Skip to content Skip to sidebar Skip to footer

Tensorflow: How To Use Pretrained Weights In New Graph?

I'm trying to build an object detector with CNN using tensorflow with python framework. I would like to train my model to do just object recognition (classification) at first and t

Solution 1:

Use saver with no arguments to save the entire model.

tf.reset_default_graph()
v1 = tf.get_variable("v1", [3], initializer = tf.initializers.random_normal)
v2 = tf.get_variable("v2", [5], initializer = tf.initializers.random_normal)
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_path='./test-case.ckpt')

    print(v1.eval())
    print(v2.eval())
saver = None
v1 = [ 2.18828251.159807   -0.26564872]
v2 = [0.114377890.5742971 ]

Then in the model you want to restore to certain values, pass a list of variable names you want to restore or a dictionary of {"variable name": variable} to the Saver.

tf.reset_default_graph()
b1 = tf.get_variable("b1", [3], initializer= tf.initializers.random_normal)
b2 = tf.get_variable("b2", [3], initializer= tf.initializers.random_normal)
saver = tf.train.Saver(var_list={'v1': b1})

with tf.Session() as sess:
  saver.restore(sess, "./test-case.ckpt")
  print(b1.eval())
  print(b2.eval())
INFO:tensorflow:Restoring parameters from ./test-case.ckpt
b1 = [ 2.18828251.159807   -0.26564872]
b2 = FailedPreconditionError: Attempting to use uninitialized value b2

Solution 2:

Although I agree with Aechlys to restore variables. The problem is harder when we want to fix these variables. For example, we trained these variables and we want to use them in another model, but this time without training them (training new variables like in transfer-learning). You can see the answer I posted here.

Quick example:

with tf.session() as sess:
    new_saver = tf.train.import_meta_graph(pathToMeta)
    new_saver.restore(sess, pathToNonMeta) 

    weight1 = sess.run(sess.graph.get_tensor_by_name("w1:0")) 


 tf.reset_default_graph() #this will eliminate the variables we restoredwith tf.session() as sess:
    weights = 
       {
       '1': tf.Variable(weight1 , name='w1-bis', trainable=False)
       }
...

We are now sure the restored variables are not a part of the graph.

Post a Comment for "Tensorflow: How To Use Pretrained Weights In New Graph?"