首页 > 其他分享 >从代码上解析Meta-learning

从代码上解析Meta-learning

时间:2023-05-04 21:32:26浏览次数:54  
标签:10 self Meta meta MAML learning np theta 解析


文章目录

  • 1.背景
  • 2.Meta-learning理解
  • 2.1 Meta-learning到底做什么
  • 2.2 MAML算法
  • 2.3 MAML算法步骤
  • 2.4 MAML代码分析和实现
  • 3.参考文章

1.背景

meta-learning区别于pretraining,它主要通过多个task来学习不同任务之间的内在联系,通俗点说,也即是通过多个任务来学习共同的参数。

举个例子,人类在进行分类的时候,由于见过太多东西了,且已经学过太多东西的分类了。那么我们可能只需每个物体一张照片,就可以对物体做到很多的区分了,那么人是怎么根据少量的图片就能学习到如此好的成果呢?

显然 ,我们已经掌握了各种用于图片分类的较巧了,比如根据物体的轮廓、纹理等信息进行分类,那么根据轮廓、根据纹理等区分物体的方法,就是我们在meta learning中需要教机器进行学习的学习技巧。

2.Meta-learning理解

meta-learning主要有以下几个概念,理解了概念我们就更容易理解这个算法到底在干什么。

2.1 Meta-learning到底做什么

meta-learning主要分为两个阶段:

  • meta-train:用来训练模型参数,使得模型能够学到不同任务中的共同参数
  • meta-test:类似于fine-tuneing阶段,用来微调下游任务。

首先假设我们首先有数据集从代码上解析Meta-learning_数据集,这个数据集有10个类别,每个类别有100个样本,共1000个样本数聚集。

我们把数据集进行分割,把100个样本分成3份,比例是1:4:5。这三份的样本数量为10,40,50。

  • N-way K-shot:在meta-train阶段,假设实验中设置5-way 10-shot,也就是在每个任务task中抽样5个类别,每个类别10份数据,构成一个support set,用在meta-train阶段。这5个类别中,另外的40份数据为query set,可以用在meta-test阶段;而还剩10个类别的50份数据,用来进行微调任务。

2.2 MAML算法

MAML是用来实现meta-learning的一种算法。下面用例子来说明MAML的算法实现过程。

从代码上解析Meta-learning_深度学习_02

假设我们目前有3个tasks,分别为从代码上解析Meta-learning_meta-learning_03。按照以前模型的训练方式,首先,我们随机初始化模型参数从代码上解析Meta-learning_深度学习_04。然后开始训练任务从代码上解析Meta-learning_人工智能_05,接着最小化损失函数从代码上解析Meta-learning_数据集_06来更新网络的参数,这样我们就会得到新的参数从代码上解析Meta-learning_深度学习_07。同理,我们可以接着更新其他两个任务。

但以前模型的训练方式,是每个任务都是随机初始化从代码上解析Meta-learning_深度学习_04开始,每个任务都是独立的。如果我们把三个任务初始化的从代码上解析Meta-learning_深度学习_04到公用的位置,则不需要更多的梯度更新步骤。MAML就是做这件事的。

MAML 试图找到许多相关任务共有的最佳参数从代码上解析Meta-learning_深度学习_04,因此我们可以用很少的数据相对快速地训练新任务,而无需通过采取许多梯度步骤来确定最佳状态初始化 从代码上解析Meta-learning_深度学习_04

如下图所示,从代码上解析Meta-learning_深度学习_04重新固定到一个新的位置来训练。对于一个新的任务task 从代码上解析Meta-learning_初始化_13,就不需要重新随机初始化参数来训练了。

从代码上解析Meta-learning_meta-learning_14

2.3 MAML算法步骤

具体的MAML算法如图所示:

从代码上解析Meta-learning_初始化_15


我们有模型从代码上解析Meta-learning_深度学习_16,其参数为从代码上解析Meta-learning_深度学习_04。同时又一系列任务从代码上解析Meta-learning_meta-learning_18

  1. 首先,随机初始化Meta模型参数从代码上解析Meta-learning_数据集_19
  2. 对于第从代码上解析Meta-learning_数据集_20任务从代码上解析Meta-learning_深度学习_21,抽样batch个从代码上解析Meta-learning_深度学习_21,这样就会构成一个batch,其中从代码上解析Meta-learning_深度学习_23

下面是内循环的操作:

  1. 如果我们有5个任务,则从代码上解析Meta-learning_数据集_24。从每个从代码上解析Meta-learning_数据集_25中,抽样从代码上解析Meta-learning_meta-learning_26个数据。对于每个任务,都需要更新一次参数:
    从代码上解析Meta-learning_数据集_27
    其中从代码上解析Meta-learning_深度学习_28是任务从代码上解析Meta-learning_数据集_25的最佳参数;从代码上解析Meta-learning_深度学习_30是学习率;从代码上解析Meta-learning_深度学习_31是梯度。
  2. 对每个任务进行更新,这样会得到5个最佳参数,从代码上解析Meta-learning_深度学习_32

下面是外循环的操作:

  1. 在外循环中,我们需要更新原始的meta模型参数从代码上解析Meta-learning_数据集_19。利用任务从代码上解析Meta-learning_深度学习_21,来生成每个任务的loss值,然后梯度更新参数从代码上解析Meta-learning_数据集_19
    从代码上解析Meta-learning_meta-learning_36
    其中,从代码上解析Meta-learning_数据集_19是我们原始meta模型的参数值;从代码上解析Meta-learning_人工智能_38是超参数;从代码上解析Meta-learning_meta-learning_39是每个任务从代码上解析Meta-learning_深度学习_21的梯度。

PS:

  • 内循环中:只需要利用support set一步更新就可以
  • 外循环中:需要利用query set进行多次迭代更新。

2.4 MAML代码分析和实现

自定义样本抽取代码,从代码上解析Meta-learning_深度学习_41的维度为从代码上解析Meta-learning_深度学习_42

def sample_points(k):
    x = np.random.rand(k,50)
    y = np.random.choice([0, 1], size=k, p=[.5, .5]).reshape([-1,1])
    return x,y

x, y = sample_points(10)
print x[0]
print y[0]

输出结果:

从代码上解析Meta-learning_数据集_43

利用简单的前馈神经网络作例子:

a = np.matmul(X, theta)
YHat = sigmoid(a)

MAML实现代码:

class MAML(object):
    def __init__(self):
        """
        定义参数,实验中用到10-way,10-shot
        """
        # 共有10个任务
        self.num_tasks = 10
        
        # 每个任务的数据量:10-shot
        self.num_samples = 10

        # 训练的迭代次数
        self.epochs = 10000
        
        # 内循环中,学习率,用来更新\theta'
        self.alpha = 0.0001
        
        # 外循环的学习率,用来更新meta模型的\theta
        self.beta = 0.0001
       
        # meta模型初始化的参数
        self.theta = np.random.normal(size=50).reshape(50, 1)
      
    # sigmoid函数
    def sigmoid(self,a):
        return 1.0 / (1 + np.exp(-a))
    
    #now let us get to the interesting part i.e training :P
    def train(self):
        
        # 循环epoch次数
        for e in range(self.epochs):        
            
            self.theta_ = []
            
            # 利用support set
            for i in range(self.num_tasks):
               
                # 抽样k个样本出来,k-shot
                XTrain, YTrain = sample_points(self.num_samples)
                
                # 前馈神经网络
                a = np.matmul(XTrain, self.theta)
                YHat = self.sigmoid(a)

                # 计算交叉熵loss
                loss = ((np.matmul(-YTrain.T, np.log(YHat)) - np.matmul((1 -YTrain.T), np.log(1 - YHat)))/self.num_samples)[0][0]
                
                # 梯度计算,更新每个任务的theta_,不需要更新meta模型的参数theta
                gradient = np.matmul(XTrain.T, (YHat - YTrain)) / self.num_samples
                self.theta_.append(self.theta - self.alpha*gradient)
                
     
            # 初始化meta模型的梯度
            meta_gradient = np.zeros(self.theta.shape)
            
            # 利用query set
            for i in range(self.num_tasks):
            
                # 在meta-test阶段,每个任务抽取10个样本出来进行
                XTest, YTest = sample_points(10)

                # 前馈神经网络
                a = np.matmul(XTest, self.theta_[i])
                YPred = self.sigmoid(a)
                           
                # 这里需要叠加每个任务的loss
                meta_gradient += np.matmul(XTest.T, (YPred - YTest)) / self.num_samples

  
            # 更新meat模型的参数theta
            self.theta = self.theta-self.beta*meta_gradient/self.num_tasks
                                       
            if e%1000==0:
                print "Epoch {}: Loss {}\n".format(e,loss)             
                print 'Updated Model Parameter Theta\n'
                print 'Sampling Next Batch of Tasks \n'
                print '---------------------------------\n'

最后输出结果:

model = MAML()
model.train()

从代码上解析Meta-learning_meta-learning_44

标签:10,self,Meta,meta,MAML,learning,np,theta,解析
From: https://blog.51cto.com/u_12243550/6244549

相关文章

  • 虚拟机安装 bind 9 及顶级域名解析
    如何将所有域名解析到同一个网关服务器中,手把手教学!!! #安装可以参考下面这个教程ISCBIND9-最详细、最认真的从零开始的BIND9-DNS服务搭建及其原理讲解(Debian/Windows)-DoHerasYang-博客园(cnblogs.com)我这边只能在虚拟机中安装成功,宿主机总是报unable......
  • iOS MachineLearning 系列(10)—— 自然语言分析之文本拆解
    iOSMachineLearning系列(10)——自然语言分析之文本拆解本系列的前几篇文章介绍了iOS中有关图像和视频处理的API,视觉处理主要有Vision框架负责,本篇起,将介绍在iOS中MachineLearning领域相关的自然语言处理框架:NaturalLanguage。1-简介NaturalLanguage是iOS种提供的一种处理自......
  • 4D成像毫米波雷达点云数据集VOD(含Python和MATLAB数据解析仿真代码)
    公众号【调皮连续波】【正文】编辑|  调皮哥的小助理     审核|调皮哥1、引言4D成像雷达开源数据集,其实好用的并不多,VOD数据集我个人感觉还可以。这其实也在之前分享过,但是为了更加清楚地展示这个数据集如何使用,本期文章就简单做个分享。在MATLAB环境下可以得到以下的......
  • 特斯拉双级联毫米波雷达解析
    公众号【调皮连续波】【正文】1、芯片屏蔽罩(1)屏蔽罩的作用是什么?屏蔽电磁干扰,芯片对天线的电磁干扰,天线对芯片的电磁干扰。但谁是主要,谁是次要的呢?我个人认为天线对芯片的干扰是主要的。(2)屏蔽罩上6个小孔的作用是什么?网友回答说屏蔽罩上的小孔,一方面是为了工作时内部器件的散热,开孔......
  • Nacos修改权重报错caused: errCode: 500, errMsg: do metadata operation failed ;caus
    今天修改Nacos权重时报错如下:caused:errCode:500, caused:errCode:500,errMsg:dometadataoperationfailed;caused:com.alibaba.nacos.con。解决方案:停掉nacos服务将nacos文件夹下data中的protocol文件夹删除重启nacos服务即可 ......
  • golang 解析处理word文档扩展包
    github.com/unidoc/unioffice该扩展包对word操作功能比较全,但为商业使用,注册后有100次的试用,具体使用就不详细说明了,具体可以看https://github.com/unidoc/unioffice-examplesgithub.com/carmel/gooxml该扩展包为unidoc/unioffice的免费版,为收费版的1.4.0版本,虽然功能没有......
  • opendrive数据格式解析思维导图 , opendrive高精地图是自动驾驶领域使用最为广泛的开源
    opendrive数据格式解析思维导图,opendrive高精地图是自动驾驶领域使用最为广泛的开源高精地图标准级地图格式。本思维导图将详细剖开高精路网地图内部的数据格式,涵盖:道路、车道、车道段、交叉口等相关名词及其属性、作用、链接关系等参数的解析。内容比较全面,希望对高精地图进行......
  • Lanelet2高精地图解析及全局路径规划, Lanelet2格式的高精地图是与opendrive高精地图并
    Lanelet2高精地图解析及全局路径规划,Lanelet2格式的高精地图是与opendrive高精地图并行的当前两大最流行的高精地图格式。在autoware停止维护AI版本推出Auto版本后,更是将原先的Lanelet地图格式进行升级为lanelet2。因此,如果大家有公司的产品依赖autoware的代码进行部署的,熟悉Lane......
  • Teachable Reinforcement Learning via Advice Distillation
    发表时间:2021(NeurIPS2021)文章要点:这篇文章提出了一种学习policy的监督范式,大概思路就是先结构化advice,然后先学习解释advice,再从advice中学policy。这个advice来自于外部的teacher,相当于一种human-in-the-loopdecisionmaking。另外这个advice不单单是reward的大小,可能具有......
  • 对港股实时数据的解析
    找页面,东财的实在太复杂了!新浪!这个页面还行https://stock.finance.sina.com.cn/hkstock/quotes/00700.html实时价格又是js价值的,跟踪,找到https://stock.finance.sina.com.cn/hkstock/api/openapi.php/HK_StockService.getHKMinline?symbol=00700&random=1683011712092&callback=va......