思路: 1.根据变量名称过滤要更新的权重: 2.如果参数分开更新,还需要设置多个优化器 代码示例: def Net_1(input): with tf.variable_scope('Net_1'): fmap_input = tf.layers.conv2d(input,32,32,(1,1),padding='same',name='conv1') _, xh, xw, xc = fmap_input.get_shape().as_list() gap = tf.layers.max_pooling2d(fmap_input,(xh,xw),strides=(1,1),name='gap') _,h,w,c = gap.get_shape().as_list() gap = tf.reshape(gap,(-1,c)) cls_logit = tf.layers.dense(gap,3,name='fc') cls_probs_soft = tf.nn.softmax(cls_logit, axis=1) cls_probs = tf.clip_by_value(cls_probs_soft, 1e-7, 1.0) 。。。。 return tmp def Net_2(input_fm): with tf.variable_scope('Net_2'): _, xh, xw, xc = input_fm.get_shape().as_list() gap = tf.layers.max_pooling2d(input_fm, (xh, xw),strides=(1,1), name='gap_cls_head') _,h,w,c = gap.get_shape().as_list() gap = tf.reshape(gap,(-1,c)) fc_cls_logits = tf.layers.dense(gap,3) cls_probs_soft = tf.nn.softmax(fc_cls_logits, axis=1) cls_probs = tf.clip_by_value(cls_probs_soft, 1e-8, 1.0) return cls_probs net1的输出作为net2的输入 input_placeholder = tf.placeholder(dtype=tf.float32, shape=[None, 128, 128, 3], name='input_') gt = tf.placeholder(dtype=tf.int32, shape=[None], name='label') global_step = tf.Variable(0, name='globel_step', trainable=False) output1 = Net_1(input_placeholder) output2 = Net_2(output1) loss=损失(output1,gt) loss1=损失(output2,gt)
#net1更新
optimizer_anet = tf.train.AdamOptimizer(0.01)
#法1
net1_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Net_1')
#法2
#tvars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
#net1_vars = [vr for vr in tvars1 if 'Net_1' in vr.name]
for tmp in net1_vars:
print('net1--->',tmp.name)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op_net1 = optimizer_anet.minimize(loss, global_step=global_step, var_list=net1_vars)
#net2更新
optimizer = tf.train.AdamOptimizer(0.01)
#法1
other_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Net_2')
#法2
#tvars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
#other_vars = [vr for vr in tvars1 if 'Net_2' in vr.name]
for tmp in other_vars:
print('net_2-->',tmp.name)
update_ops1 = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops1):
train_op_net2 = optimizer.minimize(loss2,global_step=global_step,var_list=other_vars)