首页 > 其他分享 >nn.Embedding torch.nn.Embedding

nn.Embedding torch.nn.Embedding

时间:2023-10-28 20:22:34浏览次数:26  
标签:文件 训练 nn torch 矩阵 Embedding 向量

nn.Embedding 

torch.nn.Embedding

 

随机初始化词向量矩阵:这种方式很容易理解,就是使用self.embedding = torch.nn.Embedding(vocab_size, embed_dim)命令直接随机生成个初始化的词向量矩阵,此时的向量值符合正态分布N(0,1),这里的vocab_size是指词向量矩阵能表征的词的个数,这个数值即是词向量文件中词的数量加1(加1的原因是,如果某个词在词向量文件中不存在,则获取不到索引,也就无法在词向量矩阵中获取对应的向量,这时我们默认这个词的索引为0,即将词向量的第一行作为这个词的向量表征。使用预训练的词向量文件时,这个方法同样适用),embed_dim是指表征每个词时,向量的维度(可自定义,如256)。对于随机初始化词向量矩阵的方式,词向量文件的生成方式一般是将当前所有的文本数据(包括训练数据、验证数据、测试数据)进行切词,再对所有词进行聚合统计,保留词的数量大于某个阈值(比如3)的词,并进行索引编号(编号从1开始,0作为上面提到的不在词向量文件中的其他词的索引),进而生成词向量文件。顺便提一句,词向量矩阵的初始化的方式也有很多种,比如Xavier、Kaiming初始化方法。
        使用预训练的词向量文件初始化词向量矩阵:本质上,词向量矩阵的作用是实现文本的向量表征,因此,如何用更合适的向量表示文本,逐渐成为了一个热门研究方向。预训练的词向量文件便是其中的一个研究成果,如通过word2vec、glove等预训练模型生成的词向量文件,通过大量的训练数据,来生成词的向量表征。以word2vec为例,训练后生成的词向量文件是以离线配置文件的形式存在,可通过gensim工具包进行加载,具体命令是wvmodel = gensim.models.KeyedVectors.load_word2vec_format(word2vec_file, binary=False, encoding='utf-8', unicode_errors='ignore'),加载后,可通过wvmodel.key_to_index获取词向量文件(要对词向量文件中的词索引进行重新编号,原索引从0开始,调整为从1开始,0作为不在词向量文件中的词的索引),通过wvmodel.get_vector("xxx")获取词向量文件中每个词对应的向量,将词向量文件中所有词对应的向量聚合在一起后(聚合的方式是,每个词的向量表征,按照词的索引,填充在词向量矩阵对应的位置),生成预训练词向量矩阵weight,再通过self.embedding = torch.nn.Embedding.from_pretrained(weight, freeze=False)完成词向量矩阵的初始化,参数freeze的作用,是指明训练时是否更新词向量矩阵的权重值,True为不更新,默认为True,等同于self.embedding.weight.requires_grad = False)。

         还有个细节需要介绍下,在获取到预训练的词向量文件后,由于预训练的词向量文件很大,因此在后续的训练过程中,可能会出现内存不足的错误,此时可对词向量文件及预训练词向量矩阵进行调整,具体来说,先对我们本身任务的所有文本数据进行切词统计,保留数量超过一定阈值的词,作为词向量文件(就是随机初始化词向量矩阵时,词向量文件的生成方法),再利用这个词向量文件,配合wvmodel.get_vector("xxx"),获取预训练词向量矩阵weight,最后进行后续的词向量矩阵初始化过程。这样操作之后,由于词向量文件中词的数量减少,词向量矩阵的行数减少,内存占用会随之减少很多。另外,生成词向量的预训练方法还有很多,参见【通俗易懂的词向量】。

 

 

转自:

https://www.cnblogs.com/emanlee/p/17455844.html

https://blog.csdn.net/qq_39439006/article/details/126760701

 

标签:文件,训练,nn,torch,矩阵,Embedding,向量
From: https://www.cnblogs.com/emanlee/p/17794553.html

相关文章

  • Cannot connect to the Docker
    执行docker基础命令失败!CannotconnecttotheDockerdaemonatunix:///var/run/docker.sock.Isthedockerdaemonrunning?原因:docker服务没有启动。解决方法:执行systemctlstartdocker即可。......
  • javaweb--JDBC的API-Connection
    1、获取执行SQL对象2、管理事务setAutoCommit(bool)true为自动提交false为手动提交commit()提交事务rollback()回滚事务packagecom.avb.jdbc;importjava.sql.Connection;importjava.sql.DriverManager;importjava.sql.SQLException;importjava.sql.Statement;public......
  • ERROR: Cannot unpack file C:\Users\17482\AppData\Local\Temp\pip-unpack-9g9
    ERROR:CannotunpackfileC:\Users\17482\AppData\Local\Temp\pip-unpack-9g93t3zt\simple.html(downloadedfromC:\Users\17482\AppData\Local\Temp\pip-req-build-35ukmesa,content-type:text/html);cannotdetectarchiveformatERROR:Cannotdeterm......
  • AtCoder Beginner Contest(abc) 310
    B-StrictlySuperior难度:⭐题目大意给定n个商品的价格,每个商品还有若干个属性,请问是否存在一个商品是另外一个商品的上位品;上位品的定义分两种,一是价格相同,但是商品A的属性不仅包括了商品B的属性,还比商品B多了至少一个属性;二是如果两商品的属性相同,但是......
  • pytorch:1.12-gpu-py39-cu113-ubuntu20.04
    docker-compose安装unbuntu20.04version:'3'services:ubuntu2004:image:ubuntu:20.04ports:-'2256:22'-'3356:3306'-'8058:80'volumes:-my-volume:/datacommand:tail......
  • AtCoder Beginner Contest 325
    感觉错失了上分机会A-Takahashisan(abc325A)题目大意给定姓和名,输出尊称,即姓+san。解题思路按照题意模拟即可。神奇的代码#include<bits/stdc++.h>usingnamespacestd;usingLL=longlong;intmain(void){ios::sync_with_stdio(false);cin.tie(......
  • Python 利用pandas和mysql-connector获取Excel数据写入到MySQL数据库
    如何将Excel数据插入到MySQL数据库中在实际应用中,我们可能需要将Excel表格中的数据导入到MySQL数据库中,以便于进行进一步的数据分析和处理。本文将介绍如何使用Python将Excel表格中的数据插入到MySQL数据库中。导入必要的库首先,我们需要导入pandas库和MySQLConnector/Python库......
  • [Spring框架学习]SSM 整合,使用maven构建项目的时候,启动项目报错class path resource
    错误:classpathresource[config/spring/springmvc.xml]cannotbeopenedbecauseitdoesnotexist错误原因:找不到我的springmvc.xml,在下面web.xml中是我引用路径,网上找到问题classpath指向路径不是resource路径,所以一直找不到我的xml文件,classpath:到你的class路径......
  • Python:爬取某软件站数据报错requests.exceptions.SSLError: HTTPSConnectionPool(hos
    使用Python爬取某网站数据时候,之前一直是好好的。突然就报错:requests.exceptions.SSLError:HTTPSConnectionPool(host='api.***.cn',port=443):Maxretriesexceededwithurl:/accounty1/login?analysis............检查发现,可能是IP地址存在代理导致网络环境一场。可以检......
  • anaconda+pytorch+pycharm
    1、安装anaconda,使用conda新建虚拟环境condacreate-npytorchpython=3.9numpymatplotlibpandasjupyternotebook(环境名为pytorch)condaactivatepytorchcondadeactivate2、在新建的虚拟环境下面下载pytorchcondainstallpytorchtorchvisiontorchaudio-cpytorch......