首页 > 编程问答 >保存/加载自定义 tf.keras.Model 时出现问题

保存/加载自定义 tf.keras.Model 时出现问题

时间:2024-08-07 15:54:52浏览次数:14  
标签:python tensorflow tf.keras

我正在构建一个作为自定义 tf.keras.Model 实现的自动编码器。虽然训练后的模型表现良好,但我无法正确保存并重新加载它。我已经尝试过 model.save() 方法和 save_weights() 但在这两种情况下,模型完全无法执行其任务。

此自动编码器正在调用另外两个 tf.keras.Model,即编码器和解码器,而这两个模型又依次调用调用自定义层。

残差卷积块:

@tf.keras.utils.register_keras_serializable(package="ae", name="ResidualConvBlock")
class ResidualConvBlock(tf.keras.Layer):
  def __init__(self, n_filters: int, activation = 'relu', is_res = False, **kwargs) -> None:
    super().__init__(**kwargs)
    self.is_res = is_res
    self.conv1 = tf.keras.layers.Conv2D(filters = n_filters, kernel_size = 3,\
            strides=1, kernel_initializer = 'he_normal', padding = 'same')
    self.norm1 = tf.keras.layers.BatchNormalization()
    self.activation1 = tf.keras.layers.Activation(activation)

    self.conv2 = tf.keras.layers.Conv2D(filters = n_filters, kernel_size = 3,\
            strides=1, kernel_initializer = 'he_normal', padding = 'same')
    self.norm2 = tf.keras.layers.BatchNormalization()
    self.activation2 = tf.keras.layers.Activation(activation)
    self.shortcut = tf.keras.layers.Conv2D(n_filters, kernel_size=1, strides=1, padding='valid')

  def call(self, inputs, training=False):
    # First convolutional layer
    x1 = self.conv1(inputs)
    x1 = self.norm1(x1)
    x1 = self.activation1(x1)

    # Second convolutional layer
    x2 = self.conv2(x1)
    x2 = self.norm2(x2)
    out = self.activation2(x2)

    if self.is_res:
      if inputs.shape[-1] == out.shape[-1]:
        out = inputs + out
      else:
        out = self.shortcut(inputs) + out
      out = out / 1.414
    return out

编码器块:

@tf.keras.utils.register_keras_serializable(package="ae", name="EncoderBlock")
class EncoderBlock(tf.keras.Layer):
  def __init__(self, n_filters=64, pool_size=(2,2), dropout=0.3, **kwargs):
    super().__init__(**kwargs)
    self.c = ResidualConvBlock(n_filters=n_filters)
    self.p = tf.keras.layers.MaxPooling2D(pool_size=pool_size)
    self.d = tf.keras.layers.Dropout(0.3)

  def call(self, inputs):
    c = self.c(inputs)
    p = self.p(c)
    d = self.d(p)
    return d, c

编码器模型:

@tf.keras.utils.register_keras_serializable(package="ae", name="Encoder")
class Encoder(tf.keras.Model):
  def __init__(self, latent_dim:int, n_filters: int, depth: int, **kwargs):
    super().__init__(**kwargs)
    self.n_filters = n_filters
    self.depth = depth
    self.enc_blocks = []
    self.bottle_neck = tf.keras.layers.Dense(units = latent_dim)
    for i in range(self.depth):
      if i == 0:
        self.enc_blocks.append(EncoderBlock(n_filters=self.n_filters, pool_size=(2,3)))
      else:
        self.enc_blocks.append(EncoderBlock(n_filters=2 ** i * self.n_filters))


  def call(self,inputs):
    convs = []
    x = inputs
    for block in self.enc_blocks:
      x, c = block(x)
      convs.append(c)
    out = self.bottle_neck(x)
    return out, convs

  def build_graph(self, raw_shape):
    x = tf.keras.Input(shape=raw_shape)
    return tf.keras.Model(inputs=[x], outputs=self.call(x))

  def get_config(self):
    base_config = super().get_config()
    config = {
        "n_filters": self.n_filters,
        "depth": self.depth,
        "EncoderBlock": tf.keras.legacy.saving.serialize_keras_object(self.enc_blocks[0])
    }
    return {**base_config, **config}

  @classmethod
  def from_config(cls, config):
    EncoderBlock_config = config.pop("EncoderBlock")
    EncoderBlock = tf.keras.legacy.saving.deserialize_keras_object(EncoderBlock_config)
    return cls(EncoderBlock, **config)
    # return cls(**config)

解码器块:

@tf.keras.utils.register_keras_serializable(package="ae", name="DecoderBlock")
class DecoderBlock(tf.keras.Layer):
  def __init__(self, n_filters=64, kernel_size=3, strides=(2,2), dropout=0.3, is_res = False, **kwargs):
    super().__init__(**kwargs)
    self.is_res = is_res
    self.u = tf.keras.layers.Conv2DTranspose(n_filters, kernel_size, strides = strides, padding = 'same')
    self.d = tf.keras.layers.Dropout(dropout)
    self.c = ResidualConvBlock(n_filters=n_filters)
    self.is_res = is_res

  def call(self, inputs, conv):
    u = self.u(inputs)
    if self.is_res:
      x = tf.keras.layers.concatenate([u, conv])
    else:
      x = u
    x = self.d(x)
    out = self.c(x)
    return out

解码器模型:

@tf.keras.utils.register_keras_serializable(package="ae", name="Decoder")
class Decoder(tf.keras.Model):
  def __init__(self, n_filters:int, depth:int = 4, output_channels:int =3, **kwargs):
    super().__init__(**kwargs)
    self.n_filters = n_filters
    self.depth = depth
    self.output_channels = output_channels
    self.decoder_blocks = []
    for i in range(depth):
      if i == depth -1:
        self.decoder_blocks.append(DecoderBlock(n_filters=2 ** (depth - i -1) * self.n_filters, strides = (2,3)))
      else:
        self.decoder_blocks.append(DecoderBlock(n_filters=2 ** (depth - i -1) * self.n_filters))

    self.final_conv = tf.keras.layers.Conv2D(self.output_channels, (1, 1), activation='sigmoid')

  def call(self, inputs, convs):
    out = inputs
    for i in range(self.depth):
      out = self.decoder_blocks[i](out, convs[-i-1])

    outputs = self.final_conv(out)
    return outputs

  def build_graph(self, raw_shape):
    x = tf.keras.Input(shape=raw_shape)
    y = []
    for i in range(self.depth-1):
      y.append(tf.keras.Input(shape=(raw_shape[0] * 2 ** (i+1), raw_shape[1] * 2 ** (i+1), int(self.n_filters * 2 ** (self.depth-i-1)))))
    y.append(tf.keras.Input(shape=(raw_shape[1] * 2 ** (self.depth), raw_shape[0] * 2 ** (self.depth-1) * 3, int(self.n_filters))))
    y.reverse()
    return tf.keras.Model(inputs=[x], outputs=self.call(x, y))

  def get_config(self):
    base_config = super().get_config()
    config = {
        "n_filters": self.n_filters,
        "depth": self.depth,
        "output_channels": self.output_channels,
        "DecoderBlock": tf.keras.legacy.saving.serialize_keras_object(self.decoder_blocks[0])
    }
    return {**base_config, **config}

  @classmethod
  def from_config(cls, config):
    DecoderBlock_config = config.pop("DecoderBlock")
    DecoderBlock = tf.keras.legacy.saving.deserialize_keras_object(DecoderBlock_config) 
    return cls(**config)

最后是自动编码器模型:

@tf.keras.utils.register_keras_serializable(package="ae", name="AutoEncoder")
class AE_model(tf.keras.Model):
  def __init__(self, n_filters: int, latent_dim: int, depth:int, **kwargs):
    super().__init__(**kwargs)
    self.latent_dim = latent_dim
    self.depth = depth
    # encoder
    self.encoder = Encoder(n_filters=n_filters, latent_dim=latent_dim, depth=self.depth) #encoder(latent_dim, n_filters)

    # decoder
    self.decoder = Decoder(n_filters=n_filters, depth=self.depth, output_channels=3)

  def call (self, inputs):
    encoded, convs = self.encoder(inputs)
    decoded = self.decoder(encoded, convs)
    return decoded

  def build_graph(self, raw_shape):
    x = tf.keras.Input(shape=raw_shape)
    return tf.keras.Model(inputs=[x], outputs=self.call(x))

  def get_config(self):
    base_config = super().get_config()
    config = {
        "latent_dim": self.latent_dim,
        "depth": self.depth,
        "n_filters": self.encoder.n_filters,
        "encoder": tf.keras.legacy.saving.serialize_keras_object(self.encoder),
        "decoder": tf.keras.legacy.saving.serialize_keras_object(self.decoder)
    }
    return {**base_config, **config}

  @classmethod
  def from_config(cls, config):
    encoder_config = config.pop("encoder")
    encoder = tf.keras.legacy.saving.deserialize_keras_object(encoder_config)
    decoder_config = config.pop("decoder")
    decoder = tf.keras.legacy.saving.deserialize_keras_object(decoder_config)
    #return cls(encoder, decoder, **config)
    return cls(**config)

为了能够保存自动编码器并重新加载它而不出错,我必须覆盖编码器、解码器和 AE_model 类的 get_config 和 from_config 方法。 但是我不明白为什么在编码器的情况下我必须返回编码器类反序列化的配置,否则它会抱怨编码器未知,而对于 AE_model 的解码器,它会抱怨诸如 n_filters 之类的元素定义多次。

通过上述配置,我可以保存模型并重新加载它。然而,自动编码器的重建图像是完全灰色的。 训练后的结果模型: 在此处输入图像描述

训练后的结果模型,保存并重新加载: 在此处输入图像描述

问题: 如何保存模型或其权重并能够重新加载它以供将来的推理或额外训练?


出现模型保存和加载问题的原因有很多,特别是当你有多个嵌套的自定义层和模型时。让我们来分解可能出现的问题,并提供一个解决方案:

  1. 自定义层和模型的序列化: 虽然你已经为你的自定义类定义了 get_config from_config 方法,但这些方法的实现方式存在一些问题。

    • Encoder Decoder 中不一致的 from_config Encoder from_config 中,你正确地反序列化了 EncoderBlock 并将其传递给 cls 。但是,在 Decoder 中,你没有对 DecoderBlock 执行相同的操作。这可能会导致不一致,因为加载的 Decoder 可能无法正确实例化 DecoderBlock

    • 冗余的配置项: AE_model get_config 中,你存储了 n_filters depth ,这些信息已经包含在 encoder decoder 的配置中。这种冗余可能会导致冲突。

  2. 权重保存和加载:

    • 缺少 build 调用: 你没有在 Encoder Decoder 中明确调用 build 方法。虽然在第一次调用 call 时会隐式地构建模型,但这可能会在加载模型时导致问题,因为在加载权重之前,图结构可能无法完全定义。

以下是修正后的代码,以及一些解释:

from typing import Tuple

import tensorflow as tf


@tf.keras.utils.register_keras_serializable(package="ae", name="ResidualConvBlock")
class ResidualConvBlock(tf.keras.layers.Layer):
    # ... (代码与之前相同) ...


@tf.keras.utils.register_keras_serializable(package="ae", name="EncoderBlock")
class EncoderBlock(tf.keras.layers.Layer):
    # ... (代码与之前相同) ...


@tf.keras.utils.register_keras_serializable(package="ae", name="Encoder")
class Encoder(tf.keras.Model):
    def __init__(self, latent_dim: int, n_filters: int, depth: int, **kwargs):
        super().__init__(**kwargs)
        # ... (代码与之前相同) ...
        self.build((None,) + self.input_shape) # 添加 build 调用

    def call(self, inputs: tf.Tensor) -> Tuple[tf.Tensor, list]:
        # ... (代码与之前相同) ...

    def get_config(self):
        config = super().get_config()
        config.update({
            "latent_dim": self.latent_dim,
            "n_filters": self.n_filters,
            "depth": self.depth,
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


@tf.keras.utils.register_keras_serializable(package="ae", name="DecoderBlock")
class DecoderBlock(tf.keras.layers.Layer):
    # ... (代码与之前相同) ...


@tf.keras.utils.register_keras_serializable(package="ae", name="Decoder")
class Decoder(tf.keras.Model):
    def __init__(self, n_filters: int, depth: int = 4, output_channels: int = 3, **kwargs):
        super().__init__(**kwargs)
        # ... (代码与之前相同) ...
        self.build((None,) + self.input_shape) # 添加 build 调用

    def call(self, inputs: tf.Tensor, convs: list) -> tf.Tensor:
        # ... (代码与之前相同) ...

    def get_config(self):
        config = super().get_config()
        config.update({
            "n_filters": self.n_filters,
            "depth": self.depth,
            "output_channels": self.output_channels,
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)


@tf.keras.utils.register_keras_serializable(package="ae", name="AutoEncoder")
class AE_model(tf.keras.Model):
    def __init__(self, n_filters: int, latent_dim: int, depth: int, input_shape: tuple, **kwargs):
        super().__init__(**kwargs)
        self.latent_dim = latent_dim
        self.depth = depth
        self.input_shape = input_shape
        # encoder
        self.encoder = Encoder(latent_dim=latent_dim, n_filters=n_filters, depth=self.depth, input_shape=input_shape)

        # decoder
        self.decoder = Decoder(n_filters=n_filters, depth=self.depth, output_channels=3, input_shape=input_shape)

        self.build((None,) + self.input_shape) # 添加 build 调用

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        encoded, convs = self.encoder(inputs)
        decoded = self.decoder(encoded, convs)
        return decoded

    def get_config(self):
        config = super().get_config()
        config.update({
            "latent_dim": self.latent_dim,
            "depth": self.depth,
            "n_filters": self.encoder.n_filters,
            "input_shape": self.input_shape
        })
        return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

解释:

  • build 方法: __init__ 方法的末尾添加了 self.build((None,) + self.input_shape) 来明确构建模型。 这确保在保存模型之前,图结构是完全定义的。
  • 简化的 get_config from_config : 移除了冗余的配置信息,并简化了 from_config 方法。 现在,你只需要返回 cls(**config) ,因为所有必要的配置信息都已经包含在 config 中。
  • 传递 input_shape : 为了确保模型在加载时能够正确构建,我们需要在初始化子模型 ( Encoder Decoder ) 时传递输入形状。

通过这些修改,你的模型现在应该可以正确保存和加载。尝试使用以下代码保存和加载模型:

# 创建模型实例
input_shape = (128, 128, 3) # 替换为你的输入形状
model = AE_model(n_filters=64, latent_dim=128, depth=4, input_shape=input_shape)

# ... 训练你的模型 ...

# 保存模型
model.save("my_ae_model")

# 加载模型
loaded_model = tf.keras.models.load_model("my_ae_model")

如果仍然遇到问题,请提供以下信息以便我更好地帮助你:

  • 完整的错误消息和跟踪信息。
  • 你用来保存和加载模型的确切代码。
  • 你的 TensorFlow 和 Keras 版本。

标签:python,tensorflow,tf.keras
From: 78842403

相关文章

  • 改进删除文件和目录的 python 脚本运行时间
    我有一个Python脚本,可以删除X天之前的文件和目录。然而,该脚本运行在一个包含数百万个文件和目录的巨大目录上。按照目前的速度,完成删除过程大约需要六周时间(查看磁盘空间指标)。看来主要瓶颈在于列出文件和目录。任何人都可以建议代码更改或优化,以帮助减少运行时间?不......
  • python+flask计算机毕业设计新冠疫情后病历管理系统(程序+开题+论文)
    志羽·羽场管理与智能推荐系统2220o本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容研究背景新冠疫情的爆发对全球医疗体系产生了深远影响,特别是在病历管理方面。传统的病历管理方式在面对大规模......
  • python+flask计算机毕业设计微信小程序“班级小管家”(程序+开题+论文)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容研究背景随着信息技术的迅猛发展和移动互联网的普及,微信小程序作为一种轻量级的应用程序,凭借其无需下载、即用即走的特性,在教育领域展现出了巨大的......
  • 您好,我有一个关于仅使用 python 3.10 发送电子邮件附件的问题
    我在发送包含附件的电子邮件时遇到问题。我的电子邮件的内容类型似乎设置不正确,这导致附件无法正确附加。这是我的电子邮件发送功能的片段:python复制代码self.send(subject=self.subject、recipients=self.recipients、html=""、text=""、attachments=self.attac......
  • python+flask计算机毕业设计社区居民信息管理系统 (程序+开题+论文)
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容研究背景随着城市化进程的加快,社区居民信息管理成为社区管理的重要组成部分。传统的社区管理方式存在信息更新不及时、管理效率低下等问题,难以满足......
  • Python安装教程(含MacOS&&Linux系统)
    Python安装教程Windows用户访问Python官网:WelcometoPython.org 打开下载好的安装包根据提示安装   Pip换源(系统级别)(注:Pip在3.4以上的版本才支持,3.4之前的版本可以在cmd中输入 easy_installpip 下载pip)1.为什么要换源?Python安装......
  • python
    字符串比较按位比较,有一位大,整体就大。函数多返回值正确:deftest_return():return1,2,3错误:return1return2函数的多种传参方式位置参数:关键字参数:函数调用时通过“键=值”的形式传递参数(传参顺序无所谓)eg:test(name="niu",age="19")缺省参数:举例说明:def......
  • 将普通 python 文件导入另一个文件时出现 AttributeError
    我是新手。我正在尝试将简单的python文件导入到我的主文件中。相同的代码在我的mac上工作,但在我的电脑上不起作用。我不断收到此错误消息。“AttributeError:模块‘logo’没有属性‘hammer_logo’”第一个文件拍卖.py代码importlogoprint(logo.hammer_logo)第......
  • 使用python读取mysql数据,并记录到本地的文件中
    上次写过一次读取sqlserver数据,写入本地文件。今天分享一下mysql的。原理相似,希望对大家有小小的帮忙PS,我是3.6.13版本python,上一版本用包mysql-connector,一直不成功,查询官方文档,发现这个版本的PYTHON简直是奇葩的存在了。基本所有版本都支持,就是几个小版本排除在外了。......
  • python合并音视频-通过moviepy模块合并音视频
    ......