我正在构建一个作为自定义 tf.keras.Model 实现的自动编码器。虽然训练后的模型表现良好,但我无法正确保存并重新加载它。我已经尝试过 model.save() 方法和 save_weights() 但在这两种情况下,模型完全无法执行其任务。
此自动编码器正在调用另外两个 tf.keras.Model,即编码器和解码器,而这两个模型又依次调用调用自定义层。
残差卷积块:
@tf.keras.utils.register_keras_serializable(package="ae", name="ResidualConvBlock")
class ResidualConvBlock(tf.keras.Layer):
def __init__(self, n_filters: int, activation = 'relu', is_res = False, **kwargs) -> None:
super().__init__(**kwargs)
self.is_res = is_res
self.conv1 = tf.keras.layers.Conv2D(filters = n_filters, kernel_size = 3,\
strides=1, kernel_initializer = 'he_normal', padding = 'same')
self.norm1 = tf.keras.layers.BatchNormalization()
self.activation1 = tf.keras.layers.Activation(activation)
self.conv2 = tf.keras.layers.Conv2D(filters = n_filters, kernel_size = 3,\
strides=1, kernel_initializer = 'he_normal', padding = 'same')
self.norm2 = tf.keras.layers.BatchNormalization()
self.activation2 = tf.keras.layers.Activation(activation)
self.shortcut = tf.keras.layers.Conv2D(n_filters, kernel_size=1, strides=1, padding='valid')
def call(self, inputs, training=False):
# First convolutional layer
x1 = self.conv1(inputs)
x1 = self.norm1(x1)
x1 = self.activation1(x1)
# Second convolutional layer
x2 = self.conv2(x1)
x2 = self.norm2(x2)
out = self.activation2(x2)
if self.is_res:
if inputs.shape[-1] == out.shape[-1]:
out = inputs + out
else:
out = self.shortcut(inputs) + out
out = out / 1.414
return out
编码器块:
@tf.keras.utils.register_keras_serializable(package="ae", name="EncoderBlock")
class EncoderBlock(tf.keras.Layer):
def __init__(self, n_filters=64, pool_size=(2,2), dropout=0.3, **kwargs):
super().__init__(**kwargs)
self.c = ResidualConvBlock(n_filters=n_filters)
self.p = tf.keras.layers.MaxPooling2D(pool_size=pool_size)
self.d = tf.keras.layers.Dropout(0.3)
def call(self, inputs):
c = self.c(inputs)
p = self.p(c)
d = self.d(p)
return d, c
编码器模型:
@tf.keras.utils.register_keras_serializable(package="ae", name="Encoder")
class Encoder(tf.keras.Model):
def __init__(self, latent_dim:int, n_filters: int, depth: int, **kwargs):
super().__init__(**kwargs)
self.n_filters = n_filters
self.depth = depth
self.enc_blocks = []
self.bottle_neck = tf.keras.layers.Dense(units = latent_dim)
for i in range(self.depth):
if i == 0:
self.enc_blocks.append(EncoderBlock(n_filters=self.n_filters, pool_size=(2,3)))
else:
self.enc_blocks.append(EncoderBlock(n_filters=2 ** i * self.n_filters))
def call(self,inputs):
convs = []
x = inputs
for block in self.enc_blocks:
x, c = block(x)
convs.append(c)
out = self.bottle_neck(x)
return out, convs
def build_graph(self, raw_shape):
x = tf.keras.Input(shape=raw_shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x))
def get_config(self):
base_config = super().get_config()
config = {
"n_filters": self.n_filters,
"depth": self.depth,
"EncoderBlock": tf.keras.legacy.saving.serialize_keras_object(self.enc_blocks[0])
}
return {**base_config, **config}
@classmethod
def from_config(cls, config):
EncoderBlock_config = config.pop("EncoderBlock")
EncoderBlock = tf.keras.legacy.saving.deserialize_keras_object(EncoderBlock_config)
return cls(EncoderBlock, **config)
# return cls(**config)
解码器块:
@tf.keras.utils.register_keras_serializable(package="ae", name="DecoderBlock")
class DecoderBlock(tf.keras.Layer):
def __init__(self, n_filters=64, kernel_size=3, strides=(2,2), dropout=0.3, is_res = False, **kwargs):
super().__init__(**kwargs)
self.is_res = is_res
self.u = tf.keras.layers.Conv2DTranspose(n_filters, kernel_size, strides = strides, padding = 'same')
self.d = tf.keras.layers.Dropout(dropout)
self.c = ResidualConvBlock(n_filters=n_filters)
self.is_res = is_res
def call(self, inputs, conv):
u = self.u(inputs)
if self.is_res:
x = tf.keras.layers.concatenate([u, conv])
else:
x = u
x = self.d(x)
out = self.c(x)
return out
解码器模型:
@tf.keras.utils.register_keras_serializable(package="ae", name="Decoder")
class Decoder(tf.keras.Model):
def __init__(self, n_filters:int, depth:int = 4, output_channels:int =3, **kwargs):
super().__init__(**kwargs)
self.n_filters = n_filters
self.depth = depth
self.output_channels = output_channels
self.decoder_blocks = []
for i in range(depth):
if i == depth -1:
self.decoder_blocks.append(DecoderBlock(n_filters=2 ** (depth - i -1) * self.n_filters, strides = (2,3)))
else:
self.decoder_blocks.append(DecoderBlock(n_filters=2 ** (depth - i -1) * self.n_filters))
self.final_conv = tf.keras.layers.Conv2D(self.output_channels, (1, 1), activation='sigmoid')
def call(self, inputs, convs):
out = inputs
for i in range(self.depth):
out = self.decoder_blocks[i](out, convs[-i-1])
outputs = self.final_conv(out)
return outputs
def build_graph(self, raw_shape):
x = tf.keras.Input(shape=raw_shape)
y = []
for i in range(self.depth-1):
y.append(tf.keras.Input(shape=(raw_shape[0] * 2 ** (i+1), raw_shape[1] * 2 ** (i+1), int(self.n_filters * 2 ** (self.depth-i-1)))))
y.append(tf.keras.Input(shape=(raw_shape[1] * 2 ** (self.depth), raw_shape[0] * 2 ** (self.depth-1) * 3, int(self.n_filters))))
y.reverse()
return tf.keras.Model(inputs=[x], outputs=self.call(x, y))
def get_config(self):
base_config = super().get_config()
config = {
"n_filters": self.n_filters,
"depth": self.depth,
"output_channels": self.output_channels,
"DecoderBlock": tf.keras.legacy.saving.serialize_keras_object(self.decoder_blocks[0])
}
return {**base_config, **config}
@classmethod
def from_config(cls, config):
DecoderBlock_config = config.pop("DecoderBlock")
DecoderBlock = tf.keras.legacy.saving.deserialize_keras_object(DecoderBlock_config)
return cls(**config)
最后是自动编码器模型:
@tf.keras.utils.register_keras_serializable(package="ae", name="AutoEncoder")
class AE_model(tf.keras.Model):
def __init__(self, n_filters: int, latent_dim: int, depth:int, **kwargs):
super().__init__(**kwargs)
self.latent_dim = latent_dim
self.depth = depth
# encoder
self.encoder = Encoder(n_filters=n_filters, latent_dim=latent_dim, depth=self.depth) #encoder(latent_dim, n_filters)
# decoder
self.decoder = Decoder(n_filters=n_filters, depth=self.depth, output_channels=3)
def call (self, inputs):
encoded, convs = self.encoder(inputs)
decoded = self.decoder(encoded, convs)
return decoded
def build_graph(self, raw_shape):
x = tf.keras.Input(shape=raw_shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x))
def get_config(self):
base_config = super().get_config()
config = {
"latent_dim": self.latent_dim,
"depth": self.depth,
"n_filters": self.encoder.n_filters,
"encoder": tf.keras.legacy.saving.serialize_keras_object(self.encoder),
"decoder": tf.keras.legacy.saving.serialize_keras_object(self.decoder)
}
return {**base_config, **config}
@classmethod
def from_config(cls, config):
encoder_config = config.pop("encoder")
encoder = tf.keras.legacy.saving.deserialize_keras_object(encoder_config)
decoder_config = config.pop("decoder")
decoder = tf.keras.legacy.saving.deserialize_keras_object(decoder_config)
#return cls(encoder, decoder, **config)
return cls(**config)
为了能够保存自动编码器并重新加载它而不出错,我必须覆盖编码器、解码器和 AE_model 类的 get_config 和 from_config 方法。 但是我不明白为什么在编码器的情况下我必须返回编码器类反序列化的配置,否则它会抱怨编码器未知,而对于 AE_model 的解码器,它会抱怨诸如 n_filters 之类的元素定义多次。
通过上述配置,我可以保存模型并重新加载它。然而,自动编码器的重建图像是完全灰色的。 训练后的结果模型: 在此处输入图像描述
训练后的结果模型,保存并重新加载: 在此处输入图像描述
问题: 如何保存模型或其权重并能够重新加载它以供将来的推理或额外训练?
出现模型保存和加载问题的原因有很多,特别是当你有多个嵌套的自定义层和模型时。让我们来分解可能出现的问题,并提供一个解决方案:
-
自定义层和模型的序列化: 虽然你已经为你的自定义类定义了
get_config
和from_config
方法,但这些方法的实现方式存在一些问题。-
在
Encoder
和Decoder
中不一致的from_config
: 在Encoder
的from_config
中,你正确地反序列化了EncoderBlock
并将其传递给cls
。但是,在Decoder
中,你没有对DecoderBlock
执行相同的操作。这可能会导致不一致,因为加载的Decoder
可能无法正确实例化DecoderBlock
。 -
冗余的配置项: 在
AE_model
的get_config
中,你存储了n_filters
和depth
,这些信息已经包含在encoder
和decoder
的配置中。这种冗余可能会导致冲突。
-
-
权重保存和加载:
-
缺少
build
调用: 你没有在Encoder
和Decoder
中明确调用build
方法。虽然在第一次调用call
时会隐式地构建模型,但这可能会在加载模型时导致问题,因为在加载权重之前,图结构可能无法完全定义。
-
缺少
以下是修正后的代码,以及一些解释:
from typing import Tuple
import tensorflow as tf
@tf.keras.utils.register_keras_serializable(package="ae", name="ResidualConvBlock")
class ResidualConvBlock(tf.keras.layers.Layer):
# ... (代码与之前相同) ...
@tf.keras.utils.register_keras_serializable(package="ae", name="EncoderBlock")
class EncoderBlock(tf.keras.layers.Layer):
# ... (代码与之前相同) ...
@tf.keras.utils.register_keras_serializable(package="ae", name="Encoder")
class Encoder(tf.keras.Model):
def __init__(self, latent_dim: int, n_filters: int, depth: int, **kwargs):
super().__init__(**kwargs)
# ... (代码与之前相同) ...
self.build((None,) + self.input_shape) # 添加 build 调用
def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, list]:
# ... (代码与之前相同) ...
def get_config(self):
config = super().get_config()
config.update({
"latent_dim": self.latent_dim,
"n_filters": self.n_filters,
"depth": self.depth,
})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package="ae", name="DecoderBlock")
class DecoderBlock(tf.keras.layers.Layer):
# ... (代码与之前相同) ...
@tf.keras.utils.register_keras_serializable(package="ae", name="Decoder")
class Decoder(tf.keras.Model):
def __init__(self, n_filters: int, depth: int = 4, output_channels: int = 3, **kwargs):
super().__init__(**kwargs)
# ... (代码与之前相同) ...
self.build((None,) + self.input_shape) # 添加 build 调用
def call(self, inputs: tf.Tensor, convs: list) -> tf.Tensor:
# ... (代码与之前相同) ...
def get_config(self):
config = super().get_config()
config.update({
"n_filters": self.n_filters,
"depth": self.depth,
"output_channels": self.output_channels,
})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package="ae", name="AutoEncoder")
class AE_model(tf.keras.Model):
def __init__(self, n_filters: int, latent_dim: int, depth: int, input_shape: tuple, **kwargs):
super().__init__(**kwargs)
self.latent_dim = latent_dim
self.depth = depth
self.input_shape = input_shape
# encoder
self.encoder = Encoder(latent_dim=latent_dim, n_filters=n_filters, depth=self.depth, input_shape=input_shape)
# decoder
self.decoder = Decoder(n_filters=n_filters, depth=self.depth, output_channels=3, input_shape=input_shape)
self.build((None,) + self.input_shape) # 添加 build 调用
def call(self, inputs: tf.Tensor) -> tf.Tensor:
encoded, convs = self.encoder(inputs)
decoded = self.decoder(encoded, convs)
return decoded
def get_config(self):
config = super().get_config()
config.update({
"latent_dim": self.latent_dim,
"depth": self.depth,
"n_filters": self.encoder.n_filters,
"input_shape": self.input_shape
})
return config
@classmethod
def from_config(cls, config):
return cls(**config)
解释:
-
build
方法: 在__init__
方法的末尾添加了self.build((None,) + self.input_shape)
来明确构建模型。 这确保在保存模型之前,图结构是完全定义的。 -
简化的
get_config
和from_config
: 移除了冗余的配置信息,并简化了from_config
方法。 现在,你只需要返回cls(**config)
,因为所有必要的配置信息都已经包含在config
中。 -
传递
input_shape
: 为了确保模型在加载时能够正确构建,我们需要在初始化子模型 (Encoder
和Decoder
) 时传递输入形状。
通过这些修改,你的模型现在应该可以正确保存和加载。尝试使用以下代码保存和加载模型:
# 创建模型实例
input_shape = (128, 128, 3) # 替换为你的输入形状
model = AE_model(n_filters=64, latent_dim=128, depth=4, input_shape=input_shape)
# ... 训练你的模型 ...
# 保存模型
model.save("my_ae_model")
# 加载模型
loaded_model = tf.keras.models.load_model("my_ae_model")
如果仍然遇到问题,请提供以下信息以便我更好地帮助你:
- 完整的错误消息和跟踪信息。
- 你用来保存和加载模型的确切代码。
- 你的 TensorFlow 和 Keras 版本。