首页 > 其他分享 >torch--模型选择-欠拟合-过拟合

torch--模型选择-欠拟合-过拟合

时间:2024-11-06 13:41:26浏览次数:1  
标签:loss labels features -- self torch train 拟合

"""
模型选择,欠拟合、过拟合
"""

import math
import numpy as np
import torch
from d2l import torch as d2l
from IPython import display
import matplotlib.pyplot as plt
from torch import nn


max_degree = 20                                # 多项式的最大阶数
n_train, n_test = 100, 100                     # 训练和测试数据集大小
true_w = np.zeros(max_degree)                  # 分配大量的空间
true_w[0:4] = np.array([5, 1.2, -3.4, 5.6])

features = np.random.normal(size=(n_train + n_test, 1))
np.random.shuffle(features)
poly_features = np.power(features, np.arange(max_degree).reshape(1, -1))
for i in range(max_degree):
    poly_features[:, i] /= math.gamma(i + 1)  # gamma(n)=(n-1)!
# labels的维度:(n_train+n_test,)
labels = np.dot(poly_features, true_w)
labels += np.random.normal(scale=0.1, size=labels.shape)


# NumPy ndarray转换为tensor
true_w, features, poly_features, labels = [torch.tensor(x, dtype=torch.float32) for x in [true_w, features, poly_features, labels]]

features[:2], poly_features[:2, :], labels[:2]


class Animator:
    """在动画中绘制数据"""
    def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
                 figsize=(3.5, 2.5), pic_name=None):
        # 增量地绘制多条线
        if legend is None:
            legend = []
        d2l.use_svg_display()
        self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
        if nrows * ncols == 1:
            self.axes = [self.axes, ]
        # 使用lambda函数捕获参数
        self.config_axes = lambda: d2l.set_axes(
            self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
        self.X, self.Y, self.fmts = None, None, fmts
        self.pic_name = pic_name

    def add(self, x, y):
        # 向图表中添加多个数据点
        if not hasattr(y, "__len__"):
            y = [y]
        n = len(y)
        if not hasattr(x, "__len__"):
            x = [x] * n
        if not self.X:
            self.X = [[] for _ in range(n)]
        if not self.Y:
            self.Y = [[] for _ in range(n)]
        for i, (a, b) in enumerate(zip(x, y)):
            if a is not None and b is not None:
                self.X[i].append(a)
                self.Y[i].append(b)
        self.axes[0].cla()
        for x, y, fmt in zip(self.X, self.Y, self.fmts):
            self.axes[0].plot(x, y, fmt)
        self.config_axes()
        plt.draw()
        plt.pause(0.1)
        plt.savefig(self.pic_name)
        display.display(self.fig)
        display.clear_output(wait=True)

    def show(self):
        display.display(self.fig)


def evaluate_loss(net, data_iter, loss):
    """评估给定数据集上模型的损失"""
    metric = d2l.Accumulator(2)  # 损失的总和,样本数量
    for X, y in data_iter:
        out = net(X)
        y = y.reshape(out.shape)
        l = loss(out, y)
        metric.add(l.sum(), l.numel())
    return metric[0] / metric[1]


def train_epoch_ch3(net, train_iter, loss, updater):
    """训练模型一个迭代周期(定义见第3章)"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    # metric = d2l.Accumulator(3)
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使用PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            # 使用定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])


def train(train_features, test_features, train_labels, test_labels,
          num_epochs=400, state='正常'):
    loss = nn.MSELoss(reduction='none')
    input_shape = train_features.shape[-1]
    # 不设置偏置,因为我们已经在多项式中实现了它
    net = nn.Sequential(nn.Linear(input_shape, 1, bias=False))
    batch_size = min(10, train_labels.shape[0])
    train_iter = d2l.load_array((train_features, train_labels.reshape(-1, 1)),
                                batch_size)
    test_iter = d2l.load_array((test_features, test_labels.reshape(-1, 1)),
                               batch_size, is_train=False)
    trainer = torch.optim.SGD(net.parameters(), lr=0.01)
    animator = Animator(xlabel='epoch', ylabel='loss', yscale='log',
                        xlim=[1, num_epochs], ylim=[1e-3, 1e2],
                        legend=['train', 'test'], pic_name=f'model_select_{state}')
    for epoch in range(num_epochs):
        train_epoch_ch3(net, train_iter, loss, trainer)
        if epoch == 0 or (epoch + 1) % 20 == 0:
            animator.add(epoch + 1, (evaluate_loss(net, train_iter, loss),
                                     evaluate_loss(net, test_iter, loss)))
    print('weight:', net[0].weight.data.numpy())


# 三阶多项式函数拟合(正常)
# 从多项式特征中选择前4个维度,即1,x,x^2/2!,x^3/3!
train(poly_features[:n_train, :4], poly_features[n_train:, :4],
      labels[:n_train], labels[n_train:], state='正常拟合')

# 线性函数拟合(欠拟合)
# 从多项式特征中选择前2个维度,即1和x
train(poly_features[:n_train, :2], poly_features[n_train:, :2],
      labels[:n_train], labels[n_train:], state='欠拟合')

# 高阶多项式函数拟合(过拟合)
# 从多项式特征中选取所有维度
train(poly_features[:n_train, :], poly_features[n_train:, :],
      labels[:n_train], labels[n_train:], num_epochs=1500, state='过拟合')

标签:loss,labels,features,--,self,torch,train,拟合
From: https://www.cnblogs.com/jackchen28/p/18530014

相关文章

  • Docker与k8s的联系
    本篇为帮助为帮助理解Docker与k8s大体的作用以及他们的联系和区别,没有对其进行深入刨析。产生的意义首先对于开发者来讲Docker与k8s都是为了去运行你写的代码的工具或者程序。在使用他们之前我们都需要先在本机把代码写好然后传到Linux服务器上部署运行。这样的方式也是没问题......
  • P4528 做题笔记
    神题。记\(f_{a,b,c,d}\)表示四个数排名依次为\(a,b,c,d\)的子序列的方案数(最小的排名为\(1\),以此类推)。闪电图腾就是\(f_{1,3,2,4}\),山峰图腾A为\(f_{1,2,4,3}\),B为\(f_{1,4,3,2}\)。我们所求的式子是\(f_{1,3,2,4}-f_{1,2,4,3}-f_{1,4,3,2}\)。\(=(f_{1,x,2,x}-......
  • 轻松识别报关单文字信息,翔云API海关报关单如何集成
    在全球化贸易和国际物流日益增长的今天,报关单是企业进出口活动中不可或缺的文件。报关单不仅记录了商品的进口和出口信息,还直接影响到海关的放行、货物的运输效率以及相关税费的计算。然而,传统的人工录入和处理报关单信息的方式,往往存在效率低下、错误频发等问题。为助力解决......
  • 数学建模_BP神经网络模型(多输入多输出)回归模型+Matlab代码包教会使用,直接替换数据
    BP神经网络模型(多输入多输出)回归模型神经网络回归模型原理讲解​该模型是一个典型的多层前馈神经网络(FeedforwardNeuralNetwork,FFNN),应用于回归问题。其基本原理包括数据输入、隐藏层神经元的处理、输出层的预测、以及通过反向传播算法优化权重和偏置的过程。下面将详......
  • Python 继承、多态、封装、抽象
    面向对象编程(OOP)是Python中的一种重要编程范式,它通过类和对象来组织代码。OOP的四个核心概念是继承(Inheritance)、多态(Polymorphism)、封装(Encapsulation)和数据抽象(DataAbstraction)。下面将详细介绍这四个概念。继承(Inheritance)继承是面向对象编程(OOP)的一个基本概念,它允......
  • 自激式开关电源:电路解析与实战心得
    自激式开关电源:电路解析与实战心得在现代电子产品里,开关电源就像一个隐形的心脏。它在悄无声息地为各种电路提供稳定的电流,维持整个系统的正常运转。而在开关电源的家族中,自激式开关电源因为其结构简单、成本低廉,一直是小功率应用里的明星选手。今天我们就通过这个电路图,......
  • [记录]安装 Python 中SPAM库失败
    报错信息:×pythonsetup.pyegg_infodidnotrunsuccessfully.│exitcode:1╰─>[41linesofoutput]runningegg_infocreating/private/var/folders/l9/f9rjm65s07bdf55y5xyk9f2c0000gn/T/pip-pip-egg-info-o3ic4gdp/progressbar.egg-infowriting/private/var/fo......
  • 从零到精通:BLDC电机驱动电路详解与设计思路
    BLDC驱动电路的设计解析这是一个经典的无刷直流电机(BLDC)驱动电路,用于控制三相电机的转速和扭矩。BLDC电机在各种领域都非常常见,比如无人机、电动汽车、电动滑板等,原因很简单:高效、耐用、响应快。而要设计一个稳定、可靠的BLDC驱动电路,电路设计者不仅需要懂得每个模块的功......
  • .NET 8 高性能跨平台图像处理库 ImageSharp
    前言传统的System.Drawing库功能丰富,但存在平台限制,不适用于跨平台开发。.NET8的发布,ImageSharp成为了一个更好的选择。ImageSharp是一个完全开源、高性能且跨平台的图像处理库,专为.NET设计。它提供丰富的图像处理功能和出色的性能,适用于桌面应用、Web应用和云端服务。......
  • 【C++】踏上C++学习之旅(五):auto、范围for以及nullptr的精彩时刻(C++11)
    文章目录前言1.auto关键字(C++11)1.1为什么要有auto关键字1.2auto关键字的使用方式1.3auto的使用细则1.4auto不能推导的场景2.基于范围的for循环(C++11)2.1范围for的语法2.2范围for的使用条件3.指针空值nullptr(C++11)3.1为什么会有nullptr这个关键字?前言本......