首页 > 其他分享 >Pytorch中使用Embedding报错'IndexError'的解决方法

Pytorch中使用Embedding报错'IndexError'的解决方法

时间:2023-08-16 11:01:29浏览次数:51  
标签:IndexError embedding nn torch Pytorch 报错 Embedding import word

简介

  Pytorch中nn.Embedding为针对词向量的层,其用来实现词与词向量的映射。其调用形式如下

nn.Embedding(
    num_embeddings: int, embedding_dim: int, padding_idx: int | None = None,
    max_norm: float | None = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
    sparse: bool = False
)

参数解释

num_embeddings: 词典大小尺寸,类型为int,代表输入词的大小

embedding_dim: 每个词创建多少维度用于表示,类型为int

padding_idx: 句子不一样长时,后面所需的填充id

max_norm: 最大范数,如果词维度超过该范数便需要归一化

norm_type: 利用(默认:2)范数计算

scale_grad_by_freq: 根据单词在mini-batch中出现的频率,对梯度进行放缩。默认为False.

sparse: 若为True,则与权重矩阵相关的梯度转变为稀疏张量。


问题描述

  在使用nn.Embedding时报错 IndexError: index out of range in self,具体如下图:

报错图片

  报错含义为索引超出界限,错误原因可由以下例子来说明。


解决方法

import torch
from torch import nn

embedding = nn.Embedding(4, 1)
word = [
    [1, 2, 3, 2],
    [2, 3, 4, 1]
]

embed = embedding(torch.LongTensor(word))
print(embed)

  在该例中,Embedding的参数含义为:词的数目为4个,将每个词映射为长度为1的向量;word为语句的标签编码。可以看到语句中有 "1", "2", "3", "4" 4种词,按照上述描述,词的数目就应该为4,但该程序却抛出上图的报错。将程序更改为下述则便不会报错。

import torch
from torch import nn

embedding = nn.Embedding(5, 1)
word = [
    [1, 2, 3, 2],
    [2, 3, 4, 1]
]

embed = embedding(torch.LongTensor(word))
print(embed)

  为什么词的数目要比句子中数目多一个?是因为在Embedding层中,num_embeddings 可以理解为”索引的尺寸“,即当语句中出现”4“时,其索引应当为0-4,则其大小应为5,即 num_embeddings 应至少设置为5,又如下例:

import torch
from torch import nn

embedding = nn.Embedding(7, 1)
word = [
    [1, 2, 3, 2],
    [2, 3, 6, 1]
]

embed = embedding(torch.LongTensor(word))
print(embed)

  语句中有索引“6”,则其应当为0-6num_embeddings 则应当设为7。

参考来源:

标签:IndexError,embedding,nn,torch,Pytorch,报错,Embedding,import,word
From: https://www.cnblogs.com/ToryRegulus/p/17633208.html

相关文章

  • PyTorch神经网络工具箱-新手笔记
    神经网络核心组件利用PyTorch神经网路工具箱设计神经网络就像搭积木一样,可以极大简化构建模型的任务。神经网络核心组件如下:层:神经网络的基本结构,将输入张量转换为输出张量。模型:由层构成的网络。损失函数:参数学习的目标函数,通过最小化损失函数来学习各种参数。优化器:如在使损失值......
  • ubuntu 安装Android studio报错
    运行命令./studio.sh报错:CompileCommand:excludecom/intellij/openapi/vfs/impl/FilePartNodeRoot.trieDescendboolexclude=true[0.118s][error][jfr,startup]'intsun.nio.fs.UnixNativeDispatcher.init()'java.lang.UnsatisfiedLinkError:'intsun.nio.fs.......
  • git checkout 分支报错 error: invalid path
    同事提交了一波代码后,发现怎么也切换不到这个分支了百度后发现windows电脑的git路径不支持空格和特殊符号,让同事把路径中空格或者特殊符号删了就可以解决了 ......
  • zabbix-proxy报错:cannot send list of active checks to “x.x.x.x“:delete from hos
    最近新部署了zabbix,两台zabbix-proxy访问一台zabbix-serverproxy的日志里一直都有这些数据62827:20230813:032210.216cannotsendlistofactivechecksto"10.x.x.x":host[prod-nacos-2.sugon.local]notfound162826:20230813:032212.459cannotsendlistofactivech......
  • grafana报错too many outstanding requests
    grafana报错toomanyoutstandingrequests1、问题描述当grafana使用loki作为数据源查询数据时,面板报错toomanyoutstandingrequestsloki的版本是2.8.0报错截图2、解决办法loki的配置文件中添加下面这两行query_scheduler:max_outstanding_requests_per_tenant:1000......
  • dav 编译报错 v8内存溢出
    dav编译报错v8内存溢出FATALERROR:ReachedheaplimitAllocationfailed-JavaScriptheapoutofmemory  到node_modules中/.bin/roadhog.cmd把最后一句改成endLocal&goto#_undefined_#2>NUL||title%COMSPEC%&"%_prog%"--max_old_space_size=8192......
  • Oracle启动监听报错:The listener supports no services或出现 unknown状态解决
    1、查看$ORACLE_HOME/network/admin/listener.ora文件中的host是否正确,能不能ping通2、查看$ORACLE_HOME/network/admin/tnsnames.ora文件中的host是否与listener.ora中的一致3、查看/etc/hosts文件中的127.0.0.1是不是localhost,listener.ora中host跟这里的是否一样4、登录数......
  • SVN打开文件报错
    问题描述:'D:\WorkSpace\vvvvv\XXXXXXX-K3Cloud'isalreadylocked.“ 原因分析:上一次异常操作了。解决方案:找到项目目录:右键打开   ......
  • 引入feign注入报错 org.springframework.beans.factory.NoSuchBeanDefinitionExceptio
    引入feign注入报错org.springframework.beans.factory.NoSuchBeanDefinitionException解决[172.16.22.215]out:Causedby:org.springframework.beans.factory.NoSuchBeanDefinitionException:Noqualifyingbeanoftype'com.test.mydock.api.FeignRemoteTestService�......
  • json字符串转换对象或列表,多了字段不会报错
    json字符串转换对象或列表,多了字段不会报错//DEMO1转换对象应用riskIdpublicclassItem{privateStringid;privateStringrate;publicItem(Stringid,Stringrate){this.id=id;this.rate=rate;}@Overridepubl......