首页 > 其他分享 >机器学习项目--库存需求预测3--LSTM模型

机器学习项目--库存需求预测3--LSTM模型

时间:2024-06-13 17:58:54浏览次数:18  
标签:-- 需求预测 shape train test import LSTM data valid

一、导入库和数据集

代码环境:

主要的包版本如下

python==3.10

scikit-learn==1.0.2

tensorflow==2.15.0
导入库

import pandas as pd
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, LSTM, Dropout
from keras.regularizers import l2
import matplotlib.pyplot as plt
import glob, os
import seaborn as sns
import sys
from sklearn.preprocessing import MinMaxScaler
import keras
from sklearn.metrics import mean_absolute_error as mae

 读取数据集加载到pandas,并打印前5行和列

dataframe = pd.read_csv("D:/data/rossmann-stores-clustering-and-forecast/train.csv")
dataframe.head()

 打印列名

dataframe.columns

对StateHoliday字母进行转化

def transform_state_holiday(x):
    if x == "a":
        return 1
    elif x == "b":
        return 2
    elif x == "c":
        return 3
    return x

dataframe["StateHoliday"] = dataframe.apply(lambda x:transform_state_holiday(x.StateHoliday), axis=1)

二、数据预处理

 获取一个门店的数据进行LSTM建模

data = dataframe[dataframe["Store"] == 1].sort_values(by="Date")

 将数据归一化到0-1之间,无量纲化

scaler = MinMaxScaler(feature_range=(0,1))
column_list_scaler = ['Sales','DayOfWeek','Open','Promo','StateHoliday','SchoolHoliday']
scaled_data = scaler.fit_transform(data[column_list_scaler].values)

时间序列数据转化为监督问题数据

def series_to_supervised(data, n_in=1, n_out=1, dropnana=True):
    n_vars = 1 if type(data) is list else data.shape[1]
    df = pd.DataFrame(data)
    cols, names = list(), list()
    # input sequence (t-n, t-1)
    for i in range(n_in, 0, -1):
        cols.append(df.shift(i))
        names += [('var%d(t-%d)'%(j+1, i)) for j in range(n_vars)]

    # forcast sequence(t, t+1, t+n)
    for i in range(n_out):
        cols.append(df.shift(-i))
        if i == 0:
            names += [('var%d'%(j+1)) for j in range(n_vars)]
        else:
            names += [('var%d(t+%d)' % (j + 1, i)) for j in range(n_vars)]

    #put it all together
    print("cols : {}".format(cols))
    print("names : {}".format(names))
    agg = pd.concat(cols, axis=1)
    agg.columns = names
    # drop rows witn NAN values
    if dropnana:
        agg.dropna(inplace=True)
    return agg

# 将时序数据转换为监督问题数据
reframed = series_to_supervised(scaled_data, 1, 1)

 删除无用的列

reframed = reframed.iloc[:, 0:len(column_list_scaler) + 1]
reframed.head()

 三、数据建模

数据集划分,选取前550天数据作为训练集,中间250天数据作为验证集,其余全为测试集

train_days = 550
valid_days = 250
values = reframed.values
train = values[:train_days, :]
valid = values[train_days:train_days+valid_days, :]
test = values[train_days+valid_days:, :]
train_X, train_Y = train[:, :-1], train[:, -1]
valid_X, valid_Y = valid[:, :-1], valid[:, -1]
test_X, test_Y = test[:, :-1], test[:, -1]

将数据集重构为符合LSTM要求的数据格式,即 [样本,时间步,特征]

train_X = train_X.reshape((train_X.shape[0], 1, train_X.shape[1]))
valid_X = valid_X.reshape((valid_X.shape[0], 1, valid_X.shape[1]))
test_X = test_X.reshape((test_X.shape[0], 1, test_X.shape[1]))
print(train_X.shape, train_Y.shape, valid_X.shape, valid_Y.shape, test_X.shape, test_Y.shape)

建立模型并训练

model = Sequential()
model.add(LSTM(100, activation='relu',input_shape=(train_X.shape[1], train_X.shape[2]), return_sequences=True))
model.add(Dropout(0.2))
model.add(Dense(1, activation='linear'))
model.compile(loss='mean_squared_error', optimizer='adam')
#fit network
LSTM = model.fit(train_X,
                  train_Y,
                  epochs=100,
                  batch_size=20,
                  validation_data=(valid_X, valid_Y),
                  verbose=2,
                  shuffle=False)

 打印loss

plt.plot(LSTM.history['loss'], label='train')
plt.plot(LSTM.history['val_loss'], label='valid')
plt.legend()
plt.show()

 预测test数据

test_predict = model.predict(test_X)

 数据进行反归一化

test_data_inverse = scaler.inverse_transform(np.concatenate(( test_Y.reshape(-1, 1), test_X.reshape(-1, len(column_list_scaler))[:, 1:]), axis=1))[:,0]
test_predict_inverse = scaler.inverse_transform(np.concatenate(( test_predict.reshape(-1, 1), test_X.reshape(-1, len(column_list_scaler))[:, 1:]), axis=1))[:,0]
print("mae : ", mae(test_data_inverse, test_predict_inverse))

四、数据来源和源码获取

训练数据来源于kaggle,读者可以去kagga下载。

或者加小编微信获取数据和源码:

 

标签:--,需求预测,shape,train,test,import,LSTM,data,valid
From: https://blog.csdn.net/u014460433/article/details/139659207

相关文章

  • Spring5的基本使用
    Spring5的一些变化Spring5.x整个框架已经全面基于Java8及以上版本,所以Spring5最低JDK版本要求是8由于Java8的反射增强,因此Spring5.x可以对方法的参数进行更高效的访问Spring5.x核心接口已经加入了Java8接口支持的默认方法Spring5.x已经自带了通用的日志封装,不需要再......
  • 文件上传下载
    前端文件上传这里说的文件上传和文件下载都是针对客户端进行的操作使用如下jsp代码,通过Servlet获取表单数据,是否可以获取到文件信息<%--CreatedbyIntelliJIDEA.User:carlDate:2021/10/9Time:15:46TochangethistemplateuseFile|Settings|F......
  • 前段时间的开发过程中存在的问题
    针对前段时间的开发过程中存在的问题,我们团队目前面临多个挑战。以下是我们讨论后识别出的主要问题,并进行了投票以选出需要改进的最主要三个问题:家长端功能不完善:目前,家长端询问儿童的心理问题,但无法直观地查看孩子的最近状态。这影响了家长使用软件的体验,并限制了软件在家长用户......
  • 孩子上初中厌学怎么办?家长做好3件事,能让孩子爱上学习
    孩子上初中厌学www.zjia8.com会直接导致学习成绩下滑,甚至会染上很多不良习惯,成为危害社会的“不稳定因子”。而改变这种状况,不能只靠初中生个人,更多的要靠家长。如果家长能做好以下3件事,就有希望让孩子爱上学习。第一件事:找出原因,进行相应的调整初中是小学向高中......
  • XOR的艺术
    #include<iostream>#include<cstdio>#include<cstring>#include<string>#include<cmath>#include<algorithm>#include<cstdlib>#include<set>#include<map>#include<vector>#include<qu......
  • 孩子成绩不好怎么办?比打骂更有效的沟通方式,成绩提升30分
      孩子考试结束后,看着孩子试卷上可怜的分数,都不想承认这是自己的孩子。孩子学习成绩有好有坏,这很正常。想要孩子提升,只靠打骂是不不行的。学习的主力是孩子,不是我们已经毕业多年的家长。掌握和孩子沟通的技巧www.zjia8.com发掘孩子的问题,引导孩子才是育儿的法宝。  每......
  • 医院设备管理系统的设计与实现 毕业设计-附源码39673
    摘 要随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,医院当然也不能排除在外。医院设备管理系统是以实际运用为开发背景,运用软件工程开发方法,采用SSM技术构建的一个管理系统。整个开发过程首先对软件系统进行......
  • TypeScript声明文件
    TypeScript声明文件是一种用于描述JavaScript库、模块或框架的类型信息的文件。它们具有.d.ts扩展名,并包含了类型定义和类型注解,以便在TypeScript项目中使用这些JavaScript代码时提供类型检查和智能提示。声明文件的作用是为JavaScript代码提供静态类型检查的能力,使开发者能够在......
  • chatgpt tools调用
    chatgpttools调用1.引入openai,创建clientimportjsonimportosimportsubprocessfromopenaiimportOpenAI#api_key可以填入自己的key#base_url可以使用国内的代理,海外可以使用官方地址client=OpenAI(api_key="",base_url="https://api.openai-proxy.com......
  • 阿里云运维第一步(监控):开箱即用的监控
    作者:仲阳这是云的时代,现在云计算已经在各行各业广泛的应用。但是上云对于大多数客户来说,依然有很大的学习成本,如下图仅是阿里云都有几百款产品,怎么选择?怎么用?对于客户来说都是问题。“用好云、管好云”不仅仅是口号,还是我们的目标。来自于:https://developer.aliyun.com/ebook/8......