Tensorflow保存模型和加载模型

保存

首先一些小注意事项,tf.placeholder/tf.Variable/tf.Constant等等这些tensor时可以命名的,当需要保存模型的时候,一定要尽可能的给所有变量去命名


具体保存时可以按照下面的步骤:
首先创建文件夹,这个ok,模型一般文件后缀为ckpt

ckpt_dir = "./ckpt_dir"
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

然后设置一个记录迭代次数的参数:

global_step = tf.Variable(0, name='global_step', trainable=False)

定义Saver

saver = tf.train.Saver()

此外,如果你对某些操作没有命名,但是你在加载模型时需要利用这些,比如train的predict_op,你需要在test的时候去使用,你可以使用tf自带的图收集功能

tf.add_to_collection(name='predice_op', value=predice_op)
对应的是:
tf.get_collection('predice_op')

然后具体sess执行,assign赋值,eval显式的赋值,实际上等于重新定义一个变量等于sess.run(变量):

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    for i in range(iteration_num):
        ..sess.run()等等..
        global_step.assign(i).eval()
        saver.save(sess, ckpt_dir + "/model.ckpt", global_step=global_step)

这时候你的执行程序的同层级文件夹会有ckpt_dir的文件夹,里面会保存至多5个模型

加载模型

with tf.Session() as sess:
    #加载模型
    new_saver = tf.train.import_meta_graph('ckpt_dir/model.ckpt-0.meta')
    new_saver.restore(sess, 'ckpt_dir/model.ckpt-0')
    
    #加载图
    graph = tf.get_default_graph()
    #处理test数据为随机
    test_indices = np.arange(len(teX))
    np.random.shuffle(test_indices)
    test_indices = test_indices[0:test_size]
    #定义的predice_op从collection中提取
    #定义的X、p_keep_conv等等从graph中提取,两者本质上都是从图中提取,只是命没命名
    predice_op = tf.get_collection('predice_op')
    X = graph.get_operation_by_name('X').outputs[0]
    p_keep_conv=graph.get_operation_by_name('p_keep_conv').outputs[0]
    p_keep_fc = graph.get_operation_by_name('p_keep_fc').outputs[0]
    #把你需要放进feed_dict的变量提取出来,数据也准备好,就可以run了
    print(sess.run(predice_op, feed_dict={X: teX[test_indices],
                                          p_keep_conv: 1,
                                          p_keep_fc: 1}))

发表评论

电子邮件地址不会被公开。