import torch
class MaxState(torch.nn.Module):
def __init__(self, hidden_dim, heads):
super(MaxState, self).__init__()
assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."
self.head_size = hidden_dim // heads
self.head0 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head1 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
self.head2 = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)
# self.h_linear=torch.nn.Parameter(torch.empty(1, 1))
# torch.nn.init.xavier_uniform_(self.h_linear,0.5)
# self.layer_nor = torch.nn.LayerNorm(hidden_dim)
# self.norm = torch.nn.LayerNorm(hidden_dim)
# self.alpha = torch.nn.Parameter(torch.tensor(0.5))
self.head_num = heads
self.hidden = hidden_dim
self.layer_nor = torch.nn.LayerNorm(hidden_dim)
def forward(self, input_data, state=None):
# self.head.to(device)
b, s, k, h = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size
out = self.head0(input_data)
out1 = self.head1(input_data)
out2 = self.head2(input_data)
out = out.reshape([b, s, k, h]).permute([0, 2, 1, 3])
out1 = out1.reshape([b, s, k, h]).permute([0, 2, 1, 3])
# out2 = out2.reshape([b, s, k, h]).permute([0, 2, 1, 3])
# out1 = self.head1(input_data).reshape([b, s, k, h]).permute([0, 2, 1, 3])
out = torch.cummax((out + out1) / h ** 0.5, 2)[0]
# out = torch.cummin((out + out1)/k**0.5 , 2)[0]
# out_sum = torch.cumsum((out + out1)/k**0.5 , 2)
# out=(out-out_min)*out
out = out.permute([0, 2, 1, 3])
out1 = out1.permute([0, 2, 1, 3])
# out2 = out2.permute([0, 2, 1, 3])
out = out.reshape([b, s, -1])
out1 = out1.reshape([b, s, -1])
# out2 = out2.reshape([b, s, -1])
# out = self.layer_nor(out)
# out = (out + out2) * out+out1
# out3=torch.cummax(out,1)[0]
# out = (out + out2) * out + out1
out = self.layer_nor(out + out2 + out1)
# out = self.alpha * out * (out + out2) + (1 - self.alpha) * out1
return out
class FeedForward(torch.nn.Module):
def __init__(self, hidden_size):
super(FeedForward, self).__init__()
self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)
self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)
self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)
# self.h_linear=torch.nn.Parameter(torch.empty(1, 1))
# self.gate = torch.nn.Parameter(torch.empty(hidden_size, hidden_size * 2))
# torch.nn.init.xavier_uniform_(self.gate,0.5)
self.relu = torch.nn.ReLU()
self.dr = torch.nn.Dropout(0.1)
def forward(self, x):
x1 = self.ffn1(x)
x2 = self.relu(self.gate(x))
xx = self.dr(x1 * x2)
x = self.ffn2(xx)
return x
class DecoderLayer(torch.nn.Module):
def __init__(self, hidden_size, num_heads):
super(DecoderLayer, self).__init__()
self.state = MaxState(hidden_size, num_heads)
self.state1 = MaxState(hidden_size, num_heads)
self.state2 = MaxState(hidden_size, num_heads)
self.decoder = FeedForward(hidden_size)
self.decoder1 = FeedForward(hidden_size)
self.decoder2 = FeedForward(hidden_size)
self.layer_nor = torch.nn.LayerNorm(hidden_size)
def forward(self, x):
x = self.state(x) + x
x1 = self.decoder(x)
x2 = self.state1(x) + x1
x2 = self.decoder1(x2)
x3 = self.state2(x1) + x2
x3 = self.layer_nor(x3)
x3 = self.decoder2(x3)
x = self.layer_nor(x1 + x2 + x3)
return x
class SamOut(torch.nn.Module):
def __init__(self, voc_size, hidden_size, num_heads, num_layers):
super(SamOut, self).__init__()
self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)
self.decoder = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])
self.head = torch.nn.Linear(hidden_size, voc_size, False)
def forward(self, x):
x = self.em(x)
for decoder in self.decoder:
x = decoder(x)
return self.head(x), ""
device = "cuda"
if __name__ == '__main__':
net = SamOut(235, 256, 16, 4)
net.to(device)
net(torch.randint(0, 200, [2, 8 * 13]).to(device))
# epoch___0____loss___8.586270____steps___65760: 0%| | 0/1 [01:21<?, ?it/s] cummax
# epoch___0____loss___6.930531____steps___67040: 0%| | 0/1 [01:21<?, ?it/s] cummax no layer_nor
# epoch___0____loss___7.680687____steps___77840: 0%| | 0/1 [01:35<?, ?it/s] cummax layer_nor
# epoch___0____loss___6.994579____steps___68240: 0%| | 0/1 [01:25<?, ?it/s] cummax cos
# epoch___0____loss___6.707716____steps___70640: 0%| | 0/1 [01:24<?, ?it/s] cummax no sin no cos
# epoch___0____loss___6.895388____steps___65200: 0%| | 0/1 [01:21<?, ?it/s] cummin
# epoch___0____loss___7.079460____steps___66720: 0%| | 0/1 [01:22<?, ?it/s] cummax no x
# epoch___0____loss___6.174834____steps___45360: 0%| | 0/10 [01:00<?, ?it/s] cummax 2 2 no pos
# epoch___0____loss___6.239753____steps___45120: 0%| | 0/10 [01:00<?, ?it/s] cummax 2 2 pos
# epoch___0____loss___6.547979____steps___36240: 0%| | 0/10 [01:00<?, ?it/s] cummax 3 3 no pos
# epoch___0____loss___6.947957____steps___17600: 0%| | 0/10 [01:01<?, ?it/s] src samout
# epoch___0____loss___6.108305____steps___52640: 0%| | 0/10 [02:54<?, ?it/s] src samout
# epoch___0____loss___6.069768____steps___55280: 0%| | 0/10 [03:03<?, ?it/s] src samout
# epoch___0____loss___6.058203____steps___54560: 0%| | 0/10 [01:11<?, ?it/s] current samout
# epoch___0____loss___5.996508____steps___52560: 0%| | 0/10 [01:27<?, ?it/s]
# epoch___0____loss___6.067177____steps___54400: 0%| | 0/10 [01:30<?, ?it/s]
# epoch___0____loss___5.974577____steps___52720: 0%| | 0/10 [01:44<?, ?it/s]
# epoch___0____loss___5.869751____steps___55520: 0%| | 0/10 [01:57<?, ?it/s]
# epoch___0____loss___5.749324____steps___55440: 0%| | 0/10 [02:03<?, ?it/s] maxstate no cat
# epoch___0____loss___5.715099____steps___55440: 0%| | 0/10 [02:26<?, ?it/s] cat
# epoch___0____loss___5.704436____steps___55520: 0%| | 0/10 [02:04<?, ?it/s] x1 +x2+x3
# epoch___0____loss___5.710885____steps___55360: 0%| | 0/10 [02:04<?, ?it/s] x1 +x2+x3 比 cat 牛且减少了参数量
# epoch___0____loss___5.673217____steps___55360: 0%| | 0/10 [02:00<?, ?it/s] out+out1+out2
# epoch___0____loss___5.669157____steps___55360: 0%| | 0/10 [02:13<?, ?it/s]
# epoch___0____loss___5.677723____steps___55360: 0%| | 0/10 [02:42<?, ?it/s]
# epoch___0____loss___5.494996____steps___55360: 0%| | 0/10 [03:43<?, ?it/s]
# epoch___0____loss___5.319009____steps___55280: 0%| | 0/10 [03:42<?, ?it/s] 0.0003
# epoch___0____loss___4.823767____steps___54160: 0%| | 0/10 [03:38<?, ?it/s] 0.0003 结尾 + layer norm
# epoch___0____loss___4.830925____steps___54240: 0%| | 0/10 [03:39<?, ?it/s] 0.0003 都加 + layer norm
# epoch___0____loss___4.843996____steps___56160: 0%| | 0/10 [03:46<?, ?it/s] 0.0003 中间 + relu
# epoch___0____loss___4.821821____steps___55520: 0%| | 0/10 [03:44<?, ?it/s] 0.0003 中间 + gelu
# epoch___0____loss___5.115138____steps___60400: 0%| | 0/10 [04:03<?, ?it/s] 0.0003 中间 + layer norm
# epoch___0____loss___4.672063____steps___55290: 0%| | 0/10 [05:41<?, ?it/s] 双倍
# epoch___0____loss___4.671307____steps___53220: 0%| | 0/10 [05:31<?, ?it/s] 去pos 加 x2
# epoch___0____loss___4.685665____steps___56100: 0%| | 0/10 [05:49<?, ?it/s]
# epoch___0____loss___4.640556____steps___55200: 0%| | 0/10 [05:45<?, ?it/s]
# epoch___0____loss___4.643009____steps___54390: 0%| | 0/10 [05:40<?, ?it/s]
# epoch___0____loss___6.634222____steps___45870: 0%| | 0/10 [04:46<?, ?it/s] cumsum
# epoch___0____loss___4.633770____steps___55080: 0%| | 0/10 [05:48<?, ?it/s]
# epoch___0____loss___4.637132____steps___53820: 0%| | 0/10 [05:42<?, ?it/s]
这段代码定义了一个基于PyTorch的神经网络模型,用于序列到序列的学习任务,如机器翻译或文本生成。以下是代码的主要组成部分和解释:
模型结构
- MaxState 类:
- 这个类是一个自定义的神经网络层,用于处理输入序列。
- 它包含多个线性层(
head0
,head1
,head2
),每个线性层将输入数据映射到一个新的空间。 - 使用
cummax
函数来计算输入序列的累积最大值,这可能是为了捕捉序列中的关键信息。 - 包含一个层归一化(LayerNorm)层,用于对数据进行归一化处理。
- FeedForward 类:
- 这个类实现了前馈神经网络,包含两个线性层。
- 使用门控机制(
gate
线性层)来控制信息流。 - 包含ReLU激活函数和Dropout层,用于增加模型的非线性特性和防止过拟合。
- DecoderLayer 类:
- 这个类定义了一个解码器层,包含三个
MaxState
实例和三个FeedForward
实例。 - 通过将这些层串联起来,模型能够更深入地处理序列数据。
- 这个类定义了一个解码器层,包含三个
- SamOut 类:
- 这个类是整个模型的入口,包含词嵌入层(
em
)和解码器层列表(decoder
)。 - 最后,使用一个线性层将解码器的输出映射到词汇表的大小。
- 这个类是整个模型的入口,包含词嵌入层(
代码执行
- 在
if __name__ == '__main__':
部分,创建了一个SamOut
实例,并将其移动到GPU上。 - 然后,模型被一个随机整数张量(模拟输入序列)调用,以进行前向传播。
大规模情况下的性能
根据代码后面的注释,可以得出以下结论:
- 在大规模数据集上,该模型的损失函数(loss)收敛速度较快。
- 尽管该模型的损失函数收敛下限可能略高于其他模型,但其收敛速度快,且幻觉(可能是过拟合或其他不理想行为)较低。
- 这表明该模型在大规模数据集上具有较好的泛化能力和效率。
总的来说,这个模型在设计上考虑了序列数据的深层处理,并通过使用累积最大值和层归一化等技术来提高模型的性能。在大规模数据集上的表现表明,该模型是一个有效的选择,尤其是在需要快速收敛和低过拟合风险的场景中。
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from glob import glob
from tqdm import tqdm
from model_one import SamOut
import polars as pl
from collections import Counter
from torch.optim.lr_scheduler import CosineAnnealingLR
def train():
voc = pd.read_pickle("total_voc.pkl")
net = SamOut(len(voc["voc"]), 1024, 32, 3)
print(sum([i.shape[0]*i.shape[1] for i in net.parameters() if len(i.shape)>1])+sum([i.shape[0] for i in net.parameters() if len(i.shape)==1]))
net.load_state_dict(torch.load("pretrain_1024_sin.pth"))
net.to("cuda")
opt = torch.optim.Adam(params=net.parameters(), lr=0.0003)
# 设置余弦退火的周期数(每个周期结束后学习率将被重置)
# T_max = 50 # 例如,这里设置为50个epoch
# 初始化余弦退火学习率调度器
loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)
bar = tqdm(range(10))
steps = 0
epoch_loss = []
batch_size=30
scheduler = CosineAnnealingLR(opt, 5000000 //batch_size,eta_min=1e-4)
#
for epoch in bar:
paths=glob("./pre_data_set_*.pkl")
data_set=[]
for ii in range(0,len(paths),2):
for one_path in paths[ii:ii+2]:
data_set+= pd.read_pickle(one_path)
np.random.shuffle(data_set)
loss_list = []
for i in range(0, len(data_set), batch_size):
# weights.append(list(net.state_dict().values())[0])
j = i + batch_size
input_one = data_set[i:j]
out0, _ = net(torch.Tensor(input_one)[:, :-1].int().to("cuda"))
loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
torch.Tensor(input_one)[:, 1:].reshape([-1]).long().to("cuda"))
loss0=torch.mean(2-torch.sum((torch.nn.functional.softmax(out0,-1)-1/len(voc["voc"])),-1))
loss_list.append(loss.item())
bar.set_description(
"epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
loss+=loss0
opt.zero_grad()
loss.backward()
opt.step()
steps += batch_size
scheduler.step()
# if steps%8==0:
# pd.to_pickle(loss_list, "loss91611111")
torch.save(net.state_dict(), "pretrain_1024_sin.pth")
# eval_model()
epoch_loss.append(np.mean(loss_list))
pd.to_pickle(epoch_loss, "loss9161")
def gen_one_voc():
data = pd.read_csv("pretrain_data.csv")
data = data["text"].values.tolist()
data = "".join(data)
count = Counter()
for ii in tqdm(range(0, len(data), len(data) // 8)):
jj = ii + len(data) // 8
for k, v in Counter(data[ii:jj]).items():
count[k] = count.get(k, 0) + v
data = ""
data0 = pd.read_csv("sft_data_multi.csv")
for ii in tqdm(range(0, len(data0), len(data0) // 8)):
jj = ii + len(data0) // 8
for k, v in Counter(data0[ii:jj]).items():
count[k] = count.get(k, 0) + v
data0 = ""
data1 = pd.read_csv("sft_data_single.csv")
for ii in tqdm(range(0, len(data1), len(data1) // 8)):
jj = ii + len(data1) // 8
for k, v in Counter(data1[ii:jj]).items():
count[k] = count.get(k, 0) + v
data1 = ""
# plt.plot(sorted(count.values()))
# plt.show()
count = pd.DataFrame({"voc": count.keys(), "count": count.values()})
voc = count.loc[count["count"] > 100, "voc"].values.tolist()
voc0 = [[[["<|pos_{}_{}|>".format(jj, ii) for jj, ii in enumerate(list(str(i)))], j] for i, j in
enumerate(count.loc[count["count"] <= 100, "voc"].values.tolist())]]
pd.to_pickle(voc, "voc.pkl")
pd.to_pickle(voc0, "voc0.pkl")
def gen_voc():
voc = pd.read_pickle("voc.pkl")
voc0 = pd.read_pickle("voc0.pkl")
voc0 = {j: i for i, j in voc0[0]}
for i in range(6):
for j in range(10):
voc.append("<|pos_{}_{}|>".format(i, j))
voc = ["<|sos|>", "<|user|>", "<|agent|>", "<|pad|>", "<|history|>"] + sorted(voc)
pd.to_pickle({"voc": voc, "voc0": voc0}, "total_voc.pkl")
def gen_pre_data_align(num, total_num):
voc = pd.read_pickle("total_voc.pkl")
voc["voc0"] = [[i, [voc["voc"].index(j) for j in ii]] for i, ii in voc["voc0"].items()]
voc["voc"] = [i for i in voc["voc"]]
voc = {"voc": voc["voc"] + [i for i, j in voc["voc0"]],
"voc_id": [[i] for i in list(range(len(voc["voc"])))] + [j for i, j in voc["voc0"]]}
voc = pd.DataFrame(voc)
# voc=pl.DataFrame(voc)
pre_data = pl.read_csv("pretrain_data.csv")
pre_data = pre_data["text"].to_numpy().tolist()
count = len(pre_data) // total_num
pre_data = pre_data[(num - 1) * count:count * num]
data_set = []
bar = tqdm(range(len(pre_data)))
while pre_data:
bar.update()
one = pre_data.pop()
one = pd.merge(pd.DataFrame({"voc": list(one)}), voc, on="voc", how="left")
thr = np.hstack(one["voc_id"].to_numpy()).tolist()
thr += (518 - len(thr)) * [3]
thr = thr[:512]
data_set.append(thr)
pd.to_pickle(data_set, "pre_data_set_{}.pkl".format(num))
def gen_sft_single_data_align():
voc = pd.read_pickle("total_voc.pkl")
voc["voc0"] = {i: [voc["voc"].index(j) for j in ii] for i, ii in voc["voc0"].items()}
voc["voc"] = {v: i for i, v in enumerate(voc["voc"])}
pre_data = pl.read_csv("sft_data_single.csv")
pre_data = pre_data.to_numpy().tolist()
data_set = []
index_id=0
for h, q, a in tqdm(pre_data):
index_id+=1
one = ["<|user|>"] + list(q) + ["<|agent|>"] + list(a)
one_list = []
for i in one:
voc_id = voc["voc"].get(i, None)
if voc_id != None:
one_list.append(voc_id)
else:
one_list += voc["voc0"].get(i, [3])
one_list += (512 - len(one_list)) * [3]
data_set.append(one_list[:512])
if len(data_set)>1000000:
pd.to_pickle(data_set, "sft_data_single_{}.pkl".format(index_id))
data_set=[]
pd.to_pickle(data_set, "sft_data_single_{}.pkl".format(index_id))
def train_single():
voc = pd.read_pickle("total_voc.pkl")
net = SamOut(len(voc["voc"]), 1024, 32, 3)
net.load_state_dict(torch.load("pretrain_1024_sin.pth"))
net.to("cuda")
opt = torch.optim.Adam(params=net.parameters(), lr=0.00003)
loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)
scheduler = CosineAnnealingLR(opt, 5000000 // 800, eta_min=1e-6)
bar = tqdm(range(1))
steps = 0
epoch_loss = []
for epoch in bar:
paths=glob("./sft_data_*.pkl")
np.random.shuffle(paths)
for o in tqdm(range(0,len(paths),2)):
data_set =[]
for one_path in paths[o:o+2]:
data_set+=pd.read_pickle(one_path)
np.random.shuffle(data_set)
loss_list = []
for i in range(0, len(data_set), 20):
# weights.append(list(net.state_dict().values())[0])
j = i + 20
input_one = data_set[i:j]
out0, _ = net(torch.Tensor(input_one)[:, :-1].int().to("cuda"))
loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
torch.Tensor(input_one)[:, 1:].reshape([-1]).long().to("cuda"))
loss_list.append(loss.item())
bar.set_description(
"epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
opt.zero_grad()
loss.backward()
opt.step()
steps += 20
scheduler.step()
torch.save(net.state_dict(), "pretrain_sft_single_768_sin.pth")
# eval_model()
epoch_loss.append(np.mean(loss_list))
pd.to_pickle(epoch_loss, "loss916")
def show_loss():
from matplotlib import pyplot as plt
# plt.plot(pd.read_pickle("loss916"))
plt.plot(pd.read_pickle("loss9161"))
plt.plot(pd.read_pickle("loss91611"))
plt.plot(pd.read_pickle("loss916111"))
plt.plot(pd.read_pickle("loss91611111"))
plt.legend(["no_layer_norm","layer_norm_center","state_layer_nor"])
plt.show()
if __name__ == '__main__':
# print(pd.read_pickle("loss916"))
# gen_one_voc()
# gen_voc()
# for i in range(17,18):
# gen_pre_data_align(i, 16)
# show_loss()
# train()
# gen_sft_single_data_align()
train_single()
# epoch___0____loss___6.297972____steps___905600: 0%| | 0/10 [13:53<?, ?it/s]
# epoch___3____loss___3.958826____steps___21054080: 30%|███ | 3/10 [5:36:55<13:06:09, 6738.55s/it]
# epoch___6____loss___2.984237____steps___36832160: 60%|██████ | 6/10 [23:38:53<15:45:55, 14188.98s/it]
# epoch___0____loss___2.595664____steps___3948720: 100%|██████████| 1/1 [2:22:14<00:00, 8534.37s/it]
# epoch___0____loss___6.297972____steps___905600: 0%| | 0/10 [13:53<?, ?it/s]
# epoch___3____loss___3.958826____steps___21054080: 30%|███ | 3/10 [5:36:55<13:06:09, 6738.55s/it]
# epoch___6____loss___2.984237____steps___36832160: 60%|██████ | 6/10 [23:38:53<15:45:55, 14188.98s/it]
# epoch___0____loss___2.595664____steps___3948720: 100%|██████████| 1/1 [2:22:14<00:00, 8534.37s/it]
# epoch___3____loss___2.945595____steps___17904560: 30%|███ | 3/10 [20:52:47<43:51:20, 22554.43s/it]
# epoch___0____loss___5.747346____steps___55680: 0%| | 0/10 [03:47<?, ?it/s]
# epoch___1____loss___2.989535____steps___9665840: 10%|█ | 1/10 [11:14:05<55:27:39, 22184.39s/it]
# epoch___0____loss___2.810906____steps___1280560: 0%| | 0/10 [1:27:18<?, ?it/s]
# epoch___0____loss___2.399217____steps___3948720: 100%|██████████| 1/1 [4:11:14<00:00, 15074.85s/it]
#epoch___0____loss___2.239312____steps___1948640: 0%| | 0/1 [2:02:11<?, ?it/s]
# epoch___0____loss___2.269428____steps___3948720: 100%|██████████| 1/1 [4:08:10<00:00, 14890.38s/it]
# epoch___0____loss___2.288093____steps___3948720: 100%|██████████| 1/1 [4:13:04<00:00, 15184.23s/it]
# epoch___0____loss___3.565186____steps___854160: 0%| | 0/10 [58:33<?, ?it/s]
# epoch___0____loss___2.923572____steps___4204480: 0%| | 0/10 [4:49:18<?, ?it/s]
# epoch___0____loss___2.121723____steps___3948720: 100%|██████████| 1/1 [6:07:46<00:00, 22066.76s/it]
# epoch___0____loss___5.660226____steps___7120: 0%| | 0/10 [00:51<?, ?it/s]
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from glob import glob
from tqdm import tqdm
from model_one import SamOut
import polars as pl
from collections import Counter
def train():
voc = pd.read_pickle("total_voc.pkl")
net = SamOut(len(voc["voc"]), 768, 32, 16)
print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum(
[i.shape[0] for i in net.parameters() if len(i.shape) == 1]))
net.load_state_dict(torch.load("pretrain_768.pth"))
net.to("cuda")
opt = torch.optim.Adam(params=net.parameters(), lr=0.00002)
loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)
bar = tqdm(range(10))
steps = 0
epoch_loss = []
batch_size = 30
for epoch in bar:
paths = glob("./pre_data_set_*.pkl")
data_set = []
for ii in range(0, len(paths), 2):
for one_path in paths[ii:ii + 2]:
data_set = pd.read_pickle(one_path)
np.random.shuffle(data_set)
loss_list = []
for i in range(0, len(data_set), batch_size):
# weights.append(list(net.state_dict().values())[0])
j = i + batch_size
input_one = data_set[i:j]
out0, _ = net(torch.Tensor(input_one)[:, :-1].int().to("cuda"))
loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
torch.Tensor(input_one)[:, 1:].reshape([-1]).long().to("cuda"))
loss_list.append(loss.item())
bar.set_description(
"epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
opt.zero_grad()
loss.backward()
opt.step()
steps += batch_size
torch.save(net.state_dict(), "pretrain_768.pth")
# eval_model()
epoch_loss.append(np.mean(loss_list))
pd.to_pickle(epoch_loss, "loss916")
def gen_one_voc():
data = pd.read_csv("pretrain_data.csv")
data = data["text"].values.tolist()
data = "".join(data)
count = Counter()
for ii in tqdm(range(0, len(data), len(data) // 8)):
jj = ii + len(data) // 8
for k, v in Counter(data[ii:jj]).items():
count[k] = count.get(k, 0) + v
data = ""
data0 = pd.read_csv("sft_data_multi.csv")
for ii in tqdm(range(0, len(data0), len(data0) // 8)):
jj = ii + len(data0) // 8
for k, v in Counter(data0[ii:jj]).items():
count[k] = count.get(k, 0) + v
data0 = ""
data1 = pd.read_csv("sft_data_single.csv")
for ii in tqdm(range(0, len(data1), len(data1) // 8)):
jj = ii + len(data1) // 8
for k, v in Counter(data1[ii:jj]).items():
count[k] = count.get(k, 0) + v
data1 = ""
# plt.plot(sorted(count.values()))
# plt.show()
count = pd.DataFrame({"voc": count.keys(), "count": count.values()})
voc = count.loc[count["count"] > 100, "voc"].values.tolist()
voc0 = [[[["<|pos_{}_{}|>".format(jj, ii) for jj, ii in enumerate(list(str(i)))], j] for i, j in
enumerate(count.loc[count["count"] <= 100, "voc"].values.tolist())]]
pd.to_pickle(voc, "voc.pkl")
pd.to_pickle(voc0, "voc0.pkl")
def gen_voc():
voc = pd.read_pickle("voc.pkl")
voc0 = pd.read_pickle("voc0.pkl")
voc0 = {j: i for i, j in voc0[0]}
for i in range(6):
for j in range(10):
voc.append("<|pos_{}_{}|>".format(i, j))
voc = ["<|sos|>", "<|user|>", "<|agent|>", "<|pad|>", "<|history|>"] + sorted(voc)
pd.to_pickle({"voc": voc, "voc0": voc0}, "total_voc.pkl")
def gen_pre_data_align(num, total_num):
voc = pd.read_pickle("total_voc.pkl")
voc["voc0"] = [[i, [voc["voc"].index(j) for j in ii]] for i, ii in voc["voc0"].items()]
voc["voc"] = [i for i in voc["voc"]]
voc = {"voc": voc["voc"] + [i for i, j in voc["voc0"]],
"voc_id": [[i] for i in list(range(len(voc["voc"])))] + [j for i, j in voc["voc0"]]}
voc = pd.DataFrame(voc)
# voc=pl.DataFrame(voc)
pre_data = pl.read_csv("pretrain_data.csv")
pre_data = pre_data["text"].to_numpy().tolist()
count = len(pre_data) // total_num
pre_data = pre_data[(num - 1) * count:count * num]
data_set = []
bar = tqdm(range(len(pre_data)))
while pre_data:
bar.update()
one = pre_data.pop()
one = pd.merge(pd.DataFrame({"voc": list(one)}), voc, on="voc", how="left")
thr = np.hstack(one["voc_id"].to_numpy()).tolist()
thr += (518 - len(thr)) * [3]
thr = thr[:512]
data_set.append(thr)
pd.to_pickle(data_set, "pre_data_set_{}.pkl".format(num))
def gen_sft_single_data_align():
voc = pd.read_pickle("total_voc.pkl")
voc["voc0"] = {i: [voc["voc"].index(j) for j in ii] for i, ii in voc["voc0"].items()}
voc["voc"] = {v: i for i, v in enumerate(voc["voc"])}
pre_data = pl.read_csv("sft_data_single.csv")
pre_data = pre_data.to_numpy().tolist()
data_set = []
index_id = 0
for h, q, a in tqdm(pre_data):
index_id += 1
one = ["<|user|>"] + list(q) + ["<|agent|>"] + list(a)
one_list = []
for i in one:
voc_id = voc["voc"].get(i, None)
if voc_id != None:
one_list.append(voc_id)
else:
one_list += voc["voc0"].get(i, [3])
one_list += (512 - len(one_list)) * [3]
data_set.append(one_list[:512])
if len(data_set) > 1000000:
pd.to_pickle(data_set, "sft_data_single_{}.pkl".format(index_id))
data_set = []
pd.to_pickle(data_set, "sft_data_single_{}.pkl".format(index_id))
def train_single():
voc = pd.read_pickle("total_voc.pkl")
net = SamOut(len(voc["voc"]), 512, 32, 8)
net.load_state_dict(torch.load("pretrain_sft_single.pth"))
net.to("cuda")
opt = torch.optim.Adam(params=net.parameters(), lr=0.000003)
loss_func0 = torch.nn.CrossEntropyLoss(ignore_index=3)
bar = tqdm(range(2))
steps = 0
epoch_loss = []
for epoch in bar:
paths = glob("./sft_data_*.pkl")
np.random.shuffle(paths)
for o in range(0, len(paths), 2):
data_set = []
for one_path in paths[o:o + 2]:
data_set += pd.read_pickle(one_path)
np.random.shuffle(data_set)
loss_list = []
for i in range(0, len(data_set), 80):
# weights.append(list(net.state_dict().values())[0])
j = i + 80
input_one = data_set[i:j]
out0, _ = net(torch.Tensor(input_one)[:, :-1].int().to("cuda"))
loss = loss_func0(out0.reshape([-1, out0.shape[-1]]),
torch.Tensor(input_one)[:, 1:].reshape([-1]).long().to("cuda"))
loss_list.append(loss.item())
bar.set_description(
"epoch___{}____loss___{:.6f}____steps___{}".format(epoch, np.mean(loss_list), steps))
opt.zero_grad()
loss.backward()
opt.step()
steps += 80
torch.save(net.state_dict(), "pretrain_sft_single.pth")
# eval_model()
epoch_loss.append(np.mean(loss_list))
pd.to_pickle(epoch_loss, "loss916")
def load_model_and_voc(device="cpu"):
voc = pd.read_pickle("total_voc.pkl")
net = SamOut(len(voc["voc"]), 1024, 32, 3)
# net = SamOut(len(voc["voc"]), 512, 32, 8)
print(sum([i.shape[0] * i.shape[1] for i in net.parameters() if len(i.shape) > 1]) + sum(
[i.shape[0] for i in net.parameters() if len(i.shape) == 1]))
# net.load_state_dict(torch.load("pretrain_768.pth", map_location=device))
# net.load_state_dict(torch.load("pretrain_sft_single.pth", map_location=device))
net.load_state_dict(torch.load("pretrain_sft_single_768_sin.pth", map_location=device))
# net.load_state_dict(torch.load("pretrain.pth", map_location=device))
net.to(device)
net.eval()
return net, voc
def gen_token(voc, model, prompt, max_len, rp=1.2, temp=0.5, top_k=16, device="cpu"):
print("agent:", end="", flush=True)
for _ in range(max_len):
prompt_list = []
for i in prompt:
if i not in voc["voc"]:
prompt_list += [voc["voc"].index(ii) for ii in voc["voc0"].get(i)]
else:
prompt_list.append(voc["voc"].index(i))
out, _ = model(torch.Tensor([prompt_list]).to(device).long())
out = out[:, -1:]
# 重复抑制
for token_id in enumerate(prompt_list):
out[:, :, token_id] /= rp
score = torch.softmax(out, -1)[0, 0]
score, score_index = torch.sort(score)
score=score.detach().numpy()
score_sum = np.cumsum(score)
score_index = score_index.detach().numpy()
score=score[score_sum>0.2]
score_index=score_index[score_sum>0.2]
score=score[::-1]
score_index=score_index[::-1]
# 温度
out = score / temp
v= out[:min(top_k, score.size)]
idx_next = torch.multinomial(torch.Tensor(v), num_samples=1, generator=None)
if voc["voc"][score_index[idx_next.item()]] == "<|sos|>":
break
prompt += [voc["voc"][score_index[idx_next.item()]]]
print(prompt[-1], end="", flush=True)
def t_infre():
model, voc = load_model_and_voc()
while True:
text = input("user:")
gen_token(voc, model, ["<|user|>"] + list("{}".format(text)) + ["<|agent|>"], 320)
print()
if __name__ == '__main__':
# print(pd.read_pickle("loss916"))
# gen_one_voc()
# gen_voc()
# for i in range(17,18):
# gen_pre_data_align(i, 16)
# train()
# gen_sft_single_data_align()
# train_single()
# sft 推理 一本正经的胡说八道已练成
t_infre()
标签:loss,voc,0.1,self,torch,Samout,len,V2,data
From: https://blog.csdn.net/weixin_32759777/article/details/144165816