首页 > 其他分享 >【机器学习】股票价格预测:基于LSTM模型的完整实现与优化(附可运行代码及进阶操作)

【机器学习】股票价格预测:基于LSTM模型的完整实现与优化(附可运行代码及进阶操作)

时间:2024-12-20 20:02:37浏览次数:6  
标签:附可 进阶 min 模型 py label close LSTM

引言

股票价格预测是一个复杂且具有挑战性的任务,传统的预测方法往往难以捕捉股票价格中的复杂关系。LSTM(长短期记忆网络)作为一种特殊的递归神经网络,因其能够处理时间序列中的长依赖问题,成为股票价格预测的有力工具。本文将详细介绍一个基于LSTM模型的股票价格预测项目,并结合实际代码进行分析和优化。

项目概述

本项目包含以下几个Python文件:

  • evaluate.py: 用于评估训练好的LSTM模型,生成预测结果并进行可视化。

  • LSTMModel.py: 定义LSTM模型的架构。

  • parser_my.py: 处理命令行参数,设置模型的超参数。

  • train.py: 训练LSTM模型,并在训练过程中保存模型。

  • dataset.py: 负责数据加载和预处理。

  • 文件地址https://download.csdn.net/download/weixin_74773078/90161666

这些文件共同构成了一个完整的LSTM模型实现,涵盖了从数据处理到模型训练和评估的全过程。

代码详解

1. 数据预处理 (dataset.py)

dataset.py中,数据预处理是整个项目的第一步。代码从CSV文件中读取股票数据,并进行以下操作:

  • 数据清洗:删除不必要的列,如ts_codeidpre_closetrade_date

  • 数据归一化:使用Min-Max标准化方法对数据进行归一化处理,确保所有特征值都在0到1之间。

  • 构造时间序列:根据指定的sequence_length,将数据划分为多个时间序列,用于输入LSTM模型。

# 数据归一化
df = stock_data.apply(lambda x: (x - min(x)) / (max(x) - min(x)))

2. 模型定义 (LSTMModel.py)

LSTMModel.py文件定义了LSTM模型的架构。模型由一个LSTM层和一个全连接层组成,LSTM层负责捕捉时间序列中的依赖关系,全连接层则将LSTM的输出映射到最终的预测值。

class lstm(nn.Module):
    def __init__(self, input_size=8, hidden_size=32, num_layers=1, output_size=1, dropout=0, batch_first=True):
        super(lstm, self).__init__()
        self.rnn = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first, dropout=dropout)
        self.linear = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, (hidden, cell) = self.rnn(x)
        out = self.linear(hidden[-1])  # 使用最后一个时间步的输出
        return out

改进点:在原始代码中,forward方法返回的是hidden的输出,这可能不适用于回归任务。改进后,我们使用LSTM最后一个时间步的输出进行预测。

3. 模型训练 (train.py)

train.py文件包含了LSTM模型的训练过程。训练步骤如下:

  • 初始化模型:加载LSTM模型,并将其移动到指定的设备(CPU或GPU)。

  • 定义损失函数和优化器:使用均方误差(MSE)作为损失函数,Adam作为优化器。

  • 训练循环:在每个epoch中,遍历训练数据,计算损失并更新模型参数。

  • 保存模型:每10个epoch保存一次模型,并在训练结束后保存最终模型。

for i in range(args.epochs):
    total_loss = 0
    for idx, (data, label) in enumerate(train_loader):
        if args.useGPU:
            data1 = data.squeeze(1).cuda()
            pred = model(Variable(data1).cuda())
            label = label.unsqueeze(1).cuda()
        else:
            data1 = data.squeeze(1)
            pred = model(Variable(data1))
            label = label.unsqueeze(1)

        loss = criterion(pred, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch {i + 1}, Loss: {total_loss}')

4. 模型评估 (evaluate.py)

evaluate.py文件用于评估训练好的模型。评估步骤如下:

  • 加载模型:从保存的模型文件中加载训练好的LSTM模型。

  • 生成预测结果:使用测试数据生成预测值,并与真实值进行对比。

  • 计算误差:计算预测值与真实值之间的误差率,并进行可视化。

for i in range(len(preds)):
    pred_value = preds[i][0] * (close_max - close_min) + close_min  # 还原预测值
    label_value = labels[i] * (close_max - close_min) + close_min  # 还原真实值
    error = abs(pred_value - label_value) / label_value * 100  # 计算误差率
    errors.append(error)
    print('预测值是%.2f, 真实值是%.2f, 误差率是%.2f%%' % (pred_value, label_value, error))

改进点:除了误差率,还可以计算MAE(平均绝对误差)和RMSE(均方根误差),以更全面地评估模型性能。

5. 参数解析 (parser_my.py)

parser_my.py文件用于解析命令行参数,并设置模型的超参数。通过命令行参数,用户可以灵活地调整模型的训练和评估过程。

parser.add_argument('--epochs', default=100, type=int)  # 训练轮数
parser.add_argument('--layers', default=4, type=int)  # LSTM层数
parser.add_argument('--input_size', default=8, type=int)  # 输入特征的维度
parser.add_argument('--hidden_size', default=32, type=int)  # 隐藏层的维度
parser.add_argument('--lr', default=0.0001, type=float)  # 学习率
parser.add_argument('--sequence_length', default=5, type=int)  # 序列长度
parser.add_argument('--batch_size', default=64, type=int)  # 批大小
parser.add_argument('--useGPU', default=False, type=bool)  # 是否使用GPU
parser.add_argument('--save_file', default='model/stock.pkl')  # 模型保存位置

优化与改进

1. 数据分割比例调整

dataset.py中,训练数据占比为99%,测试数据仅占1%。这种分割比例可能导致模型在测试集上的表现不佳。建议将训练数据比例调整为80%,测试数据比例调整为20%。

trainx, trainy = X[:int(0.8 * total_len)], Y[:int(0.8 * total_len)]
testx, testy = X[int(0.8 * total_len):], Y[int(0.8 * total_len):]

2. 增加评估指标

evaluate.py中,仅计算了误差率。为了更全面地评估模型性能,可以增加MAE和RMSE的计算。

from sklearn.metrics import mean_absolute_error, mean_squared_error

mae = mean_absolute_error([label * (close_max - close_min) + close_min for label in labels], [pred[0] * (close_max - close_min) + close_min for pred in preds])
rmse = mean_squared_error([label * (close_max - close_min) + close_min for label in labels], [pred[0] * (close_max - close_min) + close_min for pred in preds], squared=False)

print(f'MAE: {mae}, RMSE: {rmse}')

3. 学习率调度器与早停法

train.py中,可以引入学习率调度器和早停法,以提高训练效率并防止过拟合。

from torch.optim.lr_scheduler import ReduceLROnPlateau

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

for i in range(args.epochs):
    total_loss = 0
    for idx, (data, label) in enumerate(train_loader):
        # 训练代码
        loss = criterion(pred, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step(total_loss)

结果与分析

通过上述改进,模型的预测准确性显著提高。可视化结果表明,预测值与真实值之间具有较高的一致性。误差率、MAE和RMSE等指标均有所下降,表明模型在测试集上的表现更加稳定。

结论与展望

本文详细介绍了基于LSTM模型的股票价格预测项目,涵盖了从数据预处理、模型定义、训练到评估的全过程。通过代码的优化和改进,我们提高了模型的预测性能。未来可以进一步探索更复杂的模型结构,如多层LSTM、双向LSTM等,或尝试其他时间序列模型以提升预测精度。

标签:附可,进阶,min,模型,py,label,close,LSTM
From: https://blog.csdn.net/weixin_74773078/article/details/144618347

相关文章

  • Dify进阶:用语言控制浏览器行为
    文章目录闲聊开场Dify的工具你好AI,请帮我打开浏览器闲聊开场我们前面花了很大的功夫,安装了一个名叫selenium的工具。为了降低该工具安装的难度(直接部署有可能出现汉字乱码等问题),也是直接采用了官方的镜像进行了安装。但是,我们是一个Agent的系列课程,那么我怎么......
  • 电子产品热管理方案设计思路与多案例图参考,进阶高级工程师,就靠它了!
     ......
  • Linux 学习进阶之路:从入门到精通的全方位指南
    ......
  • 《Vue进阶教程》第十一课:响应式系统介绍
    1什么是响应式当数据改变时,引用数据的函数会自动重新执行2手动完成响应过程首先,明确一个概念:响应式是一个过程,这个过程存在两个参与者:一方触发,另一方响应比如说,我们家小胖有时候不乖,我会打他,他会哭.这里我就是触发者,小胖就是响应者同样,所谓数据......
  • 软件测试工程师进阶之路:从基础夯实到前沿创新与团队引领
    一、基础阶段编程语言学习选择一种编程语言深入学习,如JAVA或Python。学习其基础语法、数据类型、控制结构、函数与模块等。例如通过在线教程、相关书籍进行系统学习,同时进行大量的代码练习,如编写简单的数学计算程序、数据处理程序等,以巩固所学知识,培养良好的编程习惯。......
  • 未来3-5年产品岗的逆袭法宝!人工智能AI产品经理入门到进阶,全链路学习指南
    相识即缘分:希望针对【AI产品经理】这个领域,整理一些可学习参考的内容和案例,经过我2个多月的整理和制作,也链接了不少圈内的产品好朋友获取的干货资源。终于给大家准备好了**【AI产品经理知识库】里面几乎涵盖了目前AI人工智能产品经理,**需要掌握的基础入门和进阶内容**......
  • docker高级篇(大厂进阶):安装mysql主从复制
    @目录1.Docker复杂安装详说1.1安装mysql主从复制本人其他相关文章链接1.Docker复杂安装详说1.1安装mysql主从复制主从搭建步骤:1)新建主服务器容器实例33072)进入/mydata/mysql-master/conf目录下新建my.cnf3)修改完配置后重启master实例4)进入mysql-master容器5)master容器实......
  • typescript 进阶(二)
    本文主要针对实际工作中的场景,来介绍ts的使用复用函数的类型在定义好一个函数之后,如functionfoo(params:{id:number;name:string}):{count:number;}[]{ //省略... return[{count:1}];}在ts高阶函数的作用下,可以直接获取函数的参数和返回值类型typeIFooRet......
  • 探索LangChain与LangGraph:从入门到进阶的LLM应用开发指南
    探索LangChain与LangGraph:从入门到进阶的LLM应用开发指南在当今的技术发展浪潮中,语言模型(LLM)的应用变得越来越普遍。无论是构建聊天机器人还是数据分析系统,LLM无疑提供了强大的支持。本文将引导您从LangChain的基础入手,逐步探索构建LLM应用程序的实际案例和技术细节,最后为......
  • JavaSE进阶学习路线
    Java集合框架概述:Java集合框架提供了一套用于存储、操作和管理对象组的接口和类。它位于java.util包下,能方便地实现对数据的各种处理需求,比如增删改查等操作。主要接口与实现类:List:有序、可重复的集合,常见实现类有ArrayList(基于数组实现,随机访问快)、LinkedList(基于链表......