我正在 TensorFlow 中训练暹罗网络进行图像配准,但遇到一个问题,所有变量的梯度均为 None。
该网络采用一对图像(固定和移动)并输出仿射参数将移动图像与固定图像对齐的变换。模型、损失函数和训练步骤定义如下:
模型:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers, models
def build_siamese_model(input_shape):
input = layers.Input(shape=input_shape)
x = layers.Conv2D(64, (5, 5), activation='relu')(input)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Conv2D(128, (5, 5), activation='relu')(x)
x = layers.MaxPooling2D((2, 2))(x)
x = layers.Flatten()(x)
x = layers.Dense(256, activation='relu')(x)
x = layers.Dense(128, activation='relu')(x)
output = layers.Dense(6)(x) # 6 parameters for affine transformation
model = models.Model(input, output)
return model
siamese_model = build_siamese_model((970, 482, 1))
siamese_model.summary()
损失函数:
def registration_loss(fixed_image, moving_image, transform_params):
batch_size = tf.shape(moving_image)[0]
transform_params = tf.reshape(transform_params, (batch_size, 6))
# Construct the transformation matrix for affine transformation
transforms = tf.concat([
transform_params[:, 0:1], transform_params[:, 1:2], transform_params[:, 4:5],
transform_params[:, 2:3], transform_params[:, 3:4], transform_params[:, 5:6],
tf.zeros((batch_size, 2)) # Adding two zeros for the last row
], axis=1)
# Apply affine transformation to the moving image
transformed_image = tfa.image.transform(moving_image, transforms)
# Compute the loss between the fixed image and the transformed moving image
loss = tf.reduce_mean(tf.square(fixed_image - transformed_image))
return loss
训练步骤:
@tf.function
def train_step(fixed_image, moving_image):
with tf.GradientTape() as tape:
transform_params = siamese_model(moving_image, training=True)
print(f"Transform params shape in train_step: {transform_params.shape}") # Debugging
loss = registration_loss(fixed_image, moving_image, transform_params)
# Ensure gradients are computed for the loss w.r.t. the model's trainable variables
gradients = tape.gradient(loss, siamese_model.trainable_variables)
# Debugging: Print gradients and their shapes
for grad, var in zip(gradients, siamese_model.trainable_variables):
print(f"Var: {var.name}, Grad: {grad}")
if grad is not None:
print(f"Grad shape: {grad.shape}, Var shape: {var.shape}")
# Apply gradients to the model's trainable variables
optimizer.apply_gradients(zip(gradients, siamese_model.trainable_variables))
return loss
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
训练循环:
for epoch in range(10):
for step in range(100):
loss = train_step(fixed_image, moving_image)
print(f"Epoch {epoch}, Step {step}, Loss: {loss.numpy()}")
回溯:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Input In [18], in <cell line: 2>()
2 for epoch in range(num_epochs):
3 for step in range(num_steps_per_epoch):
----> 4 loss = train_step(fixed_image, moving_image)
5 print(f"Epoch {epoch}, Step {step}, Loss: {loss.numpy()}")
File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py:1147, in func_graph_from_py_func.<locals>.autograph_handler(*args, **kwargs)
1145 except Exception as e: # pylint:disable=broad-except
1146 if hasattr(e, "ag_error_metadata"):
-> 1147 raise e.ag_error_metadata.to_exception(e)
1148 else:
1149 raise
ValueError: in user code:
File "/local/scratch/melchua/slrmtmp.43866248/ipykernel_12859/2770984432.py", line 18, in train_step *
optimizer.apply_gradients(zip(gradients, siamese_model.trainable_variables))
File "/scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/keras/optimizer_v2/optimizer_v2.py", line 633, in apply_gradients **
grads_and_vars = optimizer_utils.filter_empty_gradients(grads_and_vars)
File "/scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/keras/optimizer_v2/utils.py", line 73, in filter_empty_gradients
raise ValueError(f"No gradients provided for any variable: {variable}. "
ValueError: No gradients provided for any variable: (['conv2d/kernel:0', 'conv2d/bias:0', 'conv2d_1/kernel:0', 'conv2d_1/bias:0', 'dense/kernel:0', 'dense/bias:0', 'dense_1/kernel:0', 'dense_1/bias:0', 'dense_2/kernel:0', 'dense_2/bias:0'],). Provided `grads_and_vars` is ((None, <tf.Variable 'conv2d/kernel:0' shape=(5, 5, 1, 64) dtype=float32>), (None, <tf.Variable 'conv2d/bias:0' shape=(64,) dtype=float32>), (None, <tf.Variable 'conv2d_1/kernel:0' shape=(5, 5, 64, 128) dtype=float32>), (None, <tf.Variable 'conv2d_1/bias:0' shape=(128,) dtype=float32>), (None, <tf.Variable 'dense/kernel:0' shape=(3579264, 256) dtype=float32>), (None, <tf.Variable 'dense/bias:0' shape=(256,) dtype=float32>), (None, <tf.Variable 'dense_1/kernel:0' shape=(256, 128) dtype=float32>), (None, <tf.Variable 'dense_1/bias:0' shape=(128,) dtype=float32>), (None, <tf.Variable 'dense_2/kernel:0' shape=(128, 6) dtype=float32>), (None, <tf.Variable 'dense_2/bias:0' shape=(6,) dtype=float32>)).
示例输入:
fixed_image = np.random.rand(1, 970, 482, 1).astype(np.float32)
moving_image = np.random.rand(1, 970, 482, 1).astype(np.float32)
的代码中导致梯度为 None 的问题出在
registration_loss
函数中的
tf.concat
的使用。正在使用它来构造仿射变换矩阵,但是使用的方式会破坏梯度流。
问题在于,正在将
transform_params
(需要梯度的张量)与一个用
tf.zeros
创建的常量张量连接。此操作(在特定情况下连接常数)会中断从
loss
到
transform_params
的梯度流。
要解决此问题,应该使用一种不会中断梯度流的构造变换矩阵的方法。一种方法是使用
tf.stack
将
transform_params
的切片组合在一起,然后使用
tf.reshape
获得所需的形状。
这是更正后的
registration_loss
函数:
def registration_loss(fixed_image, moving_image, transform_params):
batch_size = tf.shape(moving_image)[0]
transform_params = tf.reshape(transform_params, (batch_size, 6))
# Construct the transformation matrix for affine transformation
transforms = tf.stack([
transform_params[:, 0], transform_params[:, 1], transform_params[:, 4],
transform_params[:, 2], transform_params[:, 3], transform_params[:, 5],
tf.zeros_like(transform_params[:, 0]), tf.zeros_like(transform_params[:, 0])
], axis=1)
transforms = tf.reshape(transforms, (batch_size, 3, 2))
# Apply affine transformation to the moving image
transformed_image = tfa.image.transform(moving_image, transforms)
# Compute the loss between the fixed image and the transformed moving image
loss = tf.reduce_mean(tf.square(fixed_image - transformed_image))
return loss
在这个更正后的版本中:
-
我们使用
tf.stack
沿第二轴(axis=1
)连接transform_params
的切片,这会创建一个形状为(batch_size, 8)
的张量。 -
我们使用
tf.zeros_like
为最后一行创建形状为(batch_size,)
的零张量。这会确保创建的零张量具有与transform_params
相同的类型和放置在同一设备上的属性。 -
然后,我们使用
tf.reshape
将连接后的张量重塑为所需的(batch_size, 3, 2)
形状以用于tfa.image.transform
。
通过此更改,梯度现在应该可以正确地从
loss
函数流向
transform_params
,并且应该看到的模型开始训练。