首页 > 编程问答 >如何将keras LSTM层的输出输入到输入层?

如何将keras LSTM层的输出输入到输入层?

时间:2024-07-24 10:56:48浏览次数:17  
标签:python tensorflow keras recurrent-neural-network

我对张量流和 keras 相当陌生,有一个问题。 我想使用 LSTM 层进行时间序列预测,并进行一些修改。我从张量流教程中给出的示例开始

def build_LSTM(neurons, batch_size, history_size, features):
   model = tf.keras.models.Sequential()
   model.add(tf.keras.layers.LSTM(neurons, 
                                  batch_input_shape=(batch_size, history_size, features),
                                  stateful=True))
   model.add(tf.keras.layers.Dense(1))
   model.compile(loss='mean_squared_error', optimizer='adam')
   return(model)

在示例的当前状态中,模型的输入采用以下形式(观测值、时间步长、特征),并且它返回一个数字(对的预测)

我想要做的是在 LSTM 层中使用模式 return_sequence=True

这返回形状(时间步长、特征)的张量 T 是否正确?

有没有一种方法可以将该张量从一个步骤(假设观察 = 1)转移到下一步(观察 = 2)?我猜对应的图表应该是这样的:

enter image description here


是的,的理解是正确的。当在 Keras LSTM 层中设置 return_sequences=True 时,它会为每个时间步长返回一个输出,从而产生形状为 (时间步长, 特征) 的张量。

要将一个时间步长的 LSTM 输出传递到下一个时间步长的输入,可以使用以下方法:

1. 使用 tf.keras.layers.TimeDistributed 包装器:

这是一种更简洁、更易读的方法。可以使用 TimeDistributed 包装器将 Dense 层应用于 LSTM 输出的每个时间步长。这将创建一个与 LSTM 输出具有相同形状的输出,然后可以将其反馈到输入。

def build_LSTM(neurons, batch_size, history_size, features):
    model = tf.keras.models.Sequential()
    model.add(tf.keras.layers.LSTM(neurons, 
                                   batch_input_shape=(batch_size, history_size, features),
                                   stateful=True,
                                   return_sequences=True))
    model.add(tf.keras.layers.TimeDistributed(tf.keras.layers.Dense(features)))
    model.compile(loss='mean_squared_error', optimizer='adam')
    return(model)

在这个例子中: - TimeDistributed(tf.keras.layers.Dense(features)) 将一个 Dense 层应用于 LSTM 输出的每个时间步长。 - 由于我们希望将输出反馈到输入,因此 Dense 层的单元数应与输入特征数 ( features ) 相同。

2. 重构模型以进行迭代预测:

如果需要对输出进行更精细的控制或需要执行其他操作(例如,在将其反馈到输入之前添加其他层),则可以重构模型以迭代执行预测。

def build_LSTM(neurons, features):
    input_layer = tf.keras.Input(shape=(1, features))
    lstm_layer = tf.keras.layers.LSTM(neurons, stateful=True, return_state=True)
    dense_layer = tf.keras.layers.Dense(features)

    outputs = []
    states = None
    for t in range(history_size):
        # 获取当前时间步长的输入
        input_t = input_layer[:, t, :]
        # 将输入传递给 LSTM
        output, state_h, state_c = lstm_layer(tf.expand_dims(input_t, axis=1), initial_state=states)
        # 应用 Dense 层
        output = dense_layer(output)
        # 将输出存储在列表中
        outputs.append(output)
        # 更新 LSTM 状态
        states = [state_h, state_c]

    model = tf.keras.Model(inputs=input_layer, outputs=outputs)
    model.compile(loss='mean_squared_error', optimizer='adam')
    return model

在这个例子中:

  • 我们使用 tf.keras.Input tf.keras.Model 创建一个更灵活的模型。
  • 我们在 for 循环中迭代时间步长。
  • return_state=True 使 LSTM 返回其内部状态,以便我们可以在下一个时间步长将其反馈。
  • initial_state 参数用于将先前的 LSTM 状态传递给下一个时间步长。

选择哪种方法取决于的特定需求。如果只需要一个简单的反馈机制,则 TimeDistributed 包装器更容易实现。如果需要对输出进行更精细的控制或需要执行其他操作,则迭代预测方法会更灵活。

标签:python,tensorflow,keras,recurrent-neural-network
From: 60660394

相关文章

  • Python 中的常量应该大写吗?
    在PEP8中,一般规则是在UPPER_CASE字符中声明常量。在现实生活中,可能有多种情况:#!envpythonDATABASE_HOST='localhost'app=Flask('myapp')base_two=partial(int,base=2)通常我们将字符串类型或数字类型变量视为不可变的,因此是常量,而不是对象或函数。问题是......
  • 多重处理会导致 Python 崩溃,并给出一个错误:调用 fork() 时可能已在另一个线程中进行
    我对Python比较陌生,并试图为我的for循环实现一个多处理模块。我在img_urls中存储了一个图像url数组,我需要下载并应用一些Google视觉。if__name__=='__main__':img_urls=[ALL_MY_Image_URLS]runAll(img_urls)print("---%sseconds---"%(......
  • Python编程时输入操作数错误
    我正在用Python编写下面的代码来模拟控制系统。但是,当我调试代码时,我面临以下问题:matmul:输入操作数1没有足够的维度(有0,gufunc核心,签名为(n?,k),(k,m?)->(n?,m?)需要1)文件“D:\ÁreadeTrabalho\GitHub\TCC\CódigosMarcela\SistemaSISO_tres_estados_new.py”,......
  • Python入门知识点 7--散列类型与字符编码
    1、初识散列类型(无序序列)数据类型分为3种:   前面已经学过了两种类型   1.数值类型:int/float/bool只能存储单个数据      2.序列类型:str/list/tuple,有序的存储多个数据--有序类型,有下标,可以进行索引切片步长操作          3.散列类型......
  • Python入门知识点 6--序列类型的方法
    1、初识序列类型方法序列类型的概念:数据的集合,在序列类型里面可以存放任意的数据也可以对数据进行更方便的操作这个操作就是叫增删改查(crud)(增加(Creat),读取查询(Retrieve),更新(Update),删除(Delete)几个单词的首字母简写)增删改查是操作数据最底层的操作(从本质......
  • Python项目流程图
    我有一个由多个文件夹组成的Python项目,每个文件夹包含多个脚本。我正在寻找一个Python库或软件/包,它们可以生成流程图,说明这些脚本如何互连并绘制出从开始到结束的整个过程。自动生成Python项目流程图确实是一个挑战,目前没有完美通用的解决方案。主要原因是:......
  • 使用 mypy 时Python中的继承和多态性不起作用
    我正在寻找用mypy做一些标准的多态性,我以前从未使用过它,而且到目前为止它并不直观。基类classContentPullOptions:passclassTool(Protocol):asyncdefpull_content(self,opts:ContentPullOptions)->str|Dict[str,Any]:...子类classGoogle......
  • Python函数获取匹配和错误记录
    我有一个以下格式的json文件:[{"type":"BEGIN","id":"XYZ123"},{"type":"END","id":"XYZ123",},{"type":&......
  • python,替换标点符号但保持特殊单词完整的最佳方法
    我正在制作一个调制函数,它将采用带有特殊字符(@&*%)的关键字,并保持它们完整,同时从句子中删除所有其他标点符号。我设计了一个解决方案,但它非常庞大,而且可能比需要的更复杂。有没有一种方法可以以更简单的方式做到这一点。简而言之,我的代码匹配特殊单词的所有实例以查找跨度。然......
  • Python 检测 USB 设备 - IDLE 和 CMD 解释器之间的不同结果
    我正在尝试解决VDI解决方案中智能卡设备的USB重定向问题。我正在使用pyscard模块作为智能卡。对于进一步的上下文,主要问题是当浏览器插件调用用于处理智能卡的python脚本时,未检测到读卡器。关于问题,当我从CMD解释器运行此代码片段时,我收到空列表,表示系统上未找......