首页 > 其他分享 >word2vec使用skip-gram实现

word2vec使用skip-gram实现

时间:2022-08-24 17:35:15浏览次数:66  
标签:word2vec center skip self like gram context data size

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm 
sentences = ["jack like dog", "jack like cat", "jack like animal",
             "dog cat animal", "banana apple cat dog like", "dog fish milk like",
             "dog cat animal like", "jack like apple", "apple like", "jack like banana",
             "apple banana jack movie book music like", "cat dog hate", "cat dog like"]

sentences_list = " ".join([i for i in sentences]).split(" ")
# 词到下标的映射
vocab = list(set(sentences_list))
word2idx = {j: i for i, j in enumerate(vocab)}
idx2word = {i: j for i, j in enumerate(vocab)}
vocab_size = len(vocab)
window_size = 2
embedding_size = 2

def make_data(seq_data):
    context_arr = []
    center = []
    context = []
    skip_gram = []
    seq_data = " ".join([i for i in seq_data]).split()
    for sen in seq_data:
        for step in range(window_size, len(sen) - window_size):
            # 中心词
            center = step
            # 上下文
            context_arr = list(range(step - window_size, step)) + list(range(step + 1, step + window_size))
            for context_i in context_arr:
                skip_gram.append([np.eye(vocab_size)[word2idx[seq_data[center]]], context_i])
    input_data = []
    target_data = []
    for a, b in skip_gram:
        input_data.append(a)
        target_data.append(b)
    return torch.FloatTensor(input_data), torch.LongTensor(target_data)

 class my_dataset(Dataset):
    def __init__(self, input_data, target_data):
        super(my_dataset, self).__init__()
        self.input_data = input_data
        self.target_data = target_data
 
    def __getitem__(self, index):
        return self.input_data[index], self.target_data[index]
 
    def __len__(self):
        return self.input_data.size(0)  # 返回张量的第一个维度
    
# 输入单词,输出上下文
class SkipGram(nn.Module):
    def __init__(self, embedding_size):
        super(SkipGram, self).__init__()
        self.embedding_size = embedding_size
        self.fc1 = torch.nn.Linear(vocab_size, self.embedding_size)
        self.fc2 = torch.nn.Linear(self.embedding_size, vocab_size)
        self.loss = nn.CrossEntropyLoss()
 
    def forward(self, center, context):
        """
        :param center: [Batch_size]
        :param context:[Batch_size, vocab_size]
        :return:
        """
        center = self.fc1(center)
        center = self.fc2(center)
        loss = self.loss(center, context)
        return loss
    
batch_size = 2
center_data, context_data = make_data(sentences)
train_data = my_dataset(center_data, context_data)
train_loader = DataLoader(train_data, batch_size, shuffle=True)
epochs = 5
model = SkipGram(embedding_size=embedding_size)
model.train()

optim = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
    loop = tqdm(enumerate(train_loader), total = len(train_loader))
    for index, (center, context) in loop:
        loss = model(center, context)
        loop.set_description(f'Epoch [{epoch}/{epochs}]')
        loop.set_postfix(loss = loss.item())
        optim.zero_grad()
        loss.backward()
        optim.step()

  

标签:word2vec,center,skip,self,like,gram,context,data,size
From: https://www.cnblogs.com/ygxdbmx/p/16620960.html

相关文章