首页 > 其他分享 >元学习的简单示例

元学习的简单示例

时间:2024-09-20 23:19:57浏览次数:10  
标签:示例 support labels 学习 inner 简单 query model data

代码功能

模型结构:SimpleModel是一个简单的两层全连接神经网络。
元学习过程:在maml_train函数中,每个任务由支持集和查询集组成。模型先在支持集上进行训练,然后在查询集上进行评估,更新元模型参数。
任务生成:通过create_task_data函数生成随机任务数据,用于模拟不同的学习任务。
元训练和微调:在元训练后,代码展示了如何在新任务上进行模型微调和测试。
这个简单示例展示了如何使用元学习方法(MAML)在不同任务之间共享学习经验,并快速适应新任务。
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# 构建一个简单的全连接神经网络作为基础学习器
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(2, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 创建元学习过程
def maml_train(model, meta_optimizer, tasks, n_inner_steps=1, inner_lr=0.01):
    criterion = nn.CrossEntropyLoss()
    
    # 遍历多个任务
    for task in tasks:
        # 模拟支持集和查询集
        support_data, support_labels, query_data, query_labels = task
        
        # 初始化模型参数,用于内循环训练
        inner_model = SimpleModel()
        inner_model.load_state_dict(model.state_dict())
        inner_optimizer = optim.SGD(inner_model.parameters(), lr=inner_lr)
        
        # 在支持集上进行内循环训练
        for _ in range(n_inner_steps):
            pred_support = inner_model(support_data)
            loss_support = criterion(pred_support, support_labels)
            inner_optimizer.zero_grad()
            loss_support.backward()
            inner_optimizer.step()
        
        # 在查询集上评估
        pred_query = inner_model(query_data)
        loss_query = criterion(pred_query, query_labels)
        
        # 计算梯度并更新元模型
        meta_optimizer.zero_grad()
        loss_query.backward()
        meta_optimizer.step()

# 生成一些简单的任务数据
def create_task_data():
    # 随机生成支持集和查询集
    support_data = torch.randn(10, 2)
    support_labels = torch.randint(0, 2, (10,))
    query_data = torch.randn(10, 2)
    query_labels = torch.randint(0, 2, (10,))
    return support_data, support_labels, query_data, query_labels

# 创建多个任务
tasks = [create_task_data() for _ in range(5)]

# 初始化模型和元优化器
model = SimpleModel()
meta_optimizer = optim.Adam(model.parameters(), lr=0.001)

# 进行元训练
maml_train(model, meta_optimizer, tasks)

# 测试新的任务
new_task = create_task_data()
support_data, support_labels, query_data, query_labels = new_task

# 进行模型微调(内循环)
inner_model = SimpleModel()
inner_model.load_state_dict(model.state_dict())
inner_optimizer = optim.SGD(inner_model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# 使用支持集进行一次更新
pred_support = inner_model(support_data)
loss_support = criterion(pred_support, support_labels)
inner_optimizer.zero_grad()
loss_support.backward()
inner_optimizer.step()

# 在查询集上测试
pred_query = inner_model(query_data)
print("预测结果:", pred_query.argmax(dim=1).numpy())
print("真实标签:", query_labels.numpy())

标签:示例,support,labels,学习,inner,简单,query,model,data
From: https://blog.csdn.net/C7211BA/article/details/142407626

相关文章

  • opencascade Bnd_OBB源码学习 OBB包围盒
    opencascadeBnd_OBBOBB包围盒前言类描述了定向包围盒(OBB),比轴对齐包围盒(AABB)更紧密地包围形状的体积。OBB由盒子的中心、轴以及三个维度的一半定义。与AABB相比,OBB在作为非干扰物体的排斥机制时可以更有效地使用。方法1.空构造函数//!空构造函数Bnd_OBB():myIsAABox(S......
  • [leetcode刷题]面试经典150题之3删除有序数组中的重复项(简单)
    题目 删除有序数组中的重复项给你一个 非严格递增排列 的数组 nums ,请你 原地 删除重复出现的元素,使每个元素 只出现一次 ,返回删除后数组的新长度。元素的 相对顺序 应该保持 一致 。然后返回 nums 中唯一元素的个数。考虑 nums 的唯一元素的数量为 k ,你......
  • [leetcode刷题]面试经典150题之5多数元素元素(简单)【附Boyer-Moore 投票算法(摩尔投票法
    很有意思的一个题,想了半天没想出来,最后发现两行代码就做出来了。写完后学习到还可以用Boyer-Moore投票算法,能减小空间复杂度,我把它写在后面,可以进一步学习。题目  多数元素给定一个大小为 n 的数组 nums ,返回其中的多数元素。多数元素是指在数组中出现次数 大于 ⌊......
  • 《深度学习》—— PyTorch的介绍及PyTorch的CPU版本安装
    文章目录一、PyTorch的简单介绍二、pytorch的CPU版本安装三、torch、torchvision、torchaudio三个库的介绍一、PyTorch的简单介绍PyTorch是一个由FacebookAI实验室开发的深度学习框架,它基于Python,并提供了高效的GPU加速和灵活的模型定义能力。1.PyTorch的基本特点......
  • Docker学习
    系列文章目录第一章基础知识、数据类型学习第二章万年历项目第三章代码逻辑训练习题第四章方法、数组学习第五章图书管理系统项目第六章面向对象编程:封装、继承、多态学习第七章封装继承多态习题第八章常用类、包装类、异常处理机制学习第九章集合学习第......
  • opencascade Adaptor3d_Curve源码学习
    opencascadeAdaptor3d_Curve前言用于几何算法工作的3D曲线的根类。适配曲线是曲线提供的服务与使用该曲线的算法所需服务之间的接口。提供了两个派生具体类:GeomAdaptor_Curve,用于Geom包中的曲线Adaptor3d_CurveOnSurface,用于Geom包中表面上的曲线。用于评估BSpline曲线......
  • Asp.net MVC 学习笔记Razor(一)
    接手一个古老的项目,DotNet4.0编写的一个ASP.NETMVC的网页软件,期间结果好几任开发者的不懈努力,编码风格至少有3种,看的头疼。当然最主要的是我一直是做c++开发、c#中的wpf和winform或者python,asp.NET代码看的有点眼生。不管怎么样,先把基础的东西过一遍吧。Razor语法主要的Raz......
  • java学习9.20
    今天是简单的java小测验,实现简单的增删改查操作。我先用数组完成。后面的话想实现连接数据库的增删改查,但是始有bug不知道咋改,写的少不清楚问题出在哪,多写几回应该就能对症下药。下面是数组的代码Student类publicclassStudent{Stringstunumber;Stringname;......
  • 一些本影演算的简单应用
    根据这篇文章第一节的分析,对于任意数列\(\{a_n\}\),存在一个线性泛函\(L\)满足\(L(z^n)=a_n\)(在这里因为没有对线性泛函\(L\)的分析,所以使用正常记号),这说明了基本的本影演算本身的严谨性.对于\(L(z^n)=a_n\),称\(z\)是数列\(\{a_n\}\)的本影(umbra),通过\(L(z^n)\)对数......
  • opencascade Adaptor3d_CurveOnSurface源码学习
    opencascadeAdaptor3d_CurveOnSurface前言用于连接由Geom包中表面上的曲线提供的服务,以及使用这条曲线的算法所要求的服务。该曲线被定义为一个二维曲线,来自Geom2d包,位于表面的参数空间中方法1默认构造函数Standard_EXPORTAdaptor3d_CurveOnSurface();2通过给定的表面......