首页 > 其他分享 >torch.nn.Embedding的导入与导出

torch.nn.Embedding的导入与导出

时间:2024-09-12 20:24:29浏览次数:9  
标签:nn weight torch Embedding word1 embedding

简介及导入转自:torch.nn.Embedding使用

在RNN模型的训练过程中,需要用到词嵌入,使用torch.nn.Embedding可以快速的完成:只需要初始化torch.nn.Embedding(n,m)即可(n是单词数,m是词向量的维度)(n是嵌入字典的大小,m是嵌入向量的维度。)。

注意: embedding开始是随机的,在训练的时候会自动更新。

简单使用

举个简单的例子:

  • 输入:word1和word2是两个长度为3的句子,保存的是单词所对应的词向量的索引号。
  • 输出:随机生成(4,5)维度大小的embedding,可以通过embedding.weight查看embedding的内容。
  • 过程:输入word1时,embedding会输出第0、1、2行词向量的内容; word2同理。
import torch

word1 = torch.LongTensor([0, 1, 2])
word2 = torch.LongTensor([3, 1, 2])
embedding = torch.nn.Embedding(4, 5)
 
print(embedding.weight)
print('word1:')
print(embedding(word1))
print('word2:')
print(embedding(word2))

image

导出

创建一个嵌入层并将其导出为numpy数组。

import torch
import numpy as np

# 创建嵌入层
embedding = torch.nn.Embedding(10, 5)

# 将权重转换为numpy数组
embedding_weights = embedding.weight.data.numpy()

# 保存权重到文件
np.savetxt("embedding_weights.txt", embedding_weights)

导入

导入已经训练好的词向量,需要设置训练过程中不更新(固定embedding)。

如下所示,emb是已经训练得到的词向量,先初始化等同大小的embedding,然后将emb的数据复制过来,最后一定要设置weight.requires_grad为False。

self.embedding = torch.nn.Embedding(emb.size(0), emb.size(1))
self.embedding.weight = torch.nn.Parameter(emb)

# 固定embedding
self.embedding.weight.requires_grad = False

标签:nn,weight,torch,Embedding,word1,embedding
From: https://www.cnblogs.com/kingwz/p/18411016

相关文章

  • GNN图神经网络简单理解
    GNN简单理解文章目录一、GNN图神经网络综述1什么是图1.1图基础1.2图的分类1.3数据成图1.3.1图像转图1.3.2文本转图1.3.3其他转图1.4图结构化数据的问题类型1.4.1图层面任务graph-leveltask1.4.2节点层面任务node-leveltask1.4.3边层面任务edge-leve......
  • js | TypeError: Cannot read properties of null (reading ‘indexOf’) 【解决】
    js|TypeError:Cannotreadpropertiesofnull(reading‘indexOf’)【解决】描述概述在前端开发中,遇到TypeError:Cannotreadpropertiesofnull(reading'indexOf')这类错误并不罕见。这个错误通常表明你试图在一个null值上调用indexOf方法,而null是一......
  • 为什么需要用到channel
    Channel是Go语言中并发编程的核心工具之一,主要用于解决以下问题:1.数据传递和通信在并发编程中,不同的goroutine可能需要交换数据。使用channel可以安全地在goroutine之间传递数据,而无需显式地使用锁。channel提供了类型安全的通信机制,使得数据传输既简洁又安全。2.......
  • 什么是golang中的channel
    在Go语言中,channel是一种用于在goroutine之间进行通信和同步的工具。它允许一个goroutine发送数据到channel,另一个goroutine从channel接收数据,从而实现并发编程中的数据交换。 Channel的关键特性类型安全:每个channel都有一个指定的类型,确保发送到channel的......
  • MySQL学习笔记(三)InnoDB索引
    索引概念        索引在关系型数据库中,是一种单独的、物理的对数据库表中的一列或者多列值进行排序的一种存储结构,它是某个表中一列或者若干列值的集合,还有指向表中物理标识这些值的数据页的逻辑指针清单。        索引的作用相当于图书的目录,可以根据目......
  • torch.normal的用法和实例说明 normal函数的用法? 正态分布?
    torch.normal()是PyTorch中生成正态分布(也称为高斯分布)随机数据的函数。正态分布的特点是数据集中在均值附近,标准差描述数据的散布情况。接下来,详细解释正态分布和torch.normal()的用法。1、什么是正态分布?正态分布(NormalDistribution)是一种常见的概率分布,用两个......
  • 新版本torchtext的安装办法
    之前导入torchtext的时候出现报错信息“nomodulenamedtorchtext”,通过上网搜索加上自己摸索发现torchtext版本要与自己的pytorch版本对应网上给出的版本对应如下图但是没有最新的版本对应(目前是2024年9月份,pytorch版本已经到了2.2.*)接下来给出教程首先自己确定一下自己pyt......
  • pytorch安装: cuda、cudatoolkit、torch版本对照
    在PyTorch官网上有如下安装对照表,同时也有历史版本安装对照表从零开始配置python深度学习环境大概有如下配置步骤:方案一:电脑安装显卡驱动,然后安装CUDA、cuDNN,安装miniconda3。前面都是在电脑基础环境配置,后面的操作都是在conda环境中,安装torch、cudatoolkits等深度学习包方......
  • 通过ModelScope开源Embedding模型将图片转换为向量
    本文介绍如何通过ModelScope魔搭社区中的视觉表征模型将图片转换为向量,并入库至向量检索服务DashVector中进行向量检索。ModelScope魔搭社区旨在打造下一代开源的模型即服务共享平台,为泛AI开发者提供灵活、易用、低成本的一站式模型服务产品,让模型应用更简单。ModelScope魔搭......
  • ROS2 - Moveit2 - Planning with Approximated Constraint Manifolds(使用近似约束流
    使用近似约束流形进行规划OMPL支持自定义约束,以使规划轨迹遵循所需的行为。约束可以在关节空间和笛卡尔空间中定义,后者基于方向或位置。在规划轨迹时,每个关节状态都需要遵循所有设置的约束,默认情况下,这是通过拒绝采样来执行的。然而,这可能会导致非常长的规划时间,特别是当约束非......