首页 > 其他分享 >Transformer 例子2

Transformer 例子2

时间:2024-02-28 11:48:21浏览次数:15  
标签:dim Transformer labels train 例子 output input data

一个多维数据输入的例子:

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# 构造简单的多维时间序列数据集
def generate_multivariate_time_series(num_samples, seq_length, input_dim):
    data = np.random.randn(num_samples, seq_length, input_dim)
    labels = np.sum(data, axis=1)[:, 0]  # 简单地将所有维度的值相加作为标签
    return data, labels

# 定义 Transformer 模型
class TransformerModel(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers, heads, hidden_size):
        super(TransformerModel, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=heads, dim_feedforward=hidden_size)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers)
        self.decoder = nn.Linear(input_dim, output_dim)

    def forward(self, src):
        src = src.permute(1, 0, 2)  # 调整输入的维度顺序
        output = self.transformer_encoder(src)
        output = self.decoder(output[-1])  # 取最后一个时间步的输出
        return output

# 准备数据
input_dim = 3  # 输入维度
output_dim = 1  # 输出维度
num_samples = 1000  # 样本数
seq_length = 10  # 序列长度

data, labels = generate_multivariate_time_series(num_samples, seq_length, input_dim)
train_size = int(num_samples * 0.8)
train_data, train_labels = data[:train_size], labels[:train_size]
test_data, test_labels = data[train_size:], labels[train_size:]

# 准备训练数据加载器
train_dataset = torch.utils.data.TensorDataset(torch.tensor(train_data).float(), torch.tensor(train_labels).float())
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

# 定义模型和优化器
model = TransformerModel(input_dim=input_dim, output_dim=output_dim, num_layers=2, heads=3, hidden_size=128)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练模型
model.train()
for epoch in range(20):
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs.squeeze(), labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/20, Loss: {running_loss}")

# 测试模型
model.eval()
test_inputs = torch.tensor(test_data).float()
with torch.no_grad():
    predicted = model(test_inputs).squeeze().numpy()

# 可视化预测结果
plt.plot(test_labels, label='True Labels')
plt.plot(predicted, label='Predicted Labels')
plt.legend()
plt.show()

  

标签:dim,Transformer,labels,train,例子,output,input,data
From: https://www.cnblogs.com/kingkaixuan/p/18039827

相关文章

  • Transformer 例子
    据说很好用,先写一个例子看看:importtorchimporttorch.nnasnnimportnumpyasnpimportmatplotlib.pyplotasplt#构造简单的时间序列数据集defgenerate_time_series():time=np.arange(0,100,0.1)amplitude=np.sin(time)returnamplitude#将......
  • offline RL · RLHF · PbRL | OPPO:PbRL 场景的 offline hindsight transformer
    论文题目:BeyondReward:OfflinePreference-guidedPolicyOptimization,ICML2023,3368reject。(已经忘记当初为何加进readinglist了,可能因为abstract太炫酷了?就当作学习经验教训吧…)材料:pdf版本:https://arxiv.org/pdf/2305.16217.pdfhtml版本:https://ar5iv.labs......
  • offline RL | 读读 Decision Transformer
    论文标题:DecisionTransformer:ReinforcementLearningviaSequenceModeling,NeurIPS2021,6679poster(怎么才poster)。pdf:https://arxiv.org/pdf/2106.01345.pdfhtml:https://ar5iv.labs.arxiv.org/html/2106.01345openreview:https://openreview.net/forum?id=a7APmM4......
  • 【论文随笔】多行为序列Transformer推荐(Multi-Behavior Sequential Transformer Reco
    前言今天读的论文为一篇于2022年7月发表在第45届国际计算机学会信息检索会议(SIGIR'22)的论文,文章主要为推荐系统领域提供了一个新的视角,特别是在处理用户多行为序列数据方面,提出了一种有效的Transformer模型框架。要引用这篇论文,请使用以下格式:[1]Yuan,Enming,etal."Multi......
  • StampedLock 使用例子
    StampedLock是Java8引入的一种新的锁机制,它是ReadWriteLock的改进版,提供了更高的并发性和更好的性能。下面是一个使用StampedLock的示例:importjava.util.concurrent.locks.StampedLock;publicclassStampedLockDemo{privatedoublex,y;privatefinalSt......
  • Qt的拖拽操作例子
    本文是一个拖拽文本的例子。演示了如何把按钮的标题拖拽到文本编辑框里。Qt对拖拽的封装很好,QDrag对象简单易用。本文程序测试环境是VS2017和Qt5.9。下面是程序拖拽时的效果图,可以看出来拖拽的时候光标下方也显示了文本内容:头文件。本功能是在主窗口中实现的。下面代码里QtTest......
  • OpenCL切换显卡的例子
    在一些有多个显卡,比如一个核芯显卡和一个独立显卡的系统中使用显卡加速,OpenCL默认的设备可能不是性能更好的独立显卡。这时候可以用下述方法更换显卡,代码如下。本例在VS2015和OpenCL3.0下测试通过:conststringkernel=u8R"(kernelvoidreduceSum(globalint*num,globa......
  • Mamba详细介绍和RNN、Transformer的架构可视化对比
    Transformer体系结构已经成为大型语言模型(llm)成功的主要组成部分。为了进一步改进llm,人们正在研发可能优于Transformer体系结构的新体系结构。其中一种方法是Mamba(一种状态空间模型)。Mamba:Linear-TimeSequenceModelingwithSelectiveStateSpaces一文中提出了Mamba,我们......
  • 运行 decision transformer 遇到的问题
    简介本质上强化学习也是为了预测下一个action,那能否借用大模型的方法来实现next-action的预测。业界有多篇借用大模型的方法(transfomer)来实现这个目的。伯克利的这篇算是最为彻底和简洁。https://sites.google.com/berkeley.edu/decision-transformer transfomer官方网站......
  • python-transformers库
    python-transformers库目录python-transformers库安装测试功能和优势Transformers术语模型与分词器加载预训练模型保存模型分词器编码和解码填充Padpipelinepipeline简介pipeline原理参考资料transformers是一个用于自然语言处理(NLP)任务,如文本分类、命名实体识别,机器翻译等,提供......