首页 > 其他分享 >模型选择 欠拟合与过拟合

模型选择 欠拟合与过拟合

时间:2023-07-29 17:11:42浏览次数:33  
标签:loss num 模型 选择 train 拟合 test net

# 模型选择 欠拟合与过拟合

# 创建数据集
from mxnet import autograd
from mxnet import ndarray as nd
from mxnet import gluon

num_train = 100
num_tset = 100

true_w = [1.2, -3.4, 5.6]
true_b = 5.0

# 生成数据集

X = nd.random_normal(shape=(num_train + num_tset, 1))
x = nd.concat(X, nd.power(X, 2), nd.power(X, 3))
y = true_w[0] * x[:, 0] + true_w[1] * x[:, 1] + true_w[2] * x[:, 2] + true_b
y += .1 * nd.random_normal(shape=y.shape)
y_train, y_test = y[:num_train], y[num_train:]
print('X:', X[:5], 'x:', x[:5], 'y:', y[:5])

# 定义训练和测试步骤
import matplotlib as mpl

mpl.rcParams['figure.dpi'] = 120
import matplotlib.pyplot as plt


def square_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2


def test(net, x, y):
    return square_loss(net(x), y).mean().asscalar()


def train(x_train, x_test, y_train, y_test):
    # 定义模型
    net = gluon.nn.Sequential()
    with net.name_scope():
        net.add(gluon.nn.Dense(1))  # 线性回归模型,只有一个输出单元
    net.initialize()  # 初始化模型的参数

    # 定义训练参数
    learning_rate = 0.01
    epochs = 100
    batch_size = 10

    # 创建数据迭代器
    dataset_train = gluon.data.ArrayDataset(x_train, y_train)
    data_iter_train = gluon.data.DataLoader(dataset_train, batch_size, shuffle=True)

    # 定义优化器和损失函数
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': learning_rate})
    square_loss = gluon.loss.L2Loss()  # 使用均方误差损失函数

    train_loss = []  # 用于记录训练集上的损失值
    test_loss = []  # 用于记录测试集上的损失值

    # 开始训练模型
    for e in range(epochs):
        for data, label in data_iter_train:
            with autograd.record():
                output = net(data)  # 前向传播计算输出
                loss = square_loss(output, label)  # 计算损失
            loss.backward()  # 反向传播求梯度
            trainer.step(batch_size)  # 更新模型参数

        # 计算并记录每个epoch结束后的训练集和测试集上的损失值
        train_loss.append(square_loss(net(x_train), y_train).mean().asscalar())
        test_loss.append(square_loss(net(x_test), y_test).mean().asscalar())

    # 绘制训练集和测试集上的损失曲线
    plt.plot(train_loss)
    plt.plot(test_loss)
    plt.legend(['train', 'test'])
    plt.show()

    # 返回学到的权重和偏置
    return ('learned weight', net[0].weight.data(),
            'learned bias', net[0].bias.data())


# 三阶多项式拟合-正常

train(x[:num_train, :], x[num_train:, :], y[:num_train], y[num_train:])

# 欠拟合 -->数据不够

train(X[:num_train, :], X[num_train:, :], y[:num_train], y[num_train:])

# 过拟合 -->训练量不足

train(x[0:2, :], x[num_train:, :], y[0:2], y[num_train:])

 

标签:loss,num,模型,选择,train,拟合,test,net
From: https://www.cnblogs.com/o-Sakurajimamai-o/p/17590113.html

相关文章

  • SAP Fiori Elements 应用 OData 元数据请求 url 里的模型名称决定逻辑
    问题我用yarnstart本地启动一个SAPFioriElements应用,在Chrome开发者工具network面板,观察到一个ODatametadata请求的url如下:http://localhost:8080/sap/opu/odata/sap/SEPMRA_PROD_MAN/$metadata?sap-value-list=none&sap-language=EN这个OData服务名称SEPM......
  • 基于C语言设计的全局光照明模型
    完整资料进入【数字空间】查看——搜索"writebug"Part1Whitted-StyleRayTracingStep0.算法流程为了渲染出一张图片,RayTrace()计算了给定像素点的色彩取值。根据光路可逆原理,可以从人眼作为出发点,沿着指向该pixel的某一点的方向发出一条ray。Step1:射线求交这条ray会碰到一个......
  • OSI(Open Systems Interconnection)的五层(七层)模型
    OSI(OpenSystemsInterconnection)是一个用于计算机网络通信的参考模型,由国际标准化组织(ISO)于1984年提出。它将计算机网络通信过程划分为七个不同的层次,从物理传输层到应用层,每个层次都有其特定的功能和任务。然而,常见的网络模型实际上是TCP/IP模型,它是OSI模型的一种实际应用。TCP/......
  • xshell连接liunx服务器身份验证不能选择password
    ssh用户身份验证不能选择password 只能用publickey的解决办法 问题现象使用密码通过Workbench或SSH方式(例如PuTTY、Xshell、SecureCRT等)远程登录ECS实例时,遇到服务器禁用了密码登录方式错误. 可能原因该问题是由于SSH服务对应配置文件/etc/ssh/sshd_config中的参数Pa......
  • 推荐带500创作模型的付费创作V2.1.0独立版系统源码
    ChatGPT付费创作系统V2.1.0提供最新的对应版本小程序端,上一版本增加了PC端绘画功能,绘画功能采用其他绘画接口–意间AI,本版新增了百度文心一言接口。后台一些小细节的优化及一些小BUG的处理,前端进行了些小细节优化,针对上几版大家非常关心的卡密兑换H5端及小程序端......
  • 3DSOM软件基于物体的照片构建空间三维模型的方法
      本文介绍基于3DSOM软件,实现侧影轮廓方法的空间三维模型重建。(基于3DSOM的侧影轮廓方法空间三维模型重建)  我们首先从侧影轮廓建模方法开始,对空间三维建模的一些内容加以介绍。本文我们将基于3DSoftwareObjectModeler(3DSOM)这一软件,对上述方法加以完整的操作,并对结果加......
  • AIGC与NLP大模型实战-经典CV与NLP大模型及其下游应用任务实现
    点击下载:AIGC与NLP大模型实战-经典CV与NLP大模型及其下游应用任务实现提取码:hqq8当今社会是科技的社会,是算力快速发展的时代。随着数据中心、东数西算、高性能计算、数据分析、数据挖掘的快速发展,大模型得到了快速地发展。大模型是“大算力+强算法”相结合的产物,是人工智能的发展......
  • Switch选择结构
    Switch选择结构shift+tab:反向缩进tab:缩进打开项目结构快捷键:ctrl+Alt+Shift+s/IDEA窗口—>File—>ProjectStructure打开文件夹:OPenin然后点Explorer就文件夹了多选择结构还有一个实现方式就是switchcase语句。switchcase语句判断一个变量与一系列值中某个值是......
  • softmax回归模型——pytroch版
    importtorchfromIPythonimportdisplayfromd2limporttorchasd2l#fromd2l.mxnetimportAccumulatorbatch_size=256#每次读256张图片,返回训练iter和测试itertrain_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)num_inputs=784num_outputs......
  • softmax回归模型simple——pytroch版
    importtorchfromtorchimportnnfromd2limporttorchasd2lbatch_size=256train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)#PyTorch不会隐式地调整输入的形状。因此,#我们在线性层前定义了展平层(flatten),来调整网络输入的形状net=nn.Sequenti......