首页 > 编程问答 >将三个经过训练的二元分类模型组合成 keras 中的单个多分类模型

将三个经过训练的二元分类模型组合成 keras 中的单个多分类模型

时间:2024-07-21 02:35:32浏览次数:18  
标签:python tensorflow machine-learning keras deep-learning

我有三个经过训练的二元分类模型,它们在输出层使用 sigmoid 激活进行训练。

  1. 第一个模型返回从0到1的概率标量,以检查图像是否为数字 或不是。
  2. 第二个模型返回从 0 到 1 的概率标量来检查图像是否是数字 ONE 或否。
  3. 第三个模型返回从 0 到 1 的概率标量来检查图像是否是数字|| |两个 或不。 我知道我可以用

enter image description here

softmax 在输出层构建具有三个神经元的模型来训练它们。但是假设我遇到一种情况,由于模型复杂,训练它们的权重确实需要很长时间,我只有它们各自的二元分类模型。或者,我想在隐藏层提取它们的隐藏表示特征,例如 (二元分类以检查图像是否为零)。 model_0 那么,如何将它们连接/组合/合并为单个模型? 我的代码目前停留在这一点:

其中:

model_0 = init_binary_classification_model((28,28))
model_0.load_weights('trained_weight_of_binary_classification_to_check_whether_image_is_zero.h5')

model_1 = init_binary_classification_model((28,28))
model_1.load_weights('trained_weight_of_binary_classification_to_check_whether_image_is_one.h5')

model_2 = init_binary_classification_model((28,28))
model_2.load_weights('trained_weight_of_binary_classification_to_check_whether_image_is_two.h5')

我希望多分类模型具有相同的输入形状

def init_binary_classification_model(input_shape=(28,28)):
  input_layer = Input(shape=input_shape)
  tensor = Flatten()(input_layer)
  tensor = Dense(16, activation='relu')(tensor)
  tensor = Dense(8, activation='relu')(tensor)
  output_layer = Dense(1, activation='sigmoid')(tensor)

  return Model(inputs=input_layer, outputs=output_layer)

和不同的输出形状 (28,28) 并且我不需要重新训练模型(如果可能) . (3) 完整代码可在

https://colab.research.google.com/drive/1y1mvAzebIFU_cuEQo8Q60L1I6uT8i2Ce?usp=sharing https://colab.research.google.com/drive/1y1mvAzebIFU_cuEQo8Q60L1I6uT8i2Ce?usp=sharing


可以通过定义一个新模型来组合的二进制分类模型,该模型将三个模型的预测作为输入并输出一个三类概率向量。以下是如何执行此操作:

from tensorflow.keras.layers import Input, Concatenate, Dense
from tensorflow.keras.models import Model

# ... 加载的三个模型(model_0、model_1、model_2)...

# 创建一个输入层,其形状与的二进制分类模型相同
input_layer = Input(shape=(28,28))

# 从每个模型中获取预测
output_0 = model_0(input_layer)
output_1 = model_1(input_layer)
output_2 = model_2(input_layer)

# 将预测连接成一个向量
merged_outputs = Concatenate()([output_0, output_1, output_2])

# 可选:添加一个密集层,后跟 softmax 以获得更准确的概率分布
# merged_outputs = Dense(3, activation='softmax')(merged_outputs)

# 创建新的多类模型
multiclass_model = Model(inputs=input_layer, outputs=merged_outputs)

# 打印模型摘要
multiclass_model.summary()

这个方法做了以下几件事:

  1. 创建输入层: 它定义了一个输入层来处理图像,就像的二进制分类模型一样。
  2. 获取每个模型的预测: 它将输入图像传递给每个二进制分类模型并获取其预测(0 到 1 之间的概率)。
  3. 连接预测: 它将三个模型的预测连接成一个包含三个值的向量。
  4. 创建多类模型: 它使用输入层和连接的输出创建一个新的 Keras Model

这个新创建的 multiclass_model 将的三个二进制分类模型的预测视为输入。然后,可以使用此模型来预测输入图像属于三个类中的哪一个。

可选:

  • 可以通过添加一个具有 softmax 激活函数的密集层来提高多类模型的性能,以获得更标准化的概率分布。
  • 如果想访问中间层以提取特征,可以使用 keras 函数式 API。

请注意,此方法假设的三个二进制分类模型已经过良好的训练,并且每个模型的输出代表输入图像属于相应类别的置信度。

标签:python,tensorflow,machine-learning,keras,deep-learning
From: 78773489

相关文章

  • Python 是一种选择性解释语言吗?为什么下面的代码不起作用?
    由于程序是从上到下运行的,为什么下面的代码不执行块中的第一行就直接抛出错误?if5>2:print("TwoislessthanFive!")print("Fiveisgreaterthantwo!")错误:文件“/Users/____/Desktop/Pythonpractise/practise.py”,第3行print("五比二大!")Indentati......
  • 裁剪时间变量 Python Matplotlib Xarray
    我不确定这是否是一个愚蠢的问题,但我想按时间变量剪辑.nc文件。我在xarray中打开了数据集,但以下ds.sel行(之前已运行)仅返回错误。ds=xr.open_dataset('/Users/mia/Desktop/RMP/data/tracking/mcs_tracks_2015_11.nc')selected_days=ds.sel(time=slice('2015-11-22',......
  • 用于匹配两个数据列表中的项目的高效数据结构 - python
    我有两个列表,其中一个列表填充ID,另一个列表填充进程名称。多个进程名称可以共享一个ID。我希望能够创建一个可以使用特定ID的数据结构,然后返回与该ID关联的进程列表。我还希望能够使用特定的进程名称并返回与其连接的ID列表。我知道我可以为此创建一个字典,但是I......
  • 有人可以解决我的代码中的问题吗?而且我无法在我的电脑上安装 nsetools。如何在 python
    从nsetools导入Nseimportpandasaspdnse=Nse()all_stock_codes=nse.get_stock_codes()companies_with_low_pe=[]对于all_stock_codes中的代码:如果代码=='符号':继续尝试:stock_quote=nse.get_quote(代码)pe_ratio=stock_quote.get('priceT......
  • 将 python 脚本的 stdin 重定向到 fifo 会导致 RuntimeError: input():lost sys.stdin
    我有这个python脚本,它的作用是充当服务器,它从重定向到fifo的stdin读取命令:test.py:whileTrue:try:line=input()exceptEOFError:breakprint(f'Received:{line}')在bash中运行命令:mkfifotestfifotest.py<testfifo......
  • Python/Flask mysql 游标:为什么它不起作用?
    fromflaskimportFlaskfromflask_mysqldbimportMySQLapp=Flask(__name__)app.config['MYSQL_HOST']='localhost'app.config['MYSQL_USER']='root'app.config['MYSQL_PASSWORD']='password'a......
  • Python pandas to_csv 导致 OSError: [Errno 22] 参数无效
    我的代码如下:importpandasaspdimportnumpyasnpdf=pd.read_csv("path/to/my/infile.csv")df=df.sort_values(['distance','time'])df.to_csv("path/to/my/outfile.csv")此代码成功从infile.csv(一个3GBcsv文件)读取数据,对其进行排......
  • 从 python 中的字符串列表中提取 def 定义函数的标签
    我想使用Python中的正常def过程创建函数,并将标签分配给从字符串列表中提取的命名空间。如何实现这一点?这个问题的动机:我正在创建一个与sympy兼容的python函数库,供数学家用于符号计算实验。许多函数需要初始化具有相关标签的多个对象的系统,这些标签分别由用户提供的字......
  • 在 Raspberry Pi 4 上使用 Python 从具有 SPI 连接的 MT6816 磁性编码器读取
    我对这个领域完全陌生,并不真正知道自己在做什么并且需要帮助。我正在尝试使用MT681614位磁性编码器通过RaspberryPi的SPI连接读取绝对角度。我有以下问题:在硬件方面,是否只是简单地连接必要的连接(3.3V、MOSI、MISO、SCK、GND、CE01)?对于编码......
  • PythonW 不运行脚本。严重地
    因此,使用Windows10和Python3.6。我创建了一个.py脚本,它可以使用命令pythonmyscript.py在命令提示符下正常运行,但是当我制作该脚本的精确副本并为其赋予扩展名.pyw,并尝试使用pythonw运行它时命令pythonwmyscript.pyw,什么也没有发生......