十年网站开发经验 + 多家企业客户 + 靠谱的建站团队
量身定制 + 运营维护+专业推广+无忧售后,网站问题一站解决
这篇文章主要为大家展示了tensorflow 20如何搭网络,导出模型和运行模型,内容简而易懂,希望大家可以学习一下,学习完之后肯定会有收获的,下面让小编带大家一起来看看吧。
成都创新互联是一家集网站建设,梓潼企业网站建设,梓潼品牌网站建设,网站定制,梓潼网站建设报价,网络营销,网络优化,梓潼网站推广为一体的创新建站企业,帮助传统企业提升企业形象加强企业竞争力。可充分满足这一群体相比中小企业更为丰富、高端、多元的互联网需求。同时我们时刻保持专业、时尚、前沿,时刻以成就客户成长自我,坚持不断学习、思考、沉淀、净化自己,让我们为更多的企业打造出实用型网站。概述
以前自己都利用别人搭好的工程,修改过来用,很少把模型搭建、导出模型、加载模型运行走一遍,搞了一遍才知道这个事情也不是那么简单的。
搭建模型和导出模型
参考《TensorFlow固化模型》,导出固化的模型有两种方式.
方式1:导出pb图结构和ckpt文件,然后用 freeze_graph 工具冻结生成一个pb(包含结构和参数)
在我的代码里测试了生成pb图结构和ckpt文件,但是没接着往下走,感觉有点麻烦。我用的是第二种方法。
注意我这里只在最后保存了一次ckpt,实际应该在训练中每隔一段时间就保存一次的。
saver = tf.train.Saver(max_to_keep=5) #tf.train.write_graph(session.graph_def, FLAGS.model_dir, "nn_model.pbtxt", as_text=True) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) max_step = 2000 for i in range(max_step): batch = mnist.train.next_batch(50) if i % 100 == 0: train_accuracy = accuracy.eval(feed_dict={ x: batch[0], y_: batch[1], keep_prob: 1.0}) print('step %d, training accuracy %g' % (i, train_accuracy)) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) print('test accuracy %g' % accuracy.eval(feed_dict={ x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})) # 保存pb和ckpt print('save pb file and ckpt file') tf.train.write_graph(sess.graph_def, graph_location, "graph.pb",as_text=False) checkpoint_path = os.path.join(graph_location, "model.ckpt") saver.save(sess, checkpoint_path, global_step=max_step)