首页 > 其他分享 >Tensorflow整理[11]. 变量创建初始化保存加载

Tensorflow整理[11]. 变量创建初始化保存加载

时间:2023-06-25 12:31:38浏览次数:41  
标签:11 初始化 变量 ... tf Variable Tensorflow name


变量:创建、初始化、保存和加载

当训练模型时,用变量来存储和更新参数。变量包含张量 (Tensor)存放于内存的缓存区。建模时它们需要被明确地初始化,模型训练后它们必须被存储到磁盘。这些变量的值可在之后模型训练和分析是被加载。

描述以下两个TensorFlow类。点击以下链接可查看完整的API文档:

  • tf.Variable 类
  • tf.train.Saver 类

创建

当创建一个变量时,你将一个张量作为初始值传入构造函数Variable()。TensorFlow提供了一系列操作符来初始化张量,初始值是常量或是随机值。

注意,所有这些操作符都需要你指定张量的shape。那个形状自动成为变量的shape。变量的shape通常是固定的,但TensorFlow提供了高级的机制来重新调整其行列数。

# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")

调用tf.Variable()添加一些操作(Op, operation)到graph:

  • 一个Variable操作存放变量的值。
  • 一个初始化op将变量设置为初始值。这事实上是一个tf.assign操作.
  • 初始值的操作,例如示例中对biases变量的zeros操作也被加入了graph。

tf.Variable的返回值是Python的tf.Variable类的一个实例。

初始化

变量的初始化必须在模型的其它操作运行之前先明确地完成。最简单的方法就是添加一个给所有变量初始化的操作,并在使用模型之前首先运行那个操作。

你或者可以从检查点文件中重新获取变量值,详见下文。

使用tf.initialize_all_variables()添加一个操作对变量做初始化。记得在完全构建好模型并加载之后再运行那个操作。

# Create two variables.
weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")
biases = tf.Variable(tf.zeros([200]), name="biases")
...
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()

# Later, when launching the model
with tf.Session() as sess:
  # Run the init operation.
  sess.run(init_op)
  ...
  # Use the model
  ...

由另一个变量初始化

你有时候会需要用另一个变量的初始化值给当前变量初始化。由于tf.initialize_all_variables()是并行地初始化所有变量,所以在有这种需求的情况下需要小心。用其它变量的值初始化一个新的变量时,使用其它变量的initialized_value()属性。你可以直接把已初始化的值作为新变量的初始值,或者把它当做tensor计算得到一个值赋予新变量。

weights = tf.Variable(tf.random_normal([784, 200], stddev=0.35),
                      name="weights")

w2 = tf.Variable(weights.initialized_value(), name="w2")

w_twice = tf.Variable(weights.initialized_value() * 0.2, name="w_twice")

自定义初始化

tf.initialize_all_variables()函数便捷地添加一个op来初始化模型的所有变量。你也可以给它传入一组变量进行初始化。

保存和加载

最简单的保存和恢复模型的方法是使用tf.train.Saver对象。构造器给graph的所有变量,或是定义在列表里的变量,添加saverestore ops。saver对象提供了方法来运行这些ops,定义检查点文件的读写路径。

检查点文件

变量存储在二进制文件里,主要包含从变量名到tensor值的映射关系。

当你创建一个Saver对象时,你可以选择性地为检查点文件中的变量挑选变量名。默认情况下,将每个变量Variable.name属性的值。

保存变量

tf.train.Saver()创建一个Saver来管理模型中的所有变量。

# 创建变量
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# 添加一个初始化操作.
init_op = tf.initialize_all_variables()

# 添加一个保存和恢复变量的操作.
saver = tf.train.Saver()

# 启动模型,初始化变量并对模型进行其他操作,进行保存
with tf.Session() as sess:
  sess.run(init_op)
  # 对模型进行其他操作.
  ..
  # 将变量保存到磁盘.
  save_path = saver.save(sess, "/tmp/model.ckpt")
  print "模型保存路径: ", save_path

恢复变量

用同一个Saver对象来恢复变量。注意,当你从文件中恢复变量时,不需要事先对它们做初始化。

# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")
  print "Model restored."
  # Do some work with the model
  ...

选择存储和恢复哪些变量

如果你不给tf.train.Saver()传入任何参数,那么saver将处理graph中的所有变量。其中每一个变量都以变量创建时传入的名称被保存。

有时候在检查点文件中明确定义变量的名称很有用。举个例子,你也许已经训练得到了一个模型,其中有个变量命名为"weights",你想把它的值恢复到一个新的变量"params"中。

有时候仅保存和恢复模型的一部分变量很有用。再举个例子,你也许训练得到了一个5层神经网络,现在想训练一个6层的新模型,可以将之前5层模型的参数导入到新模型的前5层中。

你可以通过给tf.train.Saver()构造函数传入Python字典,很容易地定义需要保持的变量及对应名称:键对应使用的名称,值对应被管理的变量。

注意:

  • 如果需要保存和恢复模型变量的不同子集,可以创建任意多个saver对象。同一个变量可被列入多个saver对象中,只有当saver的restore()函数被运行时,它的值才会发生改变。
  • 如果你仅在session开始时恢复模型变量的一个子集,你需要对剩下的变量执行初始化操作tf.initialize_variables()。
# Create some variables.
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...
# Add ops to save and restore only 'v2' using the name "my_v2"
saver = tf.train.Saver({"my_v2": v2})
# Use the saver object normally after that.
...

标签:11,初始化,变量,...,tf,Variable,Tensorflow,name
From: https://blog.51cto.com/u_12667998/6544846

相关文章

  • 111111111111
    ##----------PromptGenModule-----------------------classPromptGenBlock(nn.Module):def__init__(self,prompt_dim=128,prompt_len=5,prompt_size=96,lin_dim=192):super(PromptGenBlock,self).__init__()self.prompt_param=nn.Paramet......
  • TensorFlow11.1 循环神经网络RNN01-序列表达方法
    在自然界中除了位置相关的信息(图片)以外,还用一种存在非常广泛的类型,就是时间轴上的数据,比如说序列信号,语音信号,聊天文字。就是有先后顺序。对于下面这个:不如说我们输入有10个句子,每个句子都有4个单词,然后我们怎么把这些句子转化为具体的数值呢。如果一个表示方法能够很好的表示这......
  • 11.springboot 原理 ( 起步依赖-自动配置)
    springboot原理springframeworkspringboot(配置起步依赖-自动配置)spring-boot-starter-web起步依赖(其他依赖自动传递)自动配置原理:自动将内置类存入IOC容器中,不用收到配置,只能扫描包内即子包的类,可以指定扫描的包内容:@ComponentScan("com.alex","com.ite");@Import导......
  • TensorFlow10.4 卷积神经网络-ResNet与DenseNet及ResNet实战
    1ResNet我们是实验发现在我们堆叠更多的网络结构的时候,我们并不能又一个很好的结果,就是它网络层次变多了之后他会产生一个多层的loss的堆叠,使得梯度爆炸,或者梯度弥散。然后我们想了一个办法,就是我们比如说设置了一个30层的神经网络,我们在差也不能比22层的差。就是我们设置了一......
  • pybind11内嵌解释器
    一、创建解释器需要在使用任意PythonAPI前初始化解释器,包括pybind11Python函数和类。RAIIguard类`scoped_interpreter`可用来管理解释器的生命周期。在guard类销毁时,解释器将会关闭并占用的内存。必须在所有Python函数前调用它。#include<pybind11/embed.h>//everything......
  • mysql8 执行聚合函数报错:Error 1140: In aggregated query without GROUP BY,sql_mode
    解决办法:setglobalsql_mode='STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZERO_DATE,ERROR_FOR_DIVISION_BY_ZERO,NO_ENGINE_SUBSTITUTION';SETGLOBALlog_bin_trust_function_creators=1;setsessionsql_mode='STRICT_TRANS_TABLES,NO_ZERO_IN_DATE,NO_ZER......
  • [leetcode]114. 二叉树展开为链表
    总结:怎样写递归函数?关键是把递归函数的功能定义清楚,并在递归函数体中使用自身来做事,此时不要关注递归函数执行的细节。也就是写高层级代码的时候不要关注低层级的事情,这就叫抽象。关注也没有用,想不清楚的。 1classSolution{2publicvoidflatten(TreeNoderoot){......
  • pybind11绑定类(一)
    一、自定义数据结构-结构体`class_`会创建C++class或struct的绑定。`init()`方法使用类构造函数的参数类型作为模板参数,并包装相应的构造函数;静态成员函数需要使用`class_::def_static`来绑定#include<pybind11/pybind11.h>namespacepy=pybind11;structPet{Pet......
  • day112 - mybatis的查询与特殊sql语句
    mybatis查询与特殊语句查询普通语句/***根据id查询用户信息*@paramid*@return*/UsergetUserById(@Param("id")Integerid);​<!--UsergetUserById(@Param("id")Integerid);--><selectid="getUserById"resultType="com.......
  • TensorFlow10.4 卷积神经网络-batchnorm
    我们发现这个sigmoid函数在小于-4或者大于4的时候他的导数趋近于0。然后我们送进去的input的值在[-100,100]之间,这样很容易引起梯度弥散的现象。所以我们一般情况下使用ReLU函数,但是我们有时候又不得不使用sigmoid函数。这个时候我们在送到下一层的时候我们应该先经过Normalizatio......