首页 > 编程问答 >ValueError:无法识别的关键字参数传递给 LSTM:Keras 中的 {'batch_input_shape'}

ValueError:无法识别的关键字参数传递给 LSTM:Keras 中的 {'batch_input_shape'}

时间:2024-07-29 10:08:21浏览次数:9  
标签:python keras lstm

我正在尝试在 TensorFlow 中使用 Keras 构建和训练有状态 LSTM 模型,但在指定 batch_input_shape 参数时不断遇到 ValueError。

错误消息:

ValueError: Unrecognized keyword arguments passed to LSTM: {'batch_input_shape': (1, 1, 14)}

这是我的代码的简化版本:

import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM

# Load your data
file_path = 'path_to_your_file.csv'
data = pd.read_csv(file_path)

# Create 'date' column with the first day of each month
data['date'] = pd.to_datetime(data['tahun'].astype(str) + '-' + data['bulan'].astype(str) + '-01')
data['date'] = data['date'] + pd.offsets.MonthEnd(0)
data.set_index('date', inplace=True)

# Group by 'date' and sum the 'amaun_rm' column
df_sales = data.groupby('date')['amaun_rm'].sum().reset_index()

# Create a new dataframe to model the difference
df_diff = df_sales.copy()
df_diff['prev_amaun_rm'] = df_diff['amaun_rm'].shift(1)
df_diff = df_diff.dropna()
df_diff['diff'] = df_diff['amaun_rm'] - df_diff['prev_amaun_rm']

# Create new dataframe from transformation from time series to supervised
df_supervised = df_diff.drop(['prev_amaun_rm'], axis=1)
for inc in range(1, 13):
    field_name = 'lag_' + str(inc)
    df_supervised[field_name] = df_supervised['diff'].shift(inc)

# Adding moving averages
df_supervised['ma_3'] = df_supervised['amaun_rm'].rolling(window=3).mean().shift(1)
df_supervised['ma_6'] = df_supervised['amaun_rm'].rolling(window=6).mean().shift(1)
df_supervised['ma_12'] = df_supervised['amaun_rm'].rolling(window=12).mean().shift(1)
df_supervised = df_supervised.dropna().reset_index(drop=True)
df_supervised = df_supervised.fillna(df_supervised.mean())

# Split the data into train and test sets
train_set, test_set = df_supervised[0:-6].values, df_supervised[-6:].values
scaler = MinMaxScaler(feature_range=(-1, 1))
scaler = scaler.fit(train_set)
train_set_scaled = scaler.transform(train_set)
test_set_scaled = scaler.transform(test_set)

# Split into input and output
X_train, y_train = train_set_scaled[:, 1:], train_set_scaled[:, 0]
X_test, y_test = test_set_scaled[:, 1:], test_set_scaled[:, 0]
X_train = X_train.reshape((X_train.shape[0], 1, X_train.shape[1]))
X_test = X_test.reshape((X_test.shape[0], 1, X_test.shape[1]))

# Check the shape of X_train
print("X_train shape:", X_train.shape)  # Should output (44, 1, 14)

# Define the LSTM model
model = Sequential()
model.add(LSTM(4, stateful=True, batch_input_shape=(1, X_train.shape[1], X_train.shape[2])))
model.add(Dense(1))
model.compile(loss='mean_squared_error', optimizer='adam')

# Train the model
model.fit(X_train, y_train, epochs=100, batch_size=1, verbose=1, shuffle=False)

# Summarize the model
model.summary()

我尝试过的:

  • 我验证了 X_train 的形状,即 (44, 1, 14)
  • 我尝试使用 input_shape 而不是 batch_input_shape ,这导致了不同的错误。
  • 我确保TensorFlow 和 Keras 的版本是兼容的。

系统信息:

  • Python 版本:3.12

  • TensorFlow 版本:2.17.0

  • Keras 版本:3.4.1

问题: 如何我在 Keras 中正确指定 batch_input_shape 为我的有状态 LSTM 模型以避免此错误?是否有任何特定的版本要求或需要额外的配置?


在 Keras 的 LSTM 层中, batch_input_shape 参数应该只在第一层中使用,而且当 stateful=True 时是 必须的 。 这个参数帮助有状态 LSTM 理解输入数据的批次大小、时间步长和特征维度。

在你的代码中,你正确地使用了 batch_input_shape ,但是错误信息表明 LSTM 层无法识别这个参数。这通常发生在 Keras 版本不兼容 的时候。 虽然你的 TensorFlow 版本 (2.17.0) 比较新,但你的 Keras 版本 (3.4.1) 比较旧,可能不支持这个参数在有状态 LSTM 中的用法。

为了解决这个问题,我建议你尝试以下两种方案:

方案 1:升级 Keras

将 Keras 升级到最新版本可能可以解决这个问题,因为新版本通常包含对旧版本的错误修复和对新特性的支持。 你可以使用 pip 命令来升级 Keras:

pip install --upgrade keras

方案 2:使用 TensorFlow 中的 Keras

从 TensorFlow 2.0 开始,Keras 已经整合到 TensorFlow 中,你可以直接使用 tensorflow.keras 模块来构建模型。 由于你已经在使用 TensorFlow 2.17.0, 尝试将代码中的 keras 替换为 tensorflow.keras

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, LSTM

如果你已经尝试过升级 Keras 并且仍然遇到问题,请确保你的代码中没有其他地方使用了旧版本的 Keras。 另外,请检查你的环境中是否安装了多个版本的 Keras,并确保你正在使用的是你期望的版本。

如果问题仍然存在,请提供以下信息以便我进一步帮助你:

  • 升级 Keras 后的完整错误信息
  • 你的操作系统和环境信息(例如,你是使用虚拟环境吗?)
  • 你是如何安装 TensorFlow 和 Keras 的?

标签:python,keras,lstm
From: 78805181

相关文章

  • Python反编译失败。 (不支持的操作码:JUMP_IF_NOT_EXC_MATCH)
    我尝试使用“pycdc.exe”反编译使用pycdc.exe失败。因为错误“不支持的操作码:JUMP_IF_NOT_EXC_MATCH”在此处输入图像描述使用pycdc.exe失败。因为错误“不支持的操作码:JUMP_IF_NOT_EXC_MATCH”你知道我为什么失败吗?(我试图编译的.pyc似乎是3.10版本)......
  • 计算机毕业设计项目推荐,基于Echarts的高校就业数据可视化管理系统 81461(开题答辩+程序
    摘 要信息化社会内需要与之针对性的信息获取途径,但是途径的扩展基本上为人们所努力的方向,由于站在的角度存在偏差,人们经常能够获得不同类型信息,这也是技术最为难以攻克的课题。针对高校就业管理等问题,对高校就业管理进行研究分析,然后开发设计出高校就业数据可视化管理系统......
  • Python逆向总结(Python反编译)
    目录第一种:直接反编译型第二种:打包成exe的py文件第三种: 给pyc字节码(类汇编形式)第四种:加花的pyc内容参考第一种:直接反编译型除了直接获得题目内容的python文件外,出题人也可以稍微加工一点点,给出题目python文件所对应的pyc文件,即python的字节码。PYC文件的定义pyc......
  • 【Python学习手册(第四版)】学习笔记06-Python动态类型-赋值模型详解
    个人总结难免疏漏,请多包涵。更多内容请查看原文。本文以及学习笔记系列仅用于个人学习、研究交流。主要介绍Python的动态类型(也就是Python自动为跟踪对象的类型,不需要在脚本中编写声明语句),Python中变量和对象是如何通过引用关联,垃圾收集的概念,对象共享引用是如何影响多个变量......
  • Python学习手册(第四版)】学习笔记09.3-Python对象类型-分类、引用VS拷贝VS深拷贝、比较
    个人总结难免疏漏,请多包涵。更多内容请查看原文。本文以及学习笔记系列仅用于个人学习、研究交流。这部分稍杂,视需要选择目录读取。主要讲的是对之前的所有对象类型作复习,以通俗易懂、由浅入深的方式进行介绍,所有对象类型共有的特性(例如,共享引用),引用、拷贝、深拷贝,以及比较、......
  • 同时运行多个Python文件
    如何同时运行python的多个文件我有三个文件pop.pypop1.pypop2.py我想同时运行这个文件这些文件正在被一一运行python代码运行所有文件可以使用以下几种方法同时运行多个Python文件:1.使用多线程/多进程:多线程(threading):如果的Pytho......
  • 《最新出炉》系列入门篇-Python+Playwright自动化测试-56- 多文件上传 - 下篇
    1.简介前边的两篇文章中,宏哥分别对input控件上传文件和非input控件上传文件进行了从理论到实践地讲解和介绍,但是后来又有人提出疑问,前边讲解和介绍的都是上传一个文件,如果上传多个文件,Playwright是如何实现的呢?宏哥看了一下官方的API也有上传多个文件的API,那么今天就来讲解和介绍......
  • 如何使用python模块捕获用户的文本输入
    我正在开发一个项目,它会检测到如果您按“(”,它会自动关闭它“[”和“{”的情况相同,但重点是它检测键盘按钮“{”或“[”不是字符,这意味着如果朋友有不同的方式输入“[”,它将无法工作,因为该程序用于检测“altgr+(”序列,这可能会影响不同语言的键盘因为您不想在按下......
  • 如何更新 numpy 2 的 python 模块?
    在带有pip的Linux上,新的numpy2似乎可以很好地与pandas配合使用:$python3-c'importnumpyasnp;print(np.__version__);importpandasaspd;print(pd.__version__)'2.0.12.2.2但是,在带有miniconda的Windows上,我得到$${localappdata}/miniconda3/en......
  • python BioChemist 数据集的数据字典/描述
    我正在使用生物化学家数据集。我在哪里可以找到包含每个变量描述的“数据字典”?这就是我正在查看的:importpandasaspdfrompydatasetimportdatadata('bioChemists')我已经用谷歌搜索并尝试寻找运算符,但没有运气!pydataset软件包不包含生物化学家数据集的描述......