首页 > 编程问答 >ValueError:仅顺序模型或功能模型支持参数clone_function和input_tensors

ValueError:仅顺序模型或功能模型支持参数clone_function和input_tensors

时间:2024-07-26 11:50:15浏览次数:16  
标签:python tensorflow keras quantization-aware-training

我正在使用 量化感知训练 我参考了网上的lstm代码,想放入 QAT 放入lstm,但是遇到了ValueError。

ValueError Traceback(最近一次调用last) 在<细胞系:6>() 4返回层 5 ----> 6 注释模型 = tf.keras.models.clone_model( 7 型号, 8clone_model(模型,input_tensors,clone_function)中的clone_function = apply_quantization_to_dense,

/usr/local/lib/python3.10/dist-packages/tf_keras/src/models/cloning.py 第544章 第545章 --> 546 引发值错误( 第547章 第548章收到类型为“Sequential”的模型,其中clone_function=<function apply_quantization_to_dense at 0x78b727ec4040>和input_tensors=None```

这是我的代码

Here's my code

import keras
from keras.layers import LSTM
from keras.layers import Dense, Activation
from keras.datasets import mnist
from keras.models import Sequential
from keras.optimizers import Adam

learning_rate = 0.001
training_iters = 20
batch_size = 128
display_step = 10

n_input = 28
n_step = 28
n_hidden = 128
n_classes = 10

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(-1, n_step, n_input)
x_test = x_test.reshape(-1, n_step, n_input)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

y_train = keras.utils.to_categorical(y_train, n_classes)
y_test = keras.utils.to_categorical(y_test, n_classes)

model = Sequential()
model.add(LSTM(n_hidden,
               batch_input_shape=(None, n_step, n_input),
               unroll=True))

model.add(Dense(n_classes))
model.add(Activation('softmax'))

adam = Adam(lr=learning_rate)
model.summary()
model.compile(optimizer=adam,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=training_iters,
          verbose=1,
          validation_data=(x_test, y_test))

scores = model.evaluate(x_test, y_test, verbose=0)
print('LSTM test score:', scores[0])
print('LSTM test accuracy:', scores[1])

def apply_quantization_to_dense(layer):
    if isinstance(layer, tf.keras.layers.LSTM):
        return tfmot.quantization.keras.quantize_annotate_layer(layer)
    return layer

annotated_model = tf.keras.models.clone_model(
    model,
    clone_function=apply_quantization_to_dense,
)
qat_model = tfmot.quantization.keras.quantize_apply(annotated_model)
qat_model.summary()

错误信息“ValueError:仅顺序模型或功能模型支持参数 clone_function 和 input_tensors”表明,正尝试将 tf.keras.models.clone_model clone_function 参数一起使用,而该参数仅适用于顺序模型或功能模型。当前的模型是一个顺序模型,但 clone_function 参数仅在想使用 tf.keras.models.clone_model 的功能性 API 时使用。

在的情况下,因为要对密集层进行量化,所以可以直接修改模型定义,而不是使用 clone_model 。以下是如何操作的方法:

import tensorflow_model_optimization as tfmot

# ...(的现有代码)...

# 创建模型
model = Sequential()
model.add(LSTM(n_hidden,
               batch_input_shape=(None, n_step, n_input),
               unroll=True))
# 对密集层应用量化
model.add(tfmot.quantization.keras.quantize_annotate_layer(Dense(n_classes)))
model.add(Activation('softmax'))

# ...(的现有代码)...

# 无需使用 clone_model,直接编译和训练量化模型
adam = Adam(lr=learning_rate)
model.summary()
model.compile(optimizer=adam,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=training_iters,
          verbose=1,
          validation_data=(x_test, y_test))

# ...(的现有代码)...

在此代码中:

  1. 我们直接在创建 Dense 层时使用 tfmot.quantization.keras.quantize_annotate_layer 对其进行量化。
  2. 我们不再需要 apply_quantization_to_dense 函数,因为我们直接在模型定义中应用量化。
  3. 我们不再需要使用 clone_model ,因为我们直接修改原始模型。

通过这些更改,应该能够在的 LSTM 模型中使用量化感知训练,而不会遇到 ValueError 。请注意,可能需要根据的特定需求和 TensorFlow 版本调整量化过程。请参阅 TensorFlow 模型优化文档以获取最新信息和示例。

标签:python,tensorflow,keras,quantization-aware-training
From: 78796155

相关文章

  • 【Python自动化办公】用Pandas库自动化操作Excel表格,从读取、写入到数据处理和分析
    文末免费赠送精品编程资料~~前言Python的第三方Pandas库是数据处理和分析中的利器,其强大的功能可以帮助我们轻松地对Excel表格进行自动化操作。接下来,我们将介绍九个用Pandas库操作Excel的编程例子,并且每个例子都会涉及不同的知识点,确保全面掌握这个主题。1.读取和写入E......
  • 总结24个Python接单赚钱平台与详细教程,兼职月入5000+
    如果说当下什么编程语言最靠谱或者比较适合搞副业?答案肯定100%是:Python。python是所有语法中最简单易上手的语言,不需要特别的的英语词汇量,逻辑思维也不需要很差就能上手。而且学会了之后就能编写代码爬取各种数据,制作各种图表,提升工作效率。而且还能利用业余时间接点私活......
  • python安装第三方库的国内镜像
    直接:pipconfigsetglobal.index-urlhttps://pypi.doubanio.com/simple设置了全局的第三方库的下载文件镜像请求网址。安装第三方库:pipinstallscrapy--scrapy第三方库名称 pip从国内镜像安装的命令使用中国大陆地区的Python包镜像服务时,可以通过修改p......
  • 如何将Python嵌入.Net?
    我尝试基于文档此处和此处使用pythonnet将Python嵌入到.Net中。这是我的代码Runtime.PythonDLL=@"D:\Dev\Console\.conda\python311.dll";PythonEngine.Initialize();dynamicsys=Py.Import("sys");Console.WriteLine("Pythonversion:&quo......
  • 使用pybind11封装c++的dll,供python调用
    用pip安装好pybind11 文件清单,都写在一个目录里//文件名:add.cppextern"C"doubleadd(doublein1,doublein2){returnin1+in2;}//文件名:sub.cppextern"C"doublesub(doublein1,doublein2){returnin1-in2;}//文件名:mul.cppextern"......
  • python-myStudyList
     1  下载软件1.1下载python最新版本并安装下载地址:百度搜索python官网。WelcometoPython.org。 1.2官网学习网页:PythonTutorials–RealPython   1.3也可以下载集成环境软件Anaconda。 Anaconda软件商城官方正版免费下载(msc23.cn) 2 ......
  • Python语法基础
    基本语句输入input() eg:输出print(内容)注释单行注释:#注释内容多行注释:"""注释内容"""数据类型: 字面量:整型、浮点数、字符串......intfloatstring查看数据类型:type(数据)查看数据类型 转换函数int(x):将x转换成整数类型float(x):将x转......
  • PyTesseract 不提取文本?我是所有这些Python的新手,请需要h3lp
    它不想从图像中提取文本,就像终端保持黑色并带有空格,就像它实际上试图提取文本一样,这是我的代码和图像从PIL导入图像导入pytesseract导入CV2“C:\用户\埃米利亚诺\下载\practic.png”pytesseract.pytesseract.tesseract_cmd="C:\ProgramFiles\Tesseract-OCR\tesseract.exe......
  • Python安装第三方库
    Python安装PILPIL(PythonImagingLibrary)是一个旧的Python库,用于处理图像。然而,PIL已经不再维护,并被一个名为Pillow的库所取代。Pillow是PIL的一个分支,并且完全兼容PIL。建议使用Pillow而不是PIL。pipinstallpillowPython安装moviepymoviepy是一个用于视频编辑的Python库,......
  • 优化Python中图像中的OCR文本检测
    我目前正在用python编写一个程序,该程序获取包含大量文本的图像,将其提取到.txt文件,然后将找到的单词与另一个文件中的单词列表进行比较,并创建一些坐标(根据像素)在图像中找到的单词中,如果找到图像,则会在图像中绘制红色方块。到目前为止,我已经正确处理了坐标部分,在单词周围绘制了......