首页 > 编程语言 >李沐《动手学深度学习》权重衰退(正则化)python代码实现

李沐《动手学深度学习》权重衰退(正则化)python代码实现

时间:2024-11-06 21:17:50浏览次数:3  
标签:loss python torch test 正则 train d2l 李沐 data

一、L2正则化手动实现

# 权重衰退手动实现
%matplotlib inline
import torch
from d2l import torch as d2l
from torch import nn

# n_train个训练样本,n_test个测试样本,输入数据维度是200维
n_train, n_test, num_inputs, batch_size = 20, 200, 200, 5
true_w, true_b = torch.ones((num_inputs, 1))*0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w,b]

def L2_penalty(w):
    return torch.sum(w.pow(2)) / 2

def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log', 
                            xlim=[5, num_epochs], legend=['trian', 'test'])
    
    for epoch in range(num_epochs):
        for X, y in train_iter:
            l = loss(net(X), y) + lambd*L2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if(epoch+1) % 5 == 0:
            animator.add(epoch+1, (d2l.evaluate_loss(net, train_iter, loss),
                                   d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())    
    
train(5)

二、L2正则化利用torch框架实现

# 权重衰退简洁实现
%matplotlib inline
import torch
from d2l import torch as d2l
from torch import nn

# n_train个训练样本,n_test个测试样本,输入数据维度是200维
n_train, n_test, num_inputs, batch_size = 20, 200, 200, 5
true_w, true_b = torch.ones((num_inputs, 1))*0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)


def train_concise(lambd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_(0, 0.1)
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr= 100, 0.003
    
    trainer = torch.optim.SGD([
        {"params":net[0].weight, "weight_decay":lambd},
        {"params":net[0].bias}
    ], lr=lr)
    
    animator = d2l.Animator(xlabel='epoch', ylabel='loss', yscale='log',
                        xlim=[5, num_epochs], legend=['trian', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y) / batch_size
            l.sum().backward()
            trainer.step()
        if (epoch+1) % 5 == 0:
            animator.add(epoch+1, (d2l.evaluate_loss(net, train_iter, loss),
                                   d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', torch.norm(net[0].weight).item())
    
train_concise(5)

标签:loss,python,torch,test,正则,train,d2l,李沐,data
From: https://blog.csdn.net/yuzixuan233/article/details/143580255

相关文章

  • 李沐《动手学深度学习》多层感知机python代码实现
    一、多层感知机手动实现#多层感知机的手动实现%matplotlibinlineimporttorchfromtorchimportnnfromd2limporttorchasd2lbatch_size=256train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)num_inputs,num_outputs,num_first_hiddens=......
  • 李沐《动手学深度学习》线性回归python代码实现
    一、手动实现线性回归#线性回归的手动实现%matplotlibinlineimporttorchimportrandomfromd2limporttorchasd2l#随机按照参数w和b外加一些噪音来创造训练数据集data和labelsdefsynthetic_data(w,b,num_examples):X=torch.normal(0,1,(num_example......
  • Python学习笔记-生成器的应用与原理
    生成器是Python中一种特殊的迭代工具,通过延迟计算的方式来逐步生成序列中的元素。这种特性使得生成器在处理大数据、无限序列或需要惰性求值的场景中十分有效。生成器的核心思想是通过yield语句逐步返回值,暂停并保留当前状态,直到下次调用继续执行,从而节省内存并优化性能......
  • Python学习笔记-断点操作结合异常处理
    在编程中,调试和错误处理是提升代码质量和开发效率的关键环节。调试能帮助识别并修复问题,异常处理则使得程序能在出现错误时有效地管理而不至于崩溃。断点与异常处理的结合应用是高级编程中不可或缺的技巧,能够帮助更高效地定位问题,提高程序的鲁棒性。本文将通过详细的断点和......
  • Python——数据结构与算法-时间复杂度&空间复杂度-链表&树状结构
    1.数据结构和算法简介程序可以理解为:程序=数据结构+算法概述/目的:都可以提高程序的效率(性能)数据结构指的是存储,组织数据的方式.算法指的是为了解决实际业务问题而思考思路和方法,就叫:算法.2.算法的5大特性介绍概述:为了解决实际业务问题,......
  • python面向对象(一)
    前言Python是一种功能强大的编程语言,因其简洁的语法和强大的库而备受开发者喜爱。而在Python中,面向对象编程(OOP)是一种核心的编程范式,它通过模拟现实世界中的对象和交互来帮助我们设计清晰、易维护的代码。在本篇博客中,我们将深入探讨Python的面向对象编程的基本概念,了解如......
  • 【PAT_Python解】1114 全素日
    原题链接:PTA|程序设计类实验辅助教学平台Tips:以下Python代码仅个人理解,非最优算法,仅供参考!多学习其他大佬的AC代码!defis_prime(n):ifn<=3:returnn>=2ifn%6notin(5,1):returnFalseforiinrange(5,int(n**0.5)+1,6):......
  • 【PAT_Python解】1110 区块反转
    原题链接:PTA|程序设计类实验辅助教学平台Tips:以下Python代码仅个人理解,非最优算法,仅供参考!多学习其他大佬的AC代码!importsys#读取输入head,n,k=map(int,sys.stdin.readline().split())#初始化数据,装入字典,最终取值data={}next={}for_inrange(n):......
  • 【PAT_Python解】1113 钱串子的加法
    原题链接:PTA|程序设计类实验辅助教学平台Tips:以下Python代码仅个人理解,非最优算法,仅供参考!多学习其他大佬的AC代码!defadd_base30(num1,num2):max_length=max(len(num1),len(num2))#在前面补零,使两个字符串长度相同num1=num1.zfill(max_lengt......
  • 2024/11/6日 日志 正则表达式,web与HTTP
    正则表达式点击查看代码--正则表达式--·概念:正则表达式定义了字符串组成的规则--·定义:--1.直接量:注意不要加引号--varreg=/^lw{6,12}$/:--2.创建RegExp对象--varreg=newRegExp("^lw{6,12}$");--·方法:--· test(str):判断指定字符串是否......