首页 > 其他分享 >6-完整的LSTM案例

6-完整的LSTM案例

时间:2023-02-09 00:00:22浏览次数:51  
标签:yhat batch scaled 案例 完整 train values LSTM model

import pandas as pd
import datetime
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import Dense, LSTM
import math
import matplotlib.pyplot as plt
import numpy as np

# 读取时间数据的格式化
def parser(x):
    return datetime.datetime.strptime(x, '%Y/%m/%d')

# 转换成有监督数据
def timeseries_to_supervised(data, lag=1):
    df = pd.DataFrame(data)
    tmp = [df.shift(i) for i in range(1, lag+1)]
    tmp.append(df)
    df = pd.concat(tmp, axis=1)
    df.fillna(0, inplace=True)
    return df

# 转换成差分数据
def difference(dataset, interval=1):
    diff = []
    for i in range(interval, len(dataset)):
        value = dataset[i] - dataset[i - interval]
        diff.append(value)
    return pd.Series(diff)

# 逆差分
def inverse_difference(history, yhat, interval=1):
    return yhat + history[-interval]

# 缩放
def scale(train, test):
    # 根据训练数据建立缩放器
    scaler = MinMaxScaler(feature_range=(-1, 1))
    scaler.fit(train)
    train_scaled = scaler.transform(train)
    test_scaled = scaler.transform(test)
    return scaler, train_scaled, test_scaled

# 逆缩放
def invert_scale(scaler, X, value):
    new_row = [x for x in X] + [value]
    array = np.array(new_row).reshape(1, len(new_row))
    inverted = scaler.inverse_transform(array)
    return inverted[0, -1]

# fit LSTM来训练数据
def fit_lstm(train, batch_size, nb_epoch, neurons):
    X, y = train[:, 0:-1], train[:, -1]
    X = X.reshape(X.shape[0], 1, X.shape[1])
    model = Sequential()
    model.add(LSTM(neurons, batch_input_shape=(batch_size, X.shape[1], X.shape[2]), stateful=True))
    model.add(Dense(1))
    print(model.summary())
    model.compile(loss='mean_squared_error', optimizer='adam')
    for i in range(nb_epoch):
        # 按照batch_size,一次读取batch_size个数据
        model.fit(X, y, epochs=1, batch_size=batch_size, verbose=0, shuffle=False)
        model.reset_states()
        print("当前计算次数:", i+1)
    return model

# 1步长预测
def forecast_lstm(model, batch_size, X):
    X = X.reshape(1, 1, len(X))
    yhat = model.predict(X, batch_size=batch_size)
    return yhat[0, 0]

# 加载数据
def parser(x):
    return datetime.datetime.strptime(x, '%Y/%m/%d')

ser = pd.read_csv('../LSTM系列/LSTM单变量1/data_set/shampoo-sales.csv', 
                header=0, parse_dates=[0], index_col=0, date_parser=parser).squeeze('columns')

# 稳定
raw_values = ser.values
diff_values = difference(raw_values, 1)

# 有监督
supervised = timeseries_to_supervised(diff_values, 1)
supervised_values = supervised.values

# 拆分训练集、测试集合
train, test = supervised_values[:-12], supervised_values[-12:]

# 缩放
scaler, train_scaled, test_scaled = scale(train, test)

# fit模型
lstm_model = fit_lstm(train_scaled, 1, 100, 4)

# 预测
train_reshaped = train_scaled[:, 0].reshape(len(train_scaled), 1, 1)  # 训练数据转换为可输入的矩阵
lstm_model.predict(train_reshaped, batch_size=1)
predictions = []
for i in range(len(test_scaled)):
    # 1步长预测
    X, y = test_scaled[i, 0:-1], test_scaled[i, -1]
    yhat = forecast_lstm(lstm_model, 1, X)
    yhat = invert_scale(scaler, X, yhat)
    yhat = inverse_difference(raw_values, yhat, len(test_scaled) + 1 - i)
    predictions.append(yhat)
    expected = raw_values[len(train) + i + 1]
    print('Moth=%d, Predicted=%f, Expected=%f' % (i + 1, yhat, expected))

# 性能报告
rmse = math.sqrt(mean_squared_error(raw_values[-12:], predictions))
print('Test RMSE:%.3f' % rmse)

# 绘图
plt.plot(raw_values[-12:])
plt.plot(predictions)
plt.show()

标签:yhat,batch,scaled,案例,完整,train,values,LSTM,model
From: https://www.cnblogs.com/lotuslaw/p/17103784.html

相关文章

  • 【Spring】Spring框架入门案例
    1.下载Spring5(1)Spring官网https://spring.io/(2)下载地址https://repo.spring.io/ui/native/release/org/springframework/spring/下载解压,文件夹说明2.创建普通Java......
  • 完整记录一次 Microsoft Teams 登录过程
    搜索teams打开MicrosoftTeams(workorschool)使用其他账户或注册卡在GitHubMobile认证(手机app已经认证,但是没反应).重新来一次还是没反应,使用验证器......
  • 完整工作流整合方案,自定义配置,Java+Vue+Activiti@附配套文档
    前言activiti工作流引擎项目,企业erp、oa、hr、crm等企事业办公系统轻松落地,一套完整并且实际运用在多套项目中的案例,满足日常业务流程审批需求。一、项目形式springboot......
  • Cookie案例 分析 实现
    案例:记住上一次访问时间1需求 访问一个Servlet 如果是第一次访问 则提示 您好欢迎您首次访问2如果不是第一次访问 则提示 欢迎回来 您上次访问时间为 显示时间字......
  • C# Winform MessageBox使用方法及案例
    我们在程序中经常会用到MessageBox。  MessageBox.Show()共有21中重载方法。现将其常见用法总结如下:   1.MessageBox.Show("Hello~~~~");最简单的,只显示提示信息......
  • golang 内存泄漏分析案例
    1.前言关于内存泄漏的情形已经在之前文章总结过了,本文将讨论如何发现内存泄漏。2.怎么发现内存泄露在Go中发现内存泄露有2种方法,一个是通用的监控工具,另一个是goppro......
  • 登录案例-BeanUtils基本使用、BeanUtils介绍
    登录案例BeanUtils基本使用、BeanUtils介绍login.html中form表单的action路径的写法虚拟目录+Servlet的资源路径BeanUtils工具类,简化数据封装用于封......
  • 案例-使用maven
    需求:创建一个maven项目并添加web使用servlet对index.html进行跳转代码依赖<?xmlversion="1.0"encoding="UTF-8"?><projectxmlns="http://maven.apache.org/POM......
  • 合并分支案例
    需求:目前我在dev分支,现在要将远程的master分支合并到我本地这个分支。操作步骤:1.gitcheckoutmaster切换分支到master2.gitpull拉取远程分支,目的是确保当前分支......
  • 登录案例需求和分析
    登录案例需求用户登录案例需求:1.编写login.html登录页面username&password两个输入框2.使用Druid数据库连接池技术,操作mysql,day14......