首页 > 其他分享 >深度学习--RNN基础

深度学习--RNN基础

时间:2023-04-25 18:22:15浏览次数:33  
标签:word RNN nn -- torch 单词 深度 run

深度学习--RNN基础

​ RNN(Recurrent Neutral Network,循环神经网络),主要应用于自然语言处理NLP。

RNN表示方法

1.编码

因为Pytorch中没有String类型数据,需要引入序列表示法(sequence representation)对文本进行表示。

​ 表示方法:[seq_len:一句话的单词数,feature_len:每个单词的表示方法]

文本信息的表达方式:

  1. one-hot:多少个单词就有多少位编码。缺点:非常稀疏(sparse),维度太高,缺乏语意相关性(semantic similarity)
  2. word2vec
import torch
import torch.nn as nn

word_to_ix = {"hello":0,"world":1}

embeds = nn.Embedding(2,5)  #2行5列  一共有2个单词,用5位的feature来表示
lookup_tensor = torch.tensor([word_to_ix["hello"]],dtype=torch.long)
hello_embed = embeds(lookup_tensor)
print(hello_embed)
#tensor([[-0.2169,  0.3653,  0.7812, -0.8420, -0.2815]],
#       grad_fn=<EmbeddingBackward0>)


word_to_ix = {"hello":0,"world":1}

embeds = nn.Embedding(2,5)  #2行5列  一共有2个单词,用5位的feature来表示
lookup_tensor = torch.tensor([word_to_ix["hello"]],dtype=torch.long)
hello_embed = embeds(lookup_tensor)
print(hello_embed)
#tensor([[-0.2169,  0.3653,  0.7812, -0.8420, -0.2815]],
#       grad_fn=<EmbeddingBackward0>)
  1. glove
from torchnlp.word_to_vector import GloVe
vectors = GloVe()
vector["hello"]

2. batch

两种引入方式:[word num, b, word vec] 或者 [b, word num, word vec ] 第一种常用

RNN原理

naive version

对每一个单词进行 x@w1+b1 操作,每个单词都有不同的参数

Weight sharing

共享参数,用同一个w和b

Consistent memory 持续记忆

每一个单词运算表示:x @ wxh +h @ whh

增加了一个h单元,相当于一个memory单元。

总结:

RNN的网络为yt = why*ht

ht=激活函数(Whh* ht-1+Wxh *xt) 常用的为tanh

模型的反向传播:BPTT(back propagation through time)

RNN层的使用方法

run = nn.RNN(100,10)  #word vec 单词的表示位数, memory 记忆节点

run._parameters.keys()
#odict_keys(['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'])

run.weight_hh_l0.shape, run.weight_ih_l0.shape
#(torch.Size([10, 10]), torch.Size([10, 100]))

run.bias_hh_l0.shape, run.bias_ih_l0.shape
#(torch.Size([10]), torch.Size([10]))

1.nn.RNN

nn.RNN(input_size:单词的表示方法维度,hidden_size:记忆的维度:,num_layers:默认是1)

前向传播,一步到位 out, ht = forward(x, h0)

​ x:[一句话单词数,batch几句话,表示的维度]

​ h0/ht:[层数,batch,记忆(参数)的维度]

​ out:[一句话单词数,batch,参数的维度]

import torch
import torch.nn as nn

run = nn.RNN(input_size=100, hidden_size=20, num_layers=1)
print(run)
#RNN(100, 20)

x = torch.randn(10,3,100)
h = torch.zeros(1,3,20)
out,h1 = run(x,h)
print(out.shape,h1.shape)
#torch.Size([10, 3, 20]) torch.Size([1, 3, 20])

2. nn.RNNCell:只完成一个计算

nn.RNNCell(input_size:单词的表示方法维度,hidden_size:记忆的维度:,num_layers:默认是1)

前向传播:ht=rnncell(xt,ht_1)

​ xt:[batch,word维度]

​ ht_1/ht:[层数,batch,参数的维度]

#RNNCell
x = torch.randn(10,3,100)
cell = nn.RNNCell(100,20)
h1 = torch.zeros(3,20)
#人为控制一句话的单词数
for xt in x:
    print(xt)
    h1 = cell(xt,h1)
print(h1.shape)
#torch.Size([3, 20])

标签:word,RNN,nn,--,torch,单词,深度,run
From: https://www.cnblogs.com/ssl-study/p/17353488.html

相关文章

  • Java+GeoTools实现WKT数据根据EPSG编码进行坐标系转换
    场景Java+GeoTools(开源的JavaGIS工具包)快速入门-实现读取shp文件并显示:https://blog.csdn.net/BADAO_LIUMANG_QIZHI/article/details/130367852在上面实现Java中集成Geotools之后,需求是将WKT数据转换成其他坐标系的WKT。比如说将EPSG:4524的坐标系转换成EPSG:2334的坐标系......
  • 如何解决Gridea部分主题不渲染Katex的问题
    很多好看的主题因为对象不是信息学,所以忽视了公式,即\(\LaTeX\)。导致,如果你想渲染一个\(n\),结果成了nn这个简单,导入文件即可。找到主题文件夹,打开templates->post.ejs。添加以下这行代码:<linkrel="stylesheet"`href="https://cdn.jsdelivr.net/npm/[email protected]/......
  • MarkDown基本操作
    MarkDown学习标题的创建+空格+标题名:创建一级标题+空格+标题吗:创建二级标题文字格式的设置文本前后+***:斜体加粗文本前后+**:粗体文本前后+*:斜体文本前后+~~:删除线效果引用文字前+>空格:实现引用效果分割线三个减号-或三个星号图片方法一:![图片名](本地路径)方法......
  • 蓝牙基础
    蓝牙目前已更新的版本(assignednumbers文档中): 各个版本之间的差异:在core_v5.3中的卷1中: ......
  • Linux common clock framework(1)_概述
    1.前言commonclockframework是用来管理系统clock资源的子系统,根据职能,可分为三个部分:1)向其它driver提供操作clocks的通用API。2)实现clock控制的通用逻辑,这部分和硬件无关。3)将和硬件相关的clock控制逻辑封装成操作函数集,交由底层的platform开发者实现,由通用逻辑调用。因此......
  • 2022CSP游记
    目录CSP-J20227:458:158:278:389:129:2310:3411:57中午CSP-S20222:274:156:12估分普及提高自查出分废物鸭子菜菜菜CSP-J2022废了7:45跟随校车到了考场,纪中考点不给矿泉水可还行老朋友都见到了LJHDZRLAFZWTWTCZHWYWJ....WTC已经是ISIJ的金牌了,当年我还跟他是一个......
  • 多线程-从os层面理解常见概念
    如何创建一个线程在Linux系统中有一个方法,他有四个参数,其中第一个参数是利用指针传入,后期如果被修改也会同步修改,第三个参数和自己定义的run方法有关,后面会详细说。intpthread_create(pthread_t*thread,constpthread_attr_t*attr,void*(*start_routine)(void*),vo......
  • 百度首页静态展示页面HTML+CSS
    一直觉得百度首页很复杂的,有那么多的东西,跟这个博主学习了之后,仿写了一下,样式好像很简单只设置的一些组件的高度而已,不得不说,CSS真是个好东西呀话不多说,直接上代码<!DOCTYPEhtml><htmllang="en"><head><metacharset="utf-8"><title>百度首页</titl......
  • 第一阶段自评
    具体计算方法:计算公式=工作质量加和的平均分*40%+工作量加和的平均分*20%+主动性加和的平均分*20%+帮助团队加和的平均分*10%+自身成长加和的平均分*10%=总分(不会超过100的)根据以上方法计算:小组第一:曹立辉(100)小组第二:赵悦恒(95)小组第三:李佳桧(90)......
  • c++遍历搜索关键字
    #include<iostream>#include<windows.h>#include<string.h>#include<strsafe.h>#defineMAX_INPUT_LENGTH255usingnamespacestd;voidprintMemory(char*location,longsize){ printf("\n\n---------------------location......