首页 > 编程语言 >借助 Transformer 实现美股价格的预测(Python干货)

借助 Transformer 实现美股价格的预测(Python干货)

时间:2024-08-06 12:26:03浏览次数:23  
标签:Transformer set return min Python 美股 df import Close

作者:老余捞鱼

原创不易,转载请标明出处及原作者。

写在前面的话:

          Transformer 是一种在自然语言处理等领域广泛应用的深度学习架构,与传统的循环神经网络(RNN)相比,Transformer 可以并行处理输入序列的各个位置,大大提高了计算效率。而且通过多层的深度堆叠,能够学习到更复杂和抽象的特征表示。本文将利用Python代码来实现美股价格的预测模拟。

话不多说,代码如下:

import numpy as np
import pandas as pd
import os, datetime
import tensorflow as tf
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
print('Tensorflow version: {}'.format(tf.__version__))

import matplotlib.pyplot as plt
plt.style.use('seaborn')

import warnings
warnings.filterwarnings('ignore')

physical_devices = tf.config.list_physical_devices()
print('Physical devices: {}'.format(physical_devices))

# Filter out the CPUs and keep only the GPUs
gpus = [device for device in physical_devices if 'GPU' in device.device_type]

# If GPUs are available, set memory growth to True
if len(gpus) > 0:
    tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
    tf.config.experimental.set_memory_growth(gpus[0], True)
    print('GPU memory growth: True')

Tensorflow version: 2.9.1
Physical devices: [PhysicalDevice(name=’/physical_device:CPU:0′, device_type=’CPU’)]

      Hyperparameters

  • batch_size = 32
    seq_len = 128
    
    d_k = 256
    d_v = 256
    n_heads = 12
    ff_dim = 256

    Load IBM data

    IBM_path = 'IBM.csv'
    
    df = pd.read_csv(IBM_path, delimiter=',', usecols=['Date', 'Open', 'High', 'Low', 'Close', 'Volume'])
    
    # Replace 0 to avoid dividing by 0 later on
    df['Volume'].replace(to_replace=0, method='ffill', inplace=True) 
    df.sort_values('Date', inplace=True)
    df.tail()
     df.head()

    # print the shape of the dataset
    print('Shape of the dataframe: {}'.format(df.shape))
    Shape of the dataframe: (14588, 6)

    Plot daily IBM closing prices and volume

    fig = plt.figure(figsize=(15,10))
    st = fig.suptitle("IBM Close Price and Volume", fontsize=20)
    st.set_y(0.92)
    
    ax1 = fig.add_subplot(211)
    ax1.plot(df['Close'], label='IBM Close Price')
    ax1.set_xticks(range(0, df.shape[0], 1464))
    ax1.set_xticklabels(df['Date'].loc[::1464])
    ax1.set_ylabel('Close Price', fontsize=18)
    ax1.legend(loc="upper left", fontsize=12)
    
    ax2 = fig.add_subplot(212)
    ax2.plot(df['Volume'], label='IBM Volume')
    ax2.set_xticks(range(0, df.shape[0], 1464))
    ax2.set_xticklabels(df['Date'].loc[::1464])
    ax2.set_ylabel('Volume', fontsize=18)
    ax2.legend(loc="upper left", fontsize=12)

    Calculate normalized percentage change of all columns

    '''Calculate percentage change'''
    
    df['Open'] = df['Open'].pct_change() # Create arithmetic returns column
    df['High'] = df['High'].pct_change() # Create arithmetic returns column
    df['Low'] = df['Low'].pct_change() # Create arithmetic returns column
    df['Close'] = df['Close'].pct_change() # Create arithmetic returns column
    df['Volume'] = df['Volume'].pct_change()
    
    df.dropna(how='any', axis=0, inplace=True) # Drop all rows with NaN values
    
    ###############################################################################
    '''Create indexes to split dataset'''
    
    times = sorted(df.index.values)
    last_10pct = sorted(df.index.values)[-int(0.1*len(times))] # Last 10% of series
    last_20pct = sorted(df.index.values)[-int(0.2*len(times))] # Last 20% of series
    
    ###############################################################################
    '''Normalize price columns'''
    #
    min_return = min(df[(df.index < last_20pct)][['Open', 'High', 'Low', 'Close']].min(axis=0))
    max_return = max(df[(df.index < last_20pct)][['Open', 'High', 'Low', 'Close']].max(axis=0))
    
    # Min-max normalize price columns (0-1 range)
    df['Open'] = (df['Open'] - min_return) / (max_return - min_return)
    df['High'] = (df['High'] - min_return) / (max_return - min_return)
    df['Low'] = (df['Low'] - min_return) / (max_return - min_return)
    df['Close'] = (df['Close'] - min_return) / (max_return - min_return)
    
    ###############################################################################
    '''Normalize volume column'''
    
    min_volume = df[(df.index < last_20pct)]['

标签:Transformer,set,return,min,Python,美股,df,import,Close
From: https://blog.csdn.net/weixin_70955880/article/details/140860467

相关文章

  • 将 Mojo 与 Python 结合使用
    Mojo允许您访问整个Python生态系统,但环境可能会因Python的安装方式而异。花些时间准确了解Python中的模块和包的工作原理是值得的,因为有一些复杂情况需要注意。如果您以前在调用Python代码时遇到困难,这将帮助您入门。Python中的模块和包让我们从Python开始,如......
  • Mojo和Python中的类型详解
    调用Python方法时,Mojo需要在原生Python对象和原生Mojo对象之间来回转换。大多数转换都是自动进行的,但也有一些情况Mojo尚未处理。在这些情况下,您可能需要进行显式转换,或调用额外的方法。Python中的Mojo类型Mojo基本类型隐式转换为Python对象。目前支持的......
  • python绘制圆柱体
     importosimportrandomimportnumpyasnpimportmatplotlib.pyplotasplt#合成管道数据集defplot_cylinder(center,radius,height,num_points=100):#生成圆柱体的侧面点坐标theta=np.linspace(0,2*np.pi,num_points)intervalZ=np.floor(h......
  • 计算机毕业设计必看必学!! 86393 基于微服务架构的餐饮系统的设计与实现,原创定制程序,
    摘   要近年来,我国经济和社会发展迅速,人们物质生活水平日渐提高,餐饮行业更是发展迅速,人们对于餐饮行业的认识和要求也越来越高。传统形式的餐饮行业都是以人为本,管理起来需要很多人力、物力、财力,既不方便管理者的管理,也不方便顾客实时了解餐厅动态,给传统餐......
  • python之高阶内容
    规范使用:类和对象模块导入,模块内部参数是:if__name__=="__main__":导入包(需要使用的公共代码模块):创建python包,里面放共同模块异常捕获优化:自定义异常classMyError(Exception):#异常捕获的类def__init__(self,length,min_len):#length为用户输入的密码长度......
  • python 百度翻译实例
    #-*-coding:utf-8-*-#ThiscodeshowsanexampleoftexttranslationfromEnglishtoSimplified-Chinese.#ThiscoderunsonPython2.7.xandPython3.x.#Youmayinstall`requests`torunthiscode:pipinstallrequests#Pleasereferto`https://a......
  • python入门(1)基础知识介绍
    print函数a=10print(a)print(10)print("您好")print(a,b,"您好")print(chr(98))#chr将98转换为ASVCII值print("你好"+"上海")#都是字符串可以用+连接输出print('您好',end='不换行')#修改结束符,不换行,否则自动视为有\nfp=open("note.txt&......
  • 机器学习领域中选择使用Python还是R
    在机器学习领域中,选择使用Python还是R,这主要取决于个人需求、项目特性、技能水平以及偏好。以下是对两种语言在机器学习方面的详细比较:一、社区支持与生态系统Python:Python在数据科学和机器学习领域拥有庞大的社区支持,这意味着你可以轻松找到大量的教程、文档、库和框架。......
  • 《最新出炉》系列初窥篇-Python+Playwright自动化测试-64 - Canvas和SVG元素推拽
    1.简介今天宏哥分享的在实际测试工作中很少遇到,比较生僻,如果突然遇到我们可能会脑大、懵逼,一时之间不知道怎么办?所以宏哥这里提供一种思路供大家学习和参考。2.SVG简介svg也是html5新增的一个标签,它跟canvas很相似。都可以实现绘图、动画。但是svg绘制出来的都是矢量图,不像canv......
  • 无法写入使用 pygbag 编译的 python/pygame 程序中的文本文件
    我有一个python/pygame程序,它从与该程序位于同一目录中的测试文件中读取数据。在程序结束时,应该将文本写回测试文件。这在Python环境中运行程序时有效,但在使用Pygbag编译并在浏览器中运行时无效。程序(称为main,py)是:importasyncioimportosimportpygamepyg......