首页 > 其他分享 >论文解读(MAML)《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》

论文解读(MAML)《Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks》

时间:2024-04-27 21:24:39浏览次数:28  
标签:acc loss support Agnostic Fast label Meta meta query

Note:[ wechat:Y466551 | 可加勿骚扰,付费咨询 ]

论文信息

论文标题:Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
论文作者:Chelsea Finn、Pieter Abbeel、Sergey Levine
论文来源:2017 
论文地址:download 
论文代码:download
视屏讲解:click

1-摘要

  我们提出了一种与模型无关的元学习算法,在这个意义上,它与任何经过梯度下降训练的模型兼容,并适用于各种不同的学习问题,包括分类、回归和强化学习。元学习的目标是在各种学习任务上训练一个模型,这样它就可以只使用少量的训练样本来解决新的学习任务。在我们的方法中,模型的参数被明确地训练,这样少量的梯度步长和来自新任务的少量训练数据将在该任务上产生良好的泛化性能。实际上,我们的方法训练的模型易于微调。我们证明了这种方法在两个低镜头图像分类基准上取得了最先进的性能,在少镜头回归上产生了良好的结果,并加速了使用神经网络策略对策略梯度强化学习的微调。

2-方法

  

  代码:

def maml_train(model, support_images, support_labels, query_images, query_labels, inner_step, args, optimizer, is_train=True):
    meta_loss = []
    meta_acc = []
    for support_image, support_label, query_image, query_label in zip(support_images, support_labels, query_images, query_labels):
        fast_weights = collections.OrderedDict(model.named_parameters())
        for _ in range(inner_step):  #inner_step = 1
            # Update weight
            support_logit = model.functional_forward(support_image, fast_weights)
            support_loss = nn.CrossEntropyLoss().cuda()(support_logit, support_label)
            grads = torch.autograd.grad(support_loss, fast_weights.values(), create_graph=True)
            fast_weights = collections.OrderedDict((name, param - args.inner_lr * grad) for ((name, param), grad) in zip(fast_weights.items(), grads))

        # Use trained weight to get query loss
        query_logit = model.functional_forward(query_image, fast_weights)
        query_prediction = torch.max(query_logit, dim=1)[1]
        query_loss = nn.CrossEntropyLoss().cuda()(query_logit, query_label)
        query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label)
        meta_loss.append(query_loss)
        meta_acc.append(query_acc.data.cpu().numpy())

    # Zero the gradient
    optimizer.zero_grad()
    meta_loss = torch.stack(meta_loss).mean()
    meta_acc = np.mean(meta_acc)

    if is_train:
        meta_loss.backward()
        optimizer.step()

    return meta_loss, meta_acc

 

标签:acc,loss,support,Agnostic,Fast,label,Meta,meta,query
From: https://www.cnblogs.com/BlairGrowing/p/17689137.html

相关文章

  • Fastbin attack&&Double free和Unsortbin leak的综合使用
    Fastbinattack&&Doublefree和Unsortbinleak的综合使用✅今天做一个综合题目,包括利用Fastbinattack实现多指针指向一个地址,以及利用Unsortbinleak泄露libc基地址和修改__malloc_hook地址为one_gadget题目是buuctf上面的一道题目,题目链接https://buuoj.cn/challenges#babyhe......
  • Fast Training Algorithms for Deep Convolutional Fuzzy Systems With Application t
    类似深度卷积神经网络DCNN,模糊系统领域有个深度卷积模糊系统deepconvolutionalfuzzysystem(DCFS),每一层都是一个模糊系统,上一层的输出是下一层的输入。这篇论文目的是加速DCFS的计算速度,解决可解释性1990年提出,也用反向传播训练DCFS受困于低维度小数据集,大数据量时计算负担太......
  • 使用 ForAttributeWithMetadataName 提高 IIncrementalGenerator 增量 Source Generat
    本文将告诉大家如何使用ForAttributeWithMetadataName方法用来提高IIncrementalGenerator增量SourceGenerator源代码生成的开发效率以及提高源代码生成器的运行效率这是一个在2022的6月15才合入的新功能。原因是Roslyn团队发现了大量的源代码生成器和分析器项目都......
  • 【翻译】RISC-V裸机编程指南(Bare metal programming with RISC-V guide)
    RISC-V裸机编程指南(BaremetalprogrammingwithRISC-Vguide)作者:Follow@popovicu94原文链接:https://popovicu.com/posts/bare-metal-programming-risc-v/今天,我们将探讨如何为RISC-V架构的机器编写一个裸机程序。为了确保可复现,目标平台选择为QEMUriscv64virt虚拟机......
  • centos8.2报错Failed to download metadata for repo 'BaseOS': Cannot prepare inter
    报错CentOS-8-Base68B/s|38B00:00错误:Failedtodownloadmetadataforrepo'BaseOS':......
  • sys.meta_path的作用
    `sys.meta_path`是Python导入系统中的一个关键特性,它是一个列表,包含了所有的元路径查找器(metapathfinders)。这些查找器在导入模块时会被依次查询,以便找到对应的模块。当你在Python中导入一个模块时,解释器会按照以下步骤进行:1.检查`sys.modules`缓存,看看模块是否已经被导......
  • 欢迎 Llama 3:Meta 的新一代开源大语言模型
    介绍Meta公司的Llama3是开放获取的Llama系列的最新版本,现已在HuggingFace平台发布。看到Meta持续致力于开放AI领域的发展令人振奋,我们也非常高兴地全力支持此次发布,并实现了与HuggingFace生态系统的深度集成。Llama3提供两个版本:8B版本适合在消费级GPU上高......
  • cf 393017C 石头剪刀布 Metacamp2022-onlineA-dev
     Problem-C-Codeforces 五维的DPg[i][D][r][s][p]i:到了第i个位置D:最后有D个点放在后面r,s,p:已经选择了r,s,p个石头,剪刀,布放到后面 四维的DPf[i][D][r][s][p]i:到了第i个位置D:目前有D个点放在后面r,s,p:已经选择了r,s,p个石头,剪刀,布放到后面其......
  • https://github.com/meta-llama/llama3 文生图
    https://github.com/meta-llama/llama3 Skiptocontent NavigationMenu Product Solutions OpenSource Pricing Searchorjumpto...  SigninSignup  meta-llama/llama3PublicNotificationsFork 1.4k Star ......
  • Fast Möbius Transform 学习笔记
    小Tips:在计算机语言中\(\cup\)=&/and,\(\cap\)=|/orFirstStep.定义定义长度为\(2^n\)的序列的and卷积\(A=B*C\)为\(A_i=\sum_{j\cupk=i}{B_j*C_k}\)考虑快速计算SecondStep.变换定义长度为\(2^n\)的序列的Zeta变换为\[\hat{A}_i=\sum......