[ICLR 2024]TIMEMIXER: DECOMPOSABLE MULTISCALE MIXING FOR TIME SERIES FORECASTING
研究背景与动机
- 现有方法的局限性:尽管分解方法和多周期性分析方法在时间序列预测中取得了进展,但它们往往忽略了不同尺度上信息的重要性。分解方法将时间序列分解成趋势和季节性等成分,但忽略了不同尺度上信息之间的相互作用;多周期性分析方法将时间序列分解成不同周期长度的成分,但忽略了不同尺度上信息之间的互补性。
- 多尺度分析的优势:时间序列在不同尺度上表现出不同的特征,例如,细粒度的时间序列可以捕捉到更详细的模式,而粗粒度的时间序列可以捕捉到更宏观的趋势。多尺度分析可以有效地将时间序列分解成不同的成分,并利用不同尺度上信息之间的相互作用,从而提高时间序列预测的准确性。
模型和方法
TimeMixer的网络结构如下图所示。
多尺度混合架构
- 数据降采样: 首先,模型将原始时间序列通过平均池化操作进行降采样,生成多个不同尺度的序列,分别代表细粒度和粗粒度的时间序列特征。
- 嵌入层: 将不同尺度的序列映射到高维特征空间,以便进行后续的处理。
- Past-Decomposable-Mixing (PDM) 块: 对不同尺度的序列进行分解,将趋势和季节性成分分离,并进行多尺度混合,以捕捉不同尺度上信息的互补性。
- Future-Multipredictor-Mixing (FMM) 块: 对不同尺度上的预测结果进行集成,以充分利用不同尺度上信息的预测能力。
Past-Decomposable-Mixing (PDM) 块
- 分解: 使用 Autoformer 论文中提出的分解模块,将时间序列分解成趋势和季节性成分。
- 季节性混合: 采用自底向上的混合方式,将细粒度季节性信息逐层聚合到粗粒度季节性信息中。
- 趋势性混合: 采用自顶向下的混合方式,将粗粒度趋势性信息逐层传递到细粒度趋势性信息中。
下面就是一个示意图。展示了它的季节分量和趋势分量是按照什么样的规律进行混合的。
Future-Multipredictor-Mixing (FMM) 块
- 多尺度预测: 对不同尺度上的特征进行预测,得到不同尺度上的预测结果。
- 集成预测: 将不同尺度上的预测结果进行集成,得到最终的预测结果。
代码
代码:TimeMixer
在Time-Series-Library的model文件夹,会有TimeMixer的代码。代码很长,这个就是比较复杂的了。但是它的整体流程是比较好理解的。学好这份代码还是很棒的。
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Autoformer_EncDec import series_decomp
from layers.Embed import DataEmbedding_wo_pos
from layers.StandardNorm import Normalize
class DFT_series_decomp(nn.Module):
"""
Series decomposition block
"""
def __init__(self, top_k=5):
super(DFT_series_decomp, self).__init__()
self.top_k = top_k
def forward(self, x):
xf = torch.fft.rfft(x)
freq = abs(xf)
freq[0] = 0
top_k_freq, top_list = torch.topk(freq, 5)
xf[freq <= top_k_freq.min()] = 0
x_season = torch.fft.irfft(xf)
x_trend = x - x_season
return x_season, x_trend
class MultiScaleSeasonMixing(nn.Module):
"""
Bottom-up mixing season pattern
"""
def __init__(self, configs):
super(MultiScaleSeasonMixing, self).__init__()
self.down_sampling_layers = torch.nn.ModuleList(
[
nn.Sequential(
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
),
nn.GELU(),
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
),
)
for i in range(configs.down_sampling_layers)
]
)
def forward(self, season_list):
# mixing high->low
out_high = season_list[0]
out_low = season_list[1]
out_season_list = [out_high.permute(0, 2, 1)]
for i in range(len(season_list) - 1):
out_low_res = self.down_sampling_layers[i](out_high)
out_low = out_low + out_low_res
out_high = out_low
if i + 2 <= len(season_list) - 1:
out_low = season_list[i + 2]
out_season_list.append(out_high.permute(0, 2, 1))
return out_season_list
class MultiScaleTrendMixing(nn.Module):
"""
Top-down mixing trend pattern
"""
def __init__(self, configs):
super(MultiScaleTrendMixing, self).__init__()
self.up_sampling_layers = torch.nn.ModuleList(
[
nn.Sequential(
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** (i + 1)),
configs.seq_len // (configs.down_sampling_window ** i),
),
nn.GELU(),
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.seq_len // (configs.down_sampling_window ** i),
),
)
for i in reversed(range(configs.down_sampling_layers))
])
def forward(self, trend_list):
# mixing low->high
trend_list_reverse = trend_list.copy()
trend_list_reverse.reverse()
out_low = trend_list_reverse[0]
out_high = trend_list_reverse[1]
out_trend_list = [out_low.permute(0, 2, 1)]
for i in range(len(trend_list_reverse) - 1):
out_high_res = self.up_sampling_layers[i](out_low)
out_high = out_high + out_high_res
out_low = out_high
if i + 2 <= len(trend_list_reverse) - 1:
out_high = trend_list_reverse[i + 2]
out_trend_list.append(out_low.permute(0, 2, 1))
out_trend_list.reverse()
return out_trend_list
class PastDecomposableMixing(nn.Module):
def __init__(self, configs):
super(PastDecomposableMixing, self).__init__()
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.down_sampling_window = configs.down_sampling_window
self.layer_norm = nn.LayerNorm(configs.d_model)
self.dropout = nn.Dropout(configs.dropout)
self.channel_independence = configs.channel_independence
if configs.decomp_method == 'moving_avg':
self.decompsition = series_decomp(configs.moving_avg)
elif configs.decomp_method == "dft_decomp":
self.decompsition = DFT_series_decomp(configs.top_k)
else:
raise ValueError('decompsition is error')
if configs.channel_independence == 0:
self.cross_layer = nn.Sequential(
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff),
nn.GELU(),
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model),
)
# Mixing season
self.mixing_multi_scale_season = MultiScaleSeasonMixing(configs)
# Mxing trend
self.mixing_multi_scale_trend = MultiScaleTrendMixing(configs)
self.out_cross_layer = nn.Sequential(
nn.Linear(in_features=configs.d_model, out_features=configs.d_ff),
nn.GELU(),
nn.Linear(in_features=configs.d_ff, out_features=configs.d_model),
)
def forward(self, x_list):
length_list = []
for x in x_list:
_, T, _ = x.size()
length_list.append(T)
# Decompose to obtain the season and trend
season_list = []
trend_list = []
for x in x_list:
season, trend = self.decompsition(x)
if self.channel_independence == 0:
season = self.cross_layer(season)
trend = self.cross_layer(trend)
season_list.append(season.permute(0, 2, 1))
trend_list.append(trend.permute(0, 2, 1))
# bottom-up season mixing
out_season_list = self.mixing_multi_scale_season(season_list)
# top-down trend mixing
out_trend_list = self.mixing_multi_scale_trend(trend_list)
out_list = []
for ori, out_season, out_trend, length in zip(x_list, out_season_list, out_trend_list,
length_list):
out = out_season + out_trend
if self.channel_independence:
out = ori + self.out_cross_layer(out)
out_list.append(out[:, :length, :])
return out_list
class Model(nn.Module):
def __init__(self, configs):
super(Model, self).__init__()
self.configs = configs
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.label_len = configs.label_len
self.pred_len = configs.pred_len
self.down_sampling_window = configs.down_sampling_window
self.channel_independence = configs.channel_independence
self.pdm_blocks = nn.ModuleList([PastDecomposableMixing(configs)
for _ in range(configs.e_layers)])
self.preprocess = series_decomp(configs.moving_avg)
self.enc_in = configs.enc_in
if self.channel_independence == 1:
self.enc_embedding = DataEmbedding_wo_pos(1, configs.d_model, configs.embed, configs.freq,
configs.dropout)
else:
self.enc_embedding = DataEmbedding_wo_pos(configs.enc_in, configs.d_model, configs.embed, configs.freq,
configs.dropout)
self.layer = configs.e_layers
self.normalize_layers = torch.nn.ModuleList(
[
Normalize(self.configs.enc_in, affine=True, non_norm=True if configs.use_norm == 0 else False)
for i in range(configs.down_sampling_layers + 1)
]
)
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.predict_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.pred_len,
)
for i in range(configs.down_sampling_layers + 1)
]
)
if self.channel_independence == 1:
self.projection_layer = nn.Linear(
configs.d_model, 1, bias=True)
else:
self.projection_layer = nn.Linear(
configs.d_model, configs.c_out, bias=True)
self.out_res_layers = torch.nn.ModuleList([
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.seq_len // (configs.down_sampling_window ** i),
)
for i in range(configs.down_sampling_layers + 1)
])
self.regression_layers = torch.nn.ModuleList(
[
torch.nn.Linear(
configs.seq_len // (configs.down_sampling_window ** i),
configs.pred_len,
)
for i in range(configs.down_sampling_layers + 1)
]
)
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
if self.channel_independence == 1:
self.projection_layer = nn.Linear(
configs.d_model, 1, bias=True)
else:
self.projection_layer = nn.Linear(
configs.d_model, configs.c_out, bias=True)
if self.task_name == 'classification':
self.act = F.gelu
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(
configs.d_model * configs.seq_len, configs.num_class)
def out_projection(self, dec_out, i, out_res):
dec_out = self.projection_layer(dec_out)
out_res = out_res.permute(0, 2, 1)
out_res = self.out_res_layers[i](out_res)
out_res = self.regression_layers[i](out_res).permute(0, 2, 1)
dec_out = dec_out + out_res
return dec_out
def pre_enc(self, x_list):
if self.channel_independence == 1:
return (x_list, None)
else:
out1_list = []
out2_list = []
for x in x_list:
x_1, x_2 = self.preprocess(x)
out1_list.append(x_1)
out2_list.append(x_2)
return (out1_list, out2_list)
def __multi_scale_process_inputs(self, x_enc, x_mark_enc):
if self.configs.down_sampling_method == 'max':
down_pool = torch.nn.MaxPool1d(self.configs.down_sampling_window, return_indices=False)
elif self.configs.down_sampling_method == 'avg':
down_pool = torch.nn.AvgPool1d(self.configs.down_sampling_window)
elif self.configs.down_sampling_method == 'conv':
padding = 1 if torch.__version__ >= '1.5.0' else 2
down_pool = nn.Conv1d(in_channels=self.configs.enc_in, out_channels=self.configs.enc_in,
kernel_size=3, padding=padding,
stride=self.configs.down_sampling_window,
padding_mode='circular',
bias=False)
else:
return x_enc, x_mark_enc
# B,T,C -> B,C,T
x_enc = x_enc.permute(0, 2, 1)
x_enc_ori = x_enc
x_mark_enc_mark_ori = x_mark_enc
x_enc_sampling_list = []
x_mark_sampling_list = []
x_enc_sampling_list.append(x_enc.permute(0, 2, 1))
x_mark_sampling_list.append(x_mark_enc)
for i in range(self.configs.down_sampling_layers):
x_enc_sampling = down_pool(x_enc_ori)
x_enc_sampling_list.append(x_enc_sampling.permute(0, 2, 1))
x_enc_ori = x_enc_sampling
if x_mark_enc is not None:
x_mark_sampling_list.append(x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :])
x_mark_enc_mark_ori = x_mark_enc_mark_ori[:, ::self.configs.down_sampling_window, :]
x_enc = x_enc_sampling_list
x_mark_enc = x_mark_sampling_list if x_mark_enc is not None else None
return x_enc, x_mark_enc
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
x_list = []
x_mark_list = []
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
B, T, N = x.size()
x = self.normalize_layers[i](x, 'norm')
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
x_mark = x_mark.repeat(N, 1, 1)
x_mark_list.append(x_mark)
else:
for i, x in zip(range(len(x_enc)), x_enc, ):
B, T, N = x.size()
x = self.normalize_layers[i](x, 'norm')
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
# embedding
enc_out_list = []
x_list = self.pre_enc(x_list)
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_list[0])), x_list[0], x_mark_list):
enc_out = self.enc_embedding(x, x_mark) # [B,T,C]
enc_out_list.append(enc_out)
else:
for i, x in zip(range(len(x_list[0])), x_list[0]):
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# Past Decomposable Mixing as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
# Future Multipredictor Mixing as decoder for future
dec_out_list = self.future_multi_mixing(B, enc_out_list, x_list)
dec_out = torch.stack(dec_out_list, dim=-1).sum(-1)
dec_out = self.normalize_layers[0](dec_out, 'denorm')
return dec_out
def future_multi_mixing(self, B, enc_out_list, x_list):
dec_out_list = []
if self.channel_independence == 1:
x_list = x_list[0]
for i, enc_out in zip(range(len(x_list)), enc_out_list):
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
0, 2, 1) # align temporal dimension
dec_out = self.projection_layer(dec_out)
dec_out = dec_out.reshape(B, self.configs.c_out, self.pred_len).permute(0, 2, 1).contiguous()
dec_out_list.append(dec_out)
else:
for i, enc_out, out_res in zip(range(len(x_list[0])), enc_out_list, x_list[1]):
dec_out = self.predict_layers[i](enc_out.permute(0, 2, 1)).permute(
0, 2, 1) # align temporal dimension
dec_out = self.out_projection(dec_out, i, out_res)
dec_out_list.append(dec_out)
return dec_out_list
def classification(self, x_enc, x_mark_enc):
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
x_list = x_enc
# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
enc_out = enc_out_list[0]
# Output
# the output transformer encoder/decoder embeddings don't include non-linearity
output = self.act(enc_out)
output = self.dropout(output)
# zero-out padding embeddings
output = output * x_mark_enc.unsqueeze(-1)
# (batch_size, seq_length * d_model)
output = output.reshape(output.shape[0], -1)
output = self.projection(output) # (batch_size, num_classes)
return output
def anomaly_detection(self, x_enc):
B, T, N = x_enc.size()
x_enc, _ = self.__multi_scale_process_inputs(x_enc, None)
x_list = []
for i, x in zip(range(len(x_enc)), x_enc, ):
B, T, N = x.size()
x = self.normalize_layers[i](x, 'norm')
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
dec_out = self.projection_layer(enc_out_list[0])
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()
dec_out = self.normalize_layers[0](dec_out, 'denorm')
return dec_out
def imputation(self, x_enc, x_mark_enc, mask):
means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
means = means.unsqueeze(1).detach()
x_enc = x_enc - means
x_enc = x_enc.masked_fill(mask == 0, 0)
stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
torch.sum(mask == 1, dim=1) + 1e-5)
stdev = stdev.unsqueeze(1).detach()
x_enc /= stdev
B, T, N = x_enc.size()
x_enc, x_mark_enc = self.__multi_scale_process_inputs(x_enc, x_mark_enc)
x_list = []
x_mark_list = []
if x_mark_enc is not None:
for i, x, x_mark in zip(range(len(x_enc)), x_enc, x_mark_enc):
B, T, N = x.size()
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
x_mark = x_mark.repeat(N, 1, 1)
x_mark_list.append(x_mark)
else:
for i, x in zip(range(len(x_enc)), x_enc, ):
B, T, N = x.size()
if self.channel_independence == 1:
x = x.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)
x_list.append(x)
# embedding
enc_out_list = []
for x in x_list:
enc_out = self.enc_embedding(x, None) # [B,T,C]
enc_out_list.append(enc_out)
# MultiScale-CrissCrossAttention as encoder for past
for i in range(self.layer):
enc_out_list = self.pdm_blocks[i](enc_out_list)
dec_out = self.projection_layer(enc_out_list[0])
dec_out = dec_out.reshape(B, self.configs.c_out, -1).permute(0, 2, 1).contiguous()
dec_out = dec_out * \
(stdev[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
dec_out = dec_out + \
(means[:, 0, :].unsqueeze(1).repeat(1, self.seq_len, 1))
return dec_out
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out
if self.task_name == 'imputation':
dec_out = self.imputation(x_enc, x_mark_enc, mask)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
else:
raise ValueError('Other tasks implemented yet')
思考
TimeMixer 论文的核心创新在于将多尺度混合架构应用于时间序列预测,并通过 PDM 块和 FMM 块有效地利用不同尺度上信息的互补性,从而提高了时间序列预测的准确性。
熟悉CV领域的同学一看,这玩意是不是和YOLO优点像。以YOLOv5s为例子,做一个多尺度分析,然后一个PANnet结构,多个预测头预测,最后做nms。TimeMixer就是多尺度分析,适用于时间序列的季节趋势分析,做一个类PANnet的尺度间混合(针对时间序列做了有效的修改),多个预测头预测,合并预测结果。
你看,比较一下具体流程,是不是很相似。但是,TimeMixer的工作很好,它针对于时间序列做的修改都很有效。YOLO很多魔改,我也在TimeMixer做了一些尝试,当然不太成功。说明TimeMixer的效果真不错,但同时它的结构相对简单。比如季节趋势分量尺度间的混合是不是可以做一些细粒度的工作?应该还是有一些改进空间的。
参考文献:
[ICLR 2024]TIMEMIXER: DECOMPOSABLE MULTISCALE MIXING FOR TIME SERIES FORECASTING
图片来源:
[ICLR 2024]TIMEMIXER: DECOMPOSABLE MULTISCALE MIXING FOR TIME SERIES FORECASTING
https://cloud.tencent.com/developer/article/1922832