首页 > 其他分享 >基于N-HiTS神经层次插值模型的时间序列预测——cross validation交叉验证与ray tune超参数优化

基于N-HiTS神经层次插值模型的时间序列预测——cross validation交叉验证与ray tune超参数优化

时间:2025-01-03 17:59:17浏览次数:3  
标签:HiTS 512 idx df series cross ray hat size

论文链接:https://arxiv.org/pdf/2201.12886v3


N-HiTS: Neural Hierarchical Interpolation for TimeSeries Forecasting \begin{aligned} &\text{\large \color{#CDA59E}N-HiTS: Neural Hierarchical Interpolation for TimeSeries Forecasting}\\ \end{aligned} ​N-HiTS: Neural Hierarchical Interpolation for TimeSeries Forecasting​
NHITS builds upon NBEATS and specializes its partial outputs in the different frequencies of the time series through hierarchical interpolation and multi-rate input processing. On the long-horizon forecasting task NHITS improved accuracy by 25% on AAAI’s best paper award the Informer, while being 50x faster.

References
-Boris N. Oreshkin, Dmitri Carpov, Nicolas Chapados, Yoshua Bengio (2019). “N-BEATS: Neural basis expansion analysis for interpretable time series forecasting”.
-Cristian Challu, Kin G. Olivares, Boris N. Oreshkin, Federico Garza, Max Mergenthaler-Canseco, Artur Dubrawski (2023). “NHITS: Neural Hierarchical Interpolation for Time Series Forecasting”. Accepted at the Thirty-Seventh AAAI Conference on Artificial Intelligence.
-Zhou, H.; Zhang, S.; Peng, J.; Zhang, S.; Li, J.; Xiong, H.; and Zhang, W. (2020). “Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting”. Association for the Advancement of Artificial Intelligence Conference 2021 (AAAI 2021).

在这里插入图片描述


前言

系列专栏:【深度学习:算法项目实战】✨︎
涉及医疗健康、财经金融、商业零售、食品饮料、运动健身、交通运输、环境科学、社交媒体以及文本和图像处理等诸多领域,讨论了各种复杂的深度神经网络思想,如卷积神经网络、循环神经网络、生成对抗网络、门控循环单元、长短期记忆、自然语言处理、深度强化学习、大型语言模型和迁移学习。

NHITS是一种解决时间序列长期预测中波动性和计算复杂性的模型。它采用了分层插值和多率数据采样技术,通过构建分层结构来降低计算成本并提高预测精度‌。相较于最新的Transformer架构,NHITS在平均精度上提升了16%,同时计算时间减少了50倍‌。这种模型能够更有效地处理时间序列数据,为时间序列分析提供了新的方法。

具体来说,NHITS通过结合新的分层插值和多率数据采样技术,解决了长期预测中的两个常见挑战:预测的波动性和计算复杂性。这些技术使NHITS能够依次组装其预测,强调具有不同频率和尺度的分量,同时分解输入信号并合成预测‌。这种独特的处理方式使得NHITS在长期预测任务中表现出色。

文章目录

import pandas as pd
import matplotlib.pyplot as plt

from ray import tune
from neuralforecast.auto import AutoNHITS
from neuralforecast.core import NeuralForecast
from neuralforecast.losses.numpy import mae, mse, mape, rmse

from datasetsforecast.long_horizon import LongHorizon

1. 数据集加载

datasetsforecast 是一个用于处理时间序列预测相关数据集的库。它的主要目的是方便用户获取、加载和预处理适合于时间序列预测任务的数据集。在时间序列分析和预测领域,拥有高质量、合适的数据集是非常关键的一步,这个库能够帮助我们更高效地开展工作。

# Change this to your own data to try the model
Y_df, X_df, _ = LongHorizon.load(directory='./', group='ETTm2')

2. 数据预处理

Y_df['ds'] = pd.to_datetime(Y_df['ds'])
# For this excercise we are going to take 20% of the DataSet
n_time = len(Y_df.ds.unique())
val_size = int(.2 * n_time)
test_size = int(.2 * n_time)

Y_df.groupby('unique_id').head(2)

3. 数据可视化

# We are going to plot the temperature of the transformer
# and marking the validation and train splits
u_id = 'HUFL'
x_plot = pd.to_datetime(Y_df[Y_df.unique_id==u_id].ds)
y_plot = Y_df[Y_df.unique_id==u_id].y.values

x_val = x_plot[n_time - val_size - test_size]
x_test = x_plot[n_time - test_size]

fig = plt.figure(figsize=(10, 5))
fig.tight_layout()

plt.plot(x_plot, y_plot)
plt.xlabel('Date', fontsize=17)
plt.ylabel('HUFL [15 min temperature]', fontsize=17)

plt.axvline(x_val, color='black', linestyle='-.')
plt.axvline(x_test, color='black', linestyle='-.')
plt.text(x_val, 5, '  Validation', fontsize=12)
plt.text(x_test, 5, '  Test', fontsize=12)

plt.grid()

HUFL

4. 定义超参数

Ray Tune 是一个用于超参数优化的库,它是基于 Ray 框架的一部分。Ray 是一个开源的分布式计算框架,旨在简化并行和分布式Python编程。Ray Tune 专门设计用来帮助开发者高效地搜索机器学习模型的超参数空间,以找到性能最佳的模型配置

horizon = 96 # 24hrs = 4 * 15 min.

# Use your own config or AutoNHITS.default_config
nhits_config = {
       "learning_rate": tune.choice([1e-3]),                                     # Initial Learning rate
       "max_steps": tune.choice([1000]),                                         # Number of SGD steps
       "input_size": tune.choice([5 * horizon]),                                 # input_size = multiplier * horizon
       "batch_size": tune.choice([7]),                                           # Number of series in windows
       "windows_batch_size": tune.choice([256]),                                 # Number of windows in batch
       "n_pool_kernel_size": tune.choice([[2, 2, 2], [16, 8, 1]]),               # MaxPool's Kernel size
       "n_freq_downsample": tune.choice([[168, 24, 1], [24, 12, 1], [1, 1, 1]]), # Interpolation expressivity ratios
       "activation": tune.choice(['ReLU']),                                      # Type of non-linear activation
       "n_blocks":  tune.choice([[1, 1, 1]]),                                    # Blocks per each 3 stacks
       "mlp_units":  tune.choice([[[512, 512], [512, 512], [512, 512]]]),        # 2 512-Layers per block for each stack
       "interpolation_mode": tune.choice(['linear']),                            # Type of Multi-step interpolation
       "val_check_steps": tune.choice([100]),                                    # Compute validation every 100 epochs
       "random_seed": tune.randint(3, 5),
    }

5. 构建模型

方法一:NeuralForecast 库的 Auto 模块主要用于自动模型选择和超参数调整,以帮助用户更高效地构建和优化时间序列预测模型。它能够在一系列预定义的模型和参数组合中进行搜索,找到对于给定数据集比较合适的模型配置。

nf = NeuralForecast(
    models = [
        AutoNHITS(h=horizon,
                  config=nhits_config,
                  num_samples=5
                  )
    ],
    freq='15min')

方法二:NeuralForecast 库的 models 模块也可以用于构建 N-HiTS 神经层次插值模型,它与 AutoNHITS 模型相比需要设置大量的参数,对用户的专业知识和经验要求较高,调参过程费时费力。

from neuralforecast.models import NHITS
nf = NeuralForecast(
    models = [
        NHITS(h=horizon,                                         # Forecasting horizon 预测步长
              input_size=5 * horizon,                            # Input size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2] 时间步
              stat_exog_list=None,                               # static exogenous columns 静态外生列
              hist_exog_list=None,                               # historic exogenous columns
              futr_exog_list=None,                               # future exogenous columns
              exclude_insample_y=False,                          # the model skips the autoregressive features y[t-input_size:t] if True
              stack_types = ["identity", "identity", "identity"],
              n_blocks= [1, 1, 1],
              mlp_units= 3 * [[512, 512]],
              n_pool_kernel_size = [2, 2, 1],
              n_freq_downsample = [4, 2, 1],
              pooling_mode = "MaxPool1d",
              interpolation_mode = "linear",
              dropout_prob_theta=0.0,
              activation="ReLU",
              loss=MAE(),
              valid_loss=None,
              max_steps = 1000,
              learning_rate = 1e-3,
              num_lr_decays = 3,
              early_stop_patience_steps = -1,
              val_check_steps = 100,
              batch_size = 32,
              valid_batch_size = None,
              windows_batch_size = 1024,
              inference_windows_batch_size = -1,
              start_padding_enabled=False,
              step_size = 1,
              scaler_type = "identity",
              random_seed = 1,
              num_workers_loader=0,
              drop_last_loader=False,
              optimizer=None,
              optimizer_kwargs=None,
              lr_scheduler=None,
              lr_scheduler_kwargs=None,
              dataloader_kwargs=None,
        )
    ],
    freq='15min'
)

exclude_insample_y: bool=False, the model skips the autoregressive features y[t-input_size:t] if True.意思是如果设置为True,说明模型会跳过(也就是不使用、忽略)自回归特征中从 y t − i n p u t s i z e y_{t-inputsize} yt−inputsize​到 y t y_t yt​这一部分数据。正常情况下,这些数据往往会被纳入模型的输入,作为帮助模型学习时间序列规律以及进行预测的重要依据。但当满足上述条件时,模型就不会把这一段对应的历史时间序列值当作输入信息了,相当于切断了这部分自回归的信息链路,模型会基于其他可用的输入(比如外生变量、其他历史阶段的数据等,如果有的话)来进行后续的处理和预测工作。

6. 交叉验证

交叉验证方法 cross_validation 将返回模型在测试集上的预测结果。这里我们使用第一种方法进行交叉验证

Y_hat_df = nf.cross_validation(df=Y_df, val_size=val_size,
                               test_size=test_size, n_windows=None)
nf.models[0].results.get_best_result().config
{'learning_rate': 0.001,
 'max_steps': 1000,
 'input_size': 480,
 'batch_size': 7,
 'windows_batch_size': 256,
 'n_pool_kernel_size': [2, 2, 2],
 'n_freq_downsample': [1, 1, 1],
 'activation': 'ReLU',
 'n_blocks': [1, 1, 1],
 'mlp_units': [[512, 512], [512, 512], [512, 512]],
 'interpolation_mode': 'linear',
 'val_check_steps': 100,
 'random_seed': 3,
 'h': 96,
 'loss': MAE(),
 'valid_loss': MAE()}

7. 预测结果

y_true = Y_hat_df.y.values
y_hat = Y_hat_df['AutoNHITS'].values

n_series = len(Y_df.unique_id.unique())

y_true = y_true.reshape(n_series, -1, horizon)
y_hat = y_hat.reshape(n_series, -1, horizon)

print('Parsed results')
print('2. y_true.shape (n_series, n_windows, n_time_out):\t', y_true.shape)
print('2. y_hat.shape  (n_series, n_windows, n_time_out):\t', y_hat.shape)
Parsed results
2. y_true.shape (n_series, n_windows, n_time_out):	 (7, 11425, 96)
2. y_hat.shape  (n_series, n_windows, n_time_out):	 (7, 11425, 96)
fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(10, 11))
fig.tight_layout()

series = ['HUFL','HULL','LUFL','LULL','MUFL','MULL','OT']
series_idx = 3

for idx, w_idx in enumerate([200, 300, 400]):
  axs[idx].plot(y_true[series_idx, w_idx,:],label='True')
  axs[idx].plot(y_hat[series_idx, w_idx,:],label='Forecast')
  axs[idx].grid()
  axs[idx].set_ylabel(series[series_idx]+f' window {w_idx}',
                      fontsize=17)
  if idx==2:
    axs[idx].set_xlabel('Forecast Horizon', fontsize=17)
plt.legend()
plt.show()
#plt.savefig('./results/HUFL_window.png', dpi=300)
plt.close()

在这里插入图片描述

8. 模型评估

以下代码使用了一些常见的评估指标:平均绝对误差(MAE)、平均绝对百分比误差(MAPE)、均方误差(MSE)、均方根误差(RMSE)来衡量模型预测的性能。这里我们将调用 neuralforecast.losses.numpy 模块中的 mae, mse, mape, rmse 函数来对模型的预测效果进行评估。

mae = mae(Y_hat_df['y'], Y_hat_df['AutoNHITS'])
print(f"MAE: {mae:.4f}")

mape = mape(Y_hat_df['y'], Y_hat_df['AutoNHITS'])
print(f"MAPE: {mape * 100:.4f}%")

mse = mse(Y_hat_df['y'], Y_hat_df['AutoNHITS'])
print(f"MSE: {mse:.4f}")

rmse = rmse(Y_hat_df['y'], Y_hat_df['AutoNHITS'])
print(f"RMSE: {rmse:.4f}")

标签:HiTS,512,idx,df,series,cross,ray,hat,size
From: https://blog.csdn.net/m0_63287589/article/details/144893387

相关文章

  • Java集合 —— ArrayList详解(源码)
    我这里阅读的是JDK17关于ArrayList的源码,不过思路都是一样的简介 ArrayList是一个可以动态修改的数组,与普通数组的区别就是它是没有固定大小的限制,我们可以添加或删除元素。 ArrayList继承了AbstractList,并实现了List接口。属性设置//序列化Idprivatestatic......
  • Ray 源码分析系列(8)—RuntimeEnv
    前言运行时的环境管理是最容易被大家忽略的部分,如果只是一个人使用,确实不会是什么大问题。但如果是几百人使用,同时单任务涉及到数十个分布式节点呢?答案显而易见,很容易形成木桶效应,还有就是本机磁盘容易OOM。使用示例假如没有使用过ray,这里来个简单的示例,大家理解起来可能......
  • CopyOnWriteArraySet与CopyOnWriteArrayList
    这两个集合都支持写复制,在并发性方面比,ArrayList,LinkList要好一些。适用场景:读多邪少的情况看下源码为甚么读多写少的情况下比较好第一步:CopyOnWriteArraySetcopyOnWriteArraySet=newCopyOnWriteArraySet<>();copyOnWriteArraySet......
  • 叉乘 CrossProduct
    更新日志2025/1/1:开工。公式\[(a,b)\times(c,d)=ad-bc\]简介考虑如下的两个向量,它们之间叉乘的绝对值就是那个平行四边形的面积:[没有开网,上传失败]你发现,叉乘是有正负的。具体的,对于\(\vecA\times\vecB\),若\(\vecB\)在\(\vecA\)逆时针方向,就是正的。顺时针......
  • 如何在 Ubuntu 20.04 上部署 Graylog 日志管理平台教程
    如何在Ubuntu20.04上部署Graylog日志管理平台教程简介Graylog是一个开源的、基于Web的日志管理和聚合系统,它可以帮助你高效地管理和分析大量日志数据。通过收集服务器日志,并使用Elasticsearch进行索引,以及MongoDB保存元数据,Graylog使得详细的日志分析成为可......
  • js数组-实例方法:Array.prototype.findLast(),Array.prototype.findLastIndex(),Array
    Array.prototype.findLast()findLast()方法反向迭代数组,并返回满足提供的测试函数的第一个元素的值。如果没有找到对应元素,则返回undefined语法findLast(callbackFn)findLast(callbackFn,thisArg)参数callbackFn:数组中测试元素的函数。回调应该返回一个真值,表示已......
  • 【深度学习基础|知识概述】基础数学和理论知识中的信息论知识:交叉熵(Cross-Entropy)和KL
    【深度学习基础|知识概述】基础数学和理论知识中的信息论知识:交叉熵(Cross-Entropy)和KL散度(Kullback-LeiblerDivergence)的应用,附代码。【深度学习基础|知识概述】基础数学和理论知识中的信息论知识:交叉熵(Cross-Entropy)和KL散度(Kullback-LeiblerDivergence)的应用,附代码。......
  • Elasticsearch:如何在搜索时得到精确的总 hits 数
    Elasticsearch:如何在搜索时得到精确的总hits数|Id|Title|DateAdded|SourceUrl|PostType|Body|BlogId|Description|DateUpdated|IsMarkdown|EntryName|CreatedTime|IsActive|AutoDesc|AccessPermission||-------------|-------------|--------......
  • wx.arrayBufferToBase64
    stringwx.arrayBufferToBase64(ArrayBufferarrayBuffer)从基础库2.4.0开始,本接口停止维护基础库1.1.0开始支持,低版本需做兼容处理。小程序插件:支持微信Windows版:支持微信Mac版:支持微信鸿蒙OS版:支持功能描述将ArrayBuffer对象转成Base64字符串参数A......
  • Array.from()
    Array.from() 功能:将类数组对象转换为数组将字符串转换为数组拷贝一个素组Array.from()方法就是将一个类数组对象或者可遍历对象转换成一个真正的数组。所谓类数组对象,最基本的要求就是具有length属性的对象。第一个接收参数可以是:类数组对象/字符串/数组/{length:长......