快上网专注成都网站设计 成都网站制作 成都网站建设
成都网站建设公司服务热线:028-86922220

网站建设知识

十年网站开发经验 + 多家企业客户 + 靠谱的建站团队

量身定制 + 运营维护+专业推广+无忧售后,网站问题一站解决

如何用RNN进行分类

本篇文章给大家分享的是有关如何用RNN进行分类,小编觉得挺实用的,因此分享给大家学习,希望大家阅读完这篇文章后可以有所收获,话不多说,跟着小编一起来看看吧。

创新互联是一家专业提供剑川企业网站建设,专注与成都网站制作、成都做网站、成都h5网站建设、小程序制作等业务。10年已为剑川众多企业、政府机构等服务。创新互联专业网络公司优惠进行中。

今天我们介绍的是RNN是如何玩分类的。

MNIST数据集,我们都已经很熟悉了,是一个手写数字的数据集,之前我们用它来实战CNN分类器和机器学习的方法(在公众号中回复“MNIST”,即可免费下载)。今天我们就用RNN来对MNIST数据集进行一个预测。
这个时候,我们需要将每一张数据图像当成一个28x28的序列信号(图像的大小为28x28pixels)。对于整个网络框架,我们使用一个150个循环神经元外加一个有10个神经元的全连接层(每个类对应一个),最后接一个softmax层。如下: 如何用RNN进行分类整个模型的构建阶段,也很直接,跟我们前几期学的dnn构建方法非常类似,这里只是用了没有展开的RNN代替了之前的隐藏层,需要注意的是最后的全连接层连接的是RNN的状态tensor,该状态tensor仅仅包含了RNN的最后一个状态,并且y是目标类别。

from tensorflow.contrib.layers import fully_connected
n_steps = 28
n_inputs = 28
n_neurons = 150
n_outputs = 10
learning_rate = 0.001
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.int32, [None])
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)
logits = fully_connected(states, n_outputs, activation_fn=None)
xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=y, logits=logits)
loss = tf.reduce_mean(xentropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
correct = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
init = tf.global_variables_initializer()

接下来,我们加载数据集,并对数据集进行reshape,如下:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/")
X_test = mnist.test.images.reshape((-1, n_steps, n_inputs))
y_test = mnist.test.labels

现在,我们将对上面的RNN进行training,在执行阶段跟之前的dnn也是非常类似的,如下:

n_epochs = 100
batch_size = 150
with tf.Session() as sess:
   init.run()
   for epoch in range(n_epochs):
       for iteration in range(mnist.train.num_examples // batch_size):
           X_batch, y_batch = mnist.train.next_batch(batch_size)
           X_batch = X_batch.reshape((-1, n_steps, n_inputs))
           sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
       acc_train = accuracy.eval(feed_dict={X: X_batch, y: y_batch})
       acc_test = accuracy.eval(feed_dict={X: X_test, y: y_test})
       print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)

输出的结果如下:

0 Train accuracy: 0.713333 Test accuracy: 0.7299
1 Train accuracy: 0.766667 Test accuracy: 0.7977
...
98 Train accuracy: 0.986667 Test accuracy: 0.9777
99 Train accuracy: 0.986667 Test accuracy: 0.9809

最终得到了98%的准确率,还挺不错的,如果我们调整下超参数或者RNN权重初始化的方式,训练的更久一些,或者加一些正则化的方法,结果应该还会更好。

以上就是如何用RNN进行分类,小编相信有部分知识点可能是我们日常工作会见到或用到的。希望你能通过这篇文章学到更多知识。更多详情敬请关注创新互联行业资讯频道。


网页名称:如何用RNN进行分类
URL网址:http://6mz.cn/article/gscsic.html

其他资讯