在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题。
一般是通过使用tensorflow内置的函数进行下载和加载,
from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
但是我使用时遇到了“urllib.error.URLError:
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
最后运行我们的程序
1 import tensorflow as tf 2 from tensorflow.examples.tutorials.mnist import input_data 3 4 #通过tensorflow的库来载入训练的样本 5 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) 6 7 #每个批次的大小 8 batch_size = 100 9 10 #计算有多少批次 11 n_batch = mnist.train.num_examples // batch_size 12 13 #定义两个placeholder,x是图片样本,y是输出的结果 14 x = tf.placeholder(tf.float32, [None,784]) 15 y = tf.placeholder(tf.float32, [None,10]) 16 17 #创建一个简单的神经网络 18 W = tf.Variable(tf.zeros([784,10])) 19 b = tf.Variable(tf.zeros([10])) 20 prediction = tf.nn.softmax(tf.matmul(x,W)+b) 21 22 #二次代价函数 23 loss = tf.reduce_mean(tf.square(y - prediction)) 24 25 #使用梯度下降法 26 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss) 27 28 #初始化变量 29 init = tf.global_variables_initializer() 30 31 #结果存放在一个布尔类型列表中, tf.argmax返回一维张量中最大的值所在的位置,就是返回识别出来最可能的结果 32 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1)) 33 34 #求准确率,tf.case()把bool转化为float 35 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 36 37 with tf.Session() as sess: 38 sess.run(init) 39 for epoch in range(21): 40 for batch in range(n_batch): 41 batch_xs,batch_ys = mnist.train.next_batch(batch_size) 42 sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys}) 43 44 acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels}) 45 print("Iter " + str(epoch) + ", Testing Accuracy" + str(acc)) 46