首页 > tensorflow学习笔记————分类MNIST数据集

tensorflow学习笔记————分类MNIST数据集

 

在使用tensorflow分类MNIST数据集中,最容易遇到的问题是下载MNIST样本的问题。

 

一般是通过使用tensorflow内置的函数进行下载和加载,

 

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

 

但是我使用时遇到了“urllib.error.URLError: 错误,查了一下也没什么好的解决方案,最后就自己去手动下载了。在python文件同目录下建立MNIST_data,进入目录后通过wget来下载

wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

 

最后运行我们的程序

 1 import tensorflow as tf
 2 from tensorflow.examples.tutorials.mnist import input_data
 3 
 4 #通过tensorflow的库来载入训练的样本
 5 mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
 6 
 7 #每个批次的大小
 8 batch_size = 100
 9 
10 #计算有多少批次
11 n_batch = mnist.train.num_examples // batch_size
12 
13 #定义两个placeholder,x是图片样本,y是输出的结果
14 x = tf.placeholder(tf.float32, [None,784])
15 y = tf.placeholder(tf.float32, [None,10])
16 
17 #创建一个简单的神经网络
18 W = tf.Variable(tf.zeros([784,10]))
19 b = tf.Variable(tf.zeros([10]))
20 prediction = tf.nn.softmax(tf.matmul(x,W)+b)
21 
22 #二次代价函数
23 loss = tf.reduce_mean(tf.square(y - prediction))
24 
25 #使用梯度下降法
26 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
27 
28 #初始化变量
29 init = tf.global_variables_initializer()
30 
31 #结果存放在一个布尔类型列表中, tf.argmax返回一维张量中最大的值所在的位置,就是返回识别出来最可能的结果
32 correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
33 
34 #求准确率,tf.case()把bool转化为float
35 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
36 
37 with tf.Session() as sess:
38     sess.run(init)
39     for epoch in range(21):
40         for batch in range(n_batch):
41             batch_xs,batch_ys = mnist.train.next_batch(batch_size)
42             sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
43     
44         acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})
45         print("Iter " + str(epoch) + ", Testing Accuracy" + str(acc))
46     

 

转载于:https://www.cnblogs.com/QKSword/p/8723677.html

更多相关: