1. 确定权重名称:
tvars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for tmp in tvars1:
print('all-->',tmp.name)
2. 根据网络结构从1中找到想要打印的权重名称 weight_name,通过下面的方式进行打印
fc_logits=tf.get_default_graph().get_tensor_by_name(weight_name)
with tf.Session() as sess:
init_op = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
sess.run(init_op)
fc_logits_ = sess.run(fc_logits,feed_dict={input_placeholder: gen_input,gt: label})
print('fc_logits_:',fc_logits_)