首页 > 其他分享 >tensorflow 更新部分参数或参数分开更新

tensorflow 更新部分参数或参数分开更新

时间:2023-05-08 19:14:22浏览次数:55  
标签:name get 更新 gap 参数 input tf tensorflow cls

思路:   1.根据变量名称过滤要更新的权重:   2.如果参数分开更新,还需要设置多个优化器   代码示例: def Net_1(input):     with tf.variable_scope('Net_1'):         fmap_input = tf.layers.conv2d(input,32,32,(1,1),padding='same',name='conv1')         _, xh, xw, xc = fmap_input.get_shape().as_list()         gap = tf.layers.max_pooling2d(fmap_input,(xh,xw),strides=(1,1),name='gap')         _,h,w,c = gap.get_shape().as_list()         gap = tf.reshape(gap,(-1,c))         cls_logit = tf.layers.dense(gap,3,name='fc')         cls_probs_soft = tf.nn.softmax(cls_logit, axis=1)         cls_probs = tf.clip_by_value(cls_probs_soft, 1e-7, 1.0)   。。。。   return tmp   def Net_2(input_fm):     with tf.variable_scope('Net_2'):         _, xh, xw, xc = input_fm.get_shape().as_list()         gap = tf.layers.max_pooling2d(input_fm, (xh, xw),strides=(1,1), name='gap_cls_head')         _,h,w,c = gap.get_shape().as_list()         gap = tf.reshape(gap,(-1,c))         fc_cls_logits = tf.layers.dense(gap,3)         cls_probs_soft = tf.nn.softmax(fc_cls_logits, axis=1)         cls_probs = tf.clip_by_value(cls_probs_soft, 1e-8, 1.0)         return cls_probs   net1的输出作为net2的输入 input_placeholder = tf.placeholder(dtype=tf.float32, shape=[None, 128, 128, 3], name='input_') gt = tf.placeholder(dtype=tf.int32, shape=[None], name='label') global_step = tf.Variable(0, name='globel_step', trainable=False)   output1 = Net_1(input_placeholder) output2 = Net_2(output1) loss=损失(output1,gt) loss1=损失(output2,gt)

#net1更新
optimizer_anet = tf.train.AdamOptimizer(0.01)
#法1
net1_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Net_1')
#法2
#tvars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
#net1_vars = [vr for vr in tvars1 if 'Net_1' in vr.name]
for tmp in net1_vars:
print('net1--->',tmp.name)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op_net1 = optimizer_anet.minimize(loss, global_step=global_step, var_list=net1_vars)

#net2更新
optimizer = tf.train.AdamOptimizer(0.01)
#法1
other_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Net_2')
#法2
#tvars1 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
#other_vars = [vr for vr in tvars1 if 'Net_2' in vr.name]
for tmp in other_vars:
print('net_2-->',tmp.name)
update_ops1 = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops1):
train_op_net2 = optimizer.minimize(loss2,global_step=global_step,var_list=other_vars)

     with tf.Session() as sess:     init_op = tf.group(tf.global_variables_initializer(),                        tf.local_variables_initializer())     sess.run(init_op)  #只更新net1    _=sess.run(train_op_net1,feed_dict={input_placeholder:数据,label:数据})  #只更新net2  _=sess.run(train_op_net2,feed_dict={input_placeholder:数据,label:数据})

标签:name,get,更新,gap,参数,input,tf,tensorflow,cls
From: https://www.cnblogs.com/BlogLwc/p/17382844.html

相关文章

  • 未提供与“Course.Course(string, int, int)”的所需参数“Name”对应的参数
    当传给类中的参数不确定有无时,则要给父类加个无参构造方法 ......
  • 恒创科技:香港服务器什么情况下需要更新升级?
    ​网站的正常运行离不开服务器的良好支持。任何服务中断都会减慢您的运营速度。通常情况下,随着企业业务的扩张,在使用香港服务器的过程中,难免会遇到高负载运行缓慢或性能不佳的情况。为了确保香港服务器的稳定性和性能,需要对其进行升级。那么,香港服务器什么情况下需要更新升级?......
  • 【验证码逆向专栏】数美验证码全家桶逆向分析以及 AST 获取动态参数
    声明本文章中所有内容仅供学习交流使用,不用于其他任何目的,不提供完整代码,抓包内容、敏感网址、数据接口等均已做脱敏处理,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!本文章未经许可禁止转载,禁止任何修改后二次传播,擅自使用本文讲解的技术而导致的任何意外,作......
  • 更新macOS系统后,使用gcc/g++命令,提示错误xcrun: error: invalid active developer pat
      更新macOS系统后,使用gcc/g++命令编译程序,提示错误xcrun:error:invalidactivedeveloperpath(/Library/Developer/CommandLineTools),missingxcrunat:/Library/Developer/CommandLineTools/usr/bin/xcrun解决方法:重新安装CommandLineTools,一般安装完成后问题就能......
  • 基于扩张状态观测器eso扰动补偿和权重因子调节的电流预测控制,相比传统方法,增加了参数
    基于扩张状态观测器eso扰动补偿和权重因子调节的电流预测控制,相比传统方法,增加了参数鲁棒性。降低电流脉动,和误差。基于扩张状态观测器eso补偿的三矢量模型预测控制。ID:41123672941746934......
  • 多模态+大模型领域的开源数据集(持续更新中20230508)
     ConceptualCaption简称cc,minigpt4就使用这个数据集,一个大规模的图像文本配对数据集,包含超过30万个图像,每个图像都有5个人工描述。这个数据集的目的是为了促进计算机视觉和自然语言处理之间的研究交叉,可以用于图像检索、视觉问答等任务的训练和评估。ConceptualCaptions为......
  • JVM 启动参数
    JVM启动参数通过jmap查看JVM内存分配jmap-heap[pid]一个Java进程最大占用的物理内存为:MaxMemory=eden+survivor+old+StringConstantPool+Codecache+compressedclassspace+Metaspace+Threadstack(*threadnum)+Direct+Mapped+JVM+Nativ......
  • HSSFClientAnchor 参数说明
    pachePOI 是用Java编写的免费开源的跨平台的JavaAPI,ApachePOI提供API给Java程式对MicrosoftOffice格式档案读和写的功能。HSSFClientAnchor用于创建一个新的端锚,并设置锚的左下和右下坐标,用于图片插入,画线等操作。publicHSSFClientAnchor(intdx1,intdy1,intdx2,......
  • 谷歌浏览器自动更新怎么关闭?
    1.右键单击【计算机】——【管理】——【计算机管理本地】——【系统工具】——【任务计划程序】——【任务计划程序库】——这里找到两个和Google自动更新相关的任务计划【GoogleUpdateTaskMachineCore】与【GoogleUpdateTaskMachineUA】,把这两个选项禁用;有的小伙伴们可能有三个......
  • WPF注入service,将service作为viewModel参数时,无法进入Model的问题。
    WPF注入service,将service作为viewModel参数时,无法进入Model的问题。一开始以为是注入失败,或者注入的service不对。经过排查,发现是viewModel中的参数service,不是当前包的service,是api通用包中的。....更改之后就可以进入Model了。......