首页 > 其他分享 >12-LSTM多变量-定义&训练模型

12-LSTM多变量-定义&训练模型

时间:2023-02-09 00:11:52浏览次数:53  
标签:12 变量 yhat inv print shape train test LSTM

import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.metrics import mean_squared_error
from keras.models import Sequential
from keras.layers import Dense, LSTM

# 转换成有监督数据
def series_to_supervised(data, n_in=1, n_out=1, dropnan=True):  # n_in, n_out相当于lag
    n_vars = 1 if type(data) is list else data.shape[1]  # 变量个数
    df = pd.DataFrame(data)
    print('待转换数据')
    print(df.head())
    cols, names = [], []
    # 输入序列(t-n, ..., t-1)
    for i in range(n_in, 0, -1):
        cols.append(df.shift(i))
        print('shift数据')
        print(cols[0][:5])
        names += [('var%d(t-%d)' % (j+1, i)) for j in range(n_vars)]
        print('names数据')
        print(names[:5])
    # 预测序列(t, t+1, ..., t+n)
    for i in range(0, n_out):
        cols.append(df.shift(-i))
        if i == 0:  # t时刻
            names += [('var%d(t)' % (j+1)) for j in range(n_vars)]
        else:
            names += [('var%d(t+%d)' % (j+1, i)) for j in range(n_vars)]
    # 拼接
    agg = pd.concat(cols, axis=1)
    print('拼接')
    print(agg[:5])
    agg.columns = names
    # 将空值NaN行删除
    if dropnan:
        agg.dropna(inplace=True)
    return agg

# 数据预处理
dataset = pd.read_csv('../LSTM系列/LSTM多变量3/data_set/pollution.csv', header=0, index_col=0)
values = dataset.values

# 标签编码
encoder = LabelEncoder()
values[:, 4] = encoder.fit_transform(values[:, 4])

# 转换float
values = values.astype(np.float32)

# 归一化
scaler = MinMaxScaler(feature_range=(0, 1))
scaled = scaler.fit_transform(values)

# 转换成有监督数据
reframed = series_to_supervised(scaled, 1, 1)

# 删除不预测的列
reframed.drop(reframed.columns[[9, 10, 11, 12, 13, 14, 15]], axis=1, inplace=True)

print(reframed.head())

# 数据准备
# 把数据分为训练数据和测试数据
values = reframed.values

# 拿一年的时间长度训练
n_train_hours = 365 * 24

# 划分训练数据和测试数据
train = values[:n_train_hours, :]
test = values[n_train_hours:, :]

# 拆分输入输出
train_X, train_y = train[:, :-1], train[:, -1]
test_X, test_y = test[:, :-1], test[:, -1]

# reshape输入为LSTM的输入格式
train_X = train_X.reshape((train_X.shape[0], 1, train_X.shape[1]))
test_X = test_X.reshape((test_X.shape[0], 1, test_X.shape[1]))
print('train_X.shape, train_y.shape, test_X.shape, test_y.shape')
print(train_X.shape, train_y.shape, test_X.shape, test_y.shape)

# 模型定义
model = Sequential()
model.add(LSTM(50, input_shape=(train_X.shape[1], train_X.shape[2])))
model.add(Dense(1))
model.compile(loss='mae', optimizer='adam')
print(model.summary())

# 模型训练
history = model.fit(train_X, train_y, epochs=5, batch_size=72, validation_data=(test_X, test_y), verbose=2, shuffle=False)

# 输出plot history
plt.plot(history.history['loss'], label='train')
plt.plot(history.history['val_loss'], label='test')
plt.legend()
plt.show()

# 进行预测
yhat = model.predict(test_X)
test_X = test_X.reshape((test_X.shape[0], test_X.shape[2]))

# 预测数据逆缩放
inv_yhat = np.concatenate((yhat, test_X[:, 1:]), axis=1)
inv_yhat = scaler.inverse_transform(inv_yhat)
inv_yhat = inv_yhat[:, 0]
inv_yhat = np.array(inv_yhat)

# 真实数据逆缩放
test_y = test_y.reshape((len(test_y), 1))
inv_y = np.concatenate((test_y, test_X[:, 1:]), axis=1)
inv_y = scaler.inverse_transform(inv_y)
inv_y = inv_y[:, 0]

# 画出真实数据和预测数据
plt.plot(inv_yhat, label='prediction')
plt.plot(inv_y, label='true')
plt.legend()
plt.show()

# 计算RMSE
rmse = math.sqrt(mean_squared_error(inv_y, inv_yhat))
print('Test RMSE: %.3f' % rmse)

标签:12,变量,yhat,inv,print,shape,train,test,LSTM
From: https://www.cnblogs.com/lotuslaw/p/17103804.html

相关文章

  • 13-LSTM多步预测-静态模型预测
    importpandasaspdfromsklearn.metricsimportmean_squared_errorimportmathimportmatplotlib.pyplotaspltimportdatetimedefparser(x):returndatet......
  • 5-LSTM模型开发
    """长短期记忆网络(LSTM)是一种循环神经网络(RNN)。这种类型的网络的一个好处是它可以学习和记住长序列,并且不依赖于预先指定的窗口滞后观察作为输入。在Keras中,这被称为......
  • 6-完整的LSTM案例
    importpandasaspdimportdatetimefromsklearn.metricsimportmean_squared_errorfromsklearn.preprocessingimportMinMaxScalerfromkeras.modelsimportSequ......
  • Jmeter-数据驱动DDT-CSV-响应断言也使用配置文件数据-且变量里有变量情况
    1、DDT数据驱动性能测试当我们使用Jmeter工具进行接口测试,可利用CSVDataSetConfig配置元件,对测试数据进行参数化,循环读取csv文档中每一行测试用例数据,来实现接口自动化......
  • oracle dblink 连接超时ora-12170问题记录
    应用联系我们说他们自己创建dblink连接我们数据库一会能连接上一会连接不上叫我们帮忙分析一下问题   这是他们创建语句他们那边连接报错   那么问题是怎么......
  • 124. Binary Tree Maximum Path Sum[Hard]
    124.BinaryTreeMaximumPathSumApathinabinarytreeisasequenceofnodeswhereeachpairofadjacentnodesinthesequencehasanedgeconnectingthem.......
  • 【CSP201312-4】有趣的数(数位DP)
    problem问题描述试题编号:201312-4试题名称:有趣的数时间限制:1.0s内存限制:256.0MB问题描述:问题描述我们把一个数称为有趣的,当且仅当:1.它的数字只包含0,......
  • 【CSP201312-3】最大的矩形,单调栈
    problem201312-3试题名称:最大的矩形时间限制:1.0s内存限制:256.0MB问题描述:问题描述在横轴上放了n个相邻的矩形,每个矩形的宽度是1,而第i(1≤i≤n)个矩形的高度......
  • 【CSP201312-2】ISBN号码,字符串,简单模拟
    problem试题编号:201312-2试题名称:ISBN号码时间限制:1.0s内存限制:256.0MB问题描述:问题描述每一本正式出版的图书都有一个ISBN号码与之对应,ISBN码包括9位数字、......
  • 【CSP201312-1 】出现次数最多的数,排序后扫描并记录
    problem问题描述给定n个正整数,找出它们中出现次数最多的数。如果这样的数有多个,请输出其中最小的一个。输入格式输入的第一行只有一个正整数n(1≤n≤1000),表......