首页 > 编程问答 >为什么在训练期间我的 TensorFlow Siamese 网络中的所有变量的梯度均为 None?

为什么在训练期间我的 TensorFlow Siamese 网络中的所有变量的梯度均为 None?

时间:2024-07-26 16:12:45浏览次数:18  
标签:python tensorflow machine-learning deep-learning neural-network

我正在 TensorFlow 中训练暹罗网络进行图像配准,但遇到一个问题,所有变量的梯度均为 None。

该网络采用一对图像(固定和移动)并输出仿射参数将移动图像与固定图像对齐的变换。模型、损失函数和训练步骤定义如下:

模型:

import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers, models

def build_siamese_model(input_shape):
    input = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, (5, 5), activation='relu')(input)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Conv2D(128, (5, 5), activation='relu')(x)
    x = layers.MaxPooling2D((2, 2))(x)
    x = layers.Flatten()(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.Dense(128, activation='relu')(x)
    output = layers.Dense(6)(x)  # 6 parameters for affine transformation
    model = models.Model(input, output)
    return model

siamese_model = build_siamese_model((970, 482, 1))
siamese_model.summary()

损失函数:

def registration_loss(fixed_image, moving_image, transform_params):
    batch_size = tf.shape(moving_image)[0]
    transform_params = tf.reshape(transform_params, (batch_size, 6))

    # Construct the transformation matrix for affine transformation
    transforms = tf.concat([
        transform_params[:, 0:1], transform_params[:, 1:2], transform_params[:, 4:5],
        transform_params[:, 2:3], transform_params[:, 3:4], transform_params[:, 5:6],
        tf.zeros((batch_size, 2))  # Adding two zeros for the last row
    ], axis=1)

    # Apply affine transformation to the moving image
    transformed_image = tfa.image.transform(moving_image, transforms)

    # Compute the loss between the fixed image and the transformed moving image
    loss = tf.reduce_mean(tf.square(fixed_image - transformed_image))
    return loss

训练步骤:

@tf.function
def train_step(fixed_image, moving_image):
    with tf.GradientTape() as tape:
        transform_params = siamese_model(moving_image, training=True)
        print(f"Transform params shape in train_step: {transform_params.shape}")  # Debugging
        loss = registration_loss(fixed_image, moving_image, transform_params)
    
    # Ensure gradients are computed for the loss w.r.t. the model's trainable variables
    gradients = tape.gradient(loss, siamese_model.trainable_variables)
    
    # Debugging: Print gradients and their shapes
    for grad, var in zip(gradients, siamese_model.trainable_variables):
        print(f"Var: {var.name}, Grad: {grad}")
        if grad is not None:
            print(f"Grad shape: {grad.shape}, Var shape: {var.shape}")

    # Apply gradients to the model's trainable variables
    optimizer.apply_gradients(zip(gradients, siamese_model.trainable_variables))
    return loss

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

训练循环:

for epoch in range(10):
    for step in range(100):
        loss = train_step(fixed_image, moving_image)
        print(f"Epoch {epoch}, Step {step}, Loss: {loss.numpy()}")

回溯:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [18], in <cell line: 2>()
      2 for epoch in range(num_epochs):
      3     for step in range(num_steps_per_epoch):
----> 4         loss = train_step(fixed_image, moving_image)
      5         print(f"Epoch {epoch}, Step {step}, Loss: {loss.numpy()}")

File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1147, in func_graph_from_py_func.<locals>.autograph_handler(*args, **kwargs)
   1145 except Exception as e:  # pylint:disable=broad-except
   1146   if hasattr(e, "ag_error_metadata"):
-> 1147     raise e.ag_error_metadata.to_exception(e)
   1148   else:
   1149     raise

ValueError: in user code:

    File "/local/scratch/melchua/slrmtmp.43866248/ipykernel_12859/2770984432.py", line 18, in train_step  *
        optimizer.apply_gradients(zip(gradients, siamese_model.trainable_variables))
    File "/scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py", line 633, in apply_gradients  **
        grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
    File "/scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/keras/optimizer_v2/utils.py", line 73, in filter_empty_gradients
        raise ValueError(f"No gradients provided for any variable: {variable}. "

    ValueError: No gradients provided for any variable: (['conv2d/kernel:0', 'conv2d/bias:0', 'conv2d_1/kernel:0', 'conv2d_1/bias:0', 'dense/kernel:0', 'dense/bias:0', 'dense_1/kernel:0', 'dense_1/bias:0', 'dense_2/kernel:0', 'dense_2/bias:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'conv2d/kernel:0' shape=(5, 5, 1, 64) dtype=float32>), (None, <tf.Variable 'conv2d/bias:0' shape=(64,) dtype=float32>), (None, <tf.Variable 'conv2d_1/kernel:0' shape=(5, 5, 64, 128) dtype=float32>), (None, <tf.Variable 'conv2d_1/bias:0' shape=(128,) dtype=float32>), (None, <tf.Variable 'dense/kernel:0' shape=(3579264, 256) dtype=float32>), (None, <tf.Variable 'dense/bias:0' shape=(256,) dtype=float32>), (None, <tf.Variable 'dense_1/kernel:0' shape=(256, 128) dtype=float32>), (None, <tf.Variable 'dense_1/bias:0' shape=(128,) dtype=float32>), (None, <tf.Variable 'dense_2/kernel:0' shape=(128, 6) dtype=float32>), (None, <tf.Variable 'dense_2/bias:0' shape=(6,) dtype=float32>)).

示例输入:

fixed_image = np.random.rand(1, 970, 482, 1).astype(np.float32)
moving_image = np.random.rand(1, 970, 482, 1).astype(np.float32)

的代码中导致梯度为 None 的问题出在 registration_loss 函数中的 tf.concat 的使用。正在使用它来构造仿射变换矩阵,但是使用的方式会破坏梯度流。

问题在于,正在将 transform_params (需要梯度的张量)与一个用 tf.zeros 创建的常量张量连接。此操作(在特定情况下连接常数)会中断从 loss transform_params 的梯度流。

要解决此问题,应该使用一种不会中断梯度流的构造变换矩阵的方法。一种方法是使用 tf.stack transform_params 的切片组合在一起,然后使用 tf.reshape 获得所需的形状。

这是更正后的 registration_loss 函数:

def registration_loss(fixed_image, moving_image, transform_params):
    batch_size = tf.shape(moving_image)[0]
    transform_params = tf.reshape(transform_params, (batch_size, 6))

    # Construct the transformation matrix for affine transformation
    transforms = tf.stack([
        transform_params[:, 0], transform_params[:, 1], transform_params[:, 4],
        transform_params[:, 2], transform_params[:, 3], transform_params[:, 5],
        tf.zeros_like(transform_params[:, 0]), tf.zeros_like(transform_params[:, 0])
    ], axis=1)
    transforms = tf.reshape(transforms, (batch_size, 3, 2))

    # Apply affine transformation to the moving image
    transformed_image = tfa.image.transform(moving_image, transforms)

    # Compute the loss between the fixed image and the transformed moving image
    loss = tf.reduce_mean(tf.square(fixed_image - transformed_image))
    return loss

在这个更正后的版本中:

  1. 我们使用 tf.stack 沿第二轴( axis=1 )连接 transform_params 的切片,这会创建一个形状为 (batch_size, 8) 的张量。
  2. 我们使用 tf.zeros_like 为最后一行创建形状为 (batch_size,) 的零张量。这会确保创建的零张量具有与 transform_params 相同的类型和放置在同一设备上的属性。
  3. 然后,我们使用 tf.reshape 将连接后的张量重塑为所需的 (batch_size, 3, 2) 形状以用于 tfa.image.transform

通过此更改,梯度现在应该可以正确地从 loss 函数流向 transform_params ,并且应该看到的模型开始训练。

标签:python,tensorflow,machine-learning,deep-learning,neural-network
From: 78796915

相关文章

  • 在 Python 中将 Kivy 文件选择器添加到 PopUp
    我一直在尝试通过应用程序的按钮释放创建文件选择器弹出窗口。我分别管理了FileChooser和Popups,但无法将两者一起解决,这里有人可以帮助解决问题吗?我正在尝试用Python而不是Kivy.lang来实现PopUp,因为这是我在弹出窗口方面的经验。我也无法让KivyDoc示例正常工作。我......
  • Python基础知识点(1)基本语句
    基本语句1.if语句if表达式:语句块其中,表达式是一个返回True或False的表达式。如果表达式为True,则执行if下面的语句块;如果为False,则跳过语句块执行下面的语句。2.if…else语句if表达式:语句块1else:语句块2其中,表达式是一个返回True或False的表达式。如果......
  • 使用Python实现深度学习模型:语言翻译与多语种处理
    引言语言翻译和多语种处理是自然语言处理(NLP)中的重要任务,广泛应用于跨语言交流、国际化应用和多语言内容管理等领域。通过使用Python和深度学习技术,我们可以构建一个简单的语言翻译与多语种处理系统。本文将介绍如何使用Python实现这些功能,并提供详细的代码示例。所需工具......
  • python框架之Flask
    之前写过有关flask-restful: https://www.cnblogs.com/xingxia/p/flask_restful.html虽然早期使用python进行web应用搭建的使用该框架,但是好像很少总结,在此记录一下 [安装]pip3installflask [使用]#导入类库fromflaskimportFlask#创建实例......
  • Python 搜索和抓取
    我有一个问题想知道是否值得花时间尝试用Python来解决。我有一个包含鱼类学名的大型CSV文件。我想将该CSV文件与大型鱼类形态信息数据库(www.fishbase.ca)交叉引用,并让代码返回每条鱼的最大长度。基本上,我需要创建代码来搜索Fishbase网站上的每条鱼,然后找到页面上的最......
  • 《最新出炉》系列入门篇-Python+Playwright自动化测试-54- 上传文件(input控件) - 上篇
    1.简介在实际工作中,我们进行web自动化的时候,文件上传是很常见的操作,例如上传用户头像,上传身份证信息等。所以宏哥打算按上传文件的分类对其进行一下讲解和分享。2.上传文件的API(input控件)Playwright是一个现代化的自动化测试工具,它支持多种浏览器和操作系统,可以帮助开发人员和......
  • python requests 报错 Caused by ProxyError ('Unable to connect to proxy', OSError
    背景:访问https接口,使用http代理版本:requests:2.31.0 从报错可以看出,是proxy相关的报错调整代码,设定不使用代理,将http与https对应的proxy值置空即可(尝试过proxies={},但此写法不生效)proxies={'http':'','https':''}response = requests.get('https://xxx......
  • python基础函数
    1.为什么使用函数使用函数的目的是去减少代码的冗余性,简化代码的复杂度2.如何去定义一个函数以def开头去进行相关的定义在def的后面我们就去以见明知意的方式去定义一个函数的名称在函数名称后面的括号中去添加参数值,可以是多个参数,也可以是无餐的3.函数的调用无参多......
  • 不使用 + 或 - 运算符 | 添加 2 个数字Python
    我一直在尝试编写逻辑,但测试用例失败。如何改进我的代码?代码:#Giventwointegersaandb,returnthesumofthetwointegerswithoutusingtheoperators+and-.a=-1b=1min_val=min(a,b)max_val=max(a,b)ifmin_val==max_val:pr......
  • python 中的智能 Cisco IOS 差异
    之前:hostnameFoo!interfaceGigabitEthernet1/1switchportmodetrunkswitchporttrunkallowedvlan10,20,30!interfaceGigabitEthernet1/2ipaddress192.0.2.1255.255.255.128noipproxy-arp!之后:hostnameFoo!interfaceGigabitEt......