首页 > 其他分享 >线性表示代码

线性表示代码

时间:2025-01-09 10:56:24浏览次数:1  
标签:表示 loss get 梯度 代码 torch indices 线性 data

import torch
import matplotlib.pyplot as plt

Python 中用于导入 matplotlib 库并将其 pyplot 模块简称为 plt 的常见语句。matplotlib 是一个功能强大的绘图库,而 pyplot 是其提供的一个基于状态机的接口,用于创建各种类型的可视化图表

y = x * w + b

def create_data(w, b, data_num):  #生成数据
    x = torch.normal(0, 1, (data_num, len(w))) #长度是数据数量,w多长,x多宽
    y = torch.matmul(x, w) + b  #matmul表示矩阵相乘

    noise = torch.normal(0, 0.01, y.shape) #噪声加到y上 不可能那么准.通过这种方式,对原始数据 y 添加了随机噪声,模拟了数据在实际场景中可能受到的噪声干扰
    y += noise

    return x, y
num = 500 #生成500个数据

true_w = torch.tensor([8.1, 2, 2, 4]) #torch.tensor() 函数用于从 Python 列表创建一个 PyTorch 张量。在这个例子中,列表 [8.1, 2, 2, 4] 被转换为一个张量。
true_b = torch.tensor(1.1)

X, Y = create_data(true_w, true_b, num)

plt.scatter(X[:, 3], Y, 1) #画散点图 大小为1 标签Y 只取第三列 y:500*1
plt.show()
def data_provider(data, label, batchsize):  #每次访问这个函数,就能提供一批数据 500个数据,16个数据取一次loss
    length = len(label)
    indices = list(range(length)) #indices变为0-500的列表
    #我不能按顺序取  把数据打乱,原来indices[0]=0现在indices[0]=随机数
    random.shuffle((indices))

    for each in range(0, length, batchsize):
        get_indices = indices[each: each+batchsize] #0-15 16-31
        get_data = data[get_indices]
        get_label = label[get_indices]

        yield get_data, get_label   #有存档点的return each,0,16,32

以上,为取数据

def fun(x, w, b): #定义一个模型
    pred_y = torch.matmul(x, w) + b
    return pred_y #pred_y 就是模型对输入 x 的预测输出
def maeLoss(pred_y, y): #用于计算平均绝对误差(Mean Absolute Error,MAE)损失
    return torch.sum(abs(pred_y-y))/len(y) #得到平均loss

#梯度下降:参数减去参数的梯度乘以学习率
def sgd(paras, lr):   #随机梯度下降,更新参数
    with torch.no_grad():  #张量网上,所有的计算都会计算梯度.但不是所有梯度都是我们想要的,梯度更新回传时不需要更新梯度,输入这句代码的部分,不计算梯度
        for para in paras:
            para -= para.grad* lr  #不能写成para= para - para.
            para.grad.zero_() #在完成一次参数更新后,需要将参数的梯度清零。这是因为在 PyTorch 中,梯度是累加的,如果不清零,在下一次反向传播时,新计算的梯度会与之前的梯度累加,导致错误的更新。
lr = 0.3
w_0 = torch.normal(0, 0.01, true_w.shape, requires_grad=True)  #这个w需要计算梯度
b_0 = torch.tensor(0.01, requires_grad=True)
epochs = 50

for epoch in range(epochs):
    data_loss = 0
    for batch_x, batch_y in data_provider(X, Y, batchsize):
        pred_y = fun(batch_x, w_0, b_0) #得出预测的y
        loss = maeLoss(pred_y, batch_y) #batch_y是真实的y
        loss.backward() #调用 loss.backward() 进行反向传播。这一步会计算损失函数关于模型参数的梯度
        sgd([w_0, b_0], lr)
        data_loss += loss

    print("epoch %03d:loss: %.6f"%(epoch, data_loss))

print("真实的函数值是", true_w, true_b)
print("训练得到的参数值是", w_0, b_0)
idx = 3 #定义变量 idx 为 3,用于后续从数据集中选择特定的列。
plt.plot(X[:, idx].detach().numpy(), X[:, idx].detach().numpy()*w_0[idx].detach().numpy()+b_0.detach().numpy())
plt.scatter(X[:, idx], Y, 1) #detach() 方法将张量从计算图中分离,避免在转换为 numpy 数组时因计算图的存在引发问题,numpy() 方法将 torch.Tensor 转换为 numpy.ndarray,这部分数据作为线图的 x 轴数据。
plt.show()

标签:表示,loss,get,梯度,代码,torch,indices,线性,data
From: https://www.cnblogs.com/jyp02/p/18661716

相关文章

  • 全网最简单、免费的零代码平台一键 Docker 搭建,快速搭建应用
    本文档docker组中使用的镜像已经上传阿里云docker私服,方便安装。第一步:复制下面内容创建docker-compose.ymlservices:qiaoqiaoyun-mysql:image:registry.cn-hangzhou.aliyuncs.com/jeecgdocker/qiaoqiaoyun-mysql:2.0.1environment:MYSQL_ROOT_PA......
  • XTR105 XTR105UA/2K5规格书具有传感器激励和线性化的 4mA 至 20mA 电流变送器芯片
    XTR105是一款带有两个精准电流源的单片4mA至20mA、2线制电流发送器。该器件在一个单集成电路上提供针对铂RTD温度传感器和桥、仪表放大器以及电流输出电路的完整电流激励。多用途线性化电流提供一个对RTD的第二阶修正,通常可以实现一个40:1的线性改进。仪器放大器增益可......
  • JS将docx转为html代码--Vue3(简易版)
    这两天突然接了一个把节气文章转成html页面的需求,本来只是需要多按几下ctrl+c,ctrl+v能解决的事,但是想想后续一年24个节气,就做个自动转换的工具吧。由于做软件还涉及到其他语言和配置环境,所以还是选择了web。首先创建一个vue3项目,我用的vite搭建的,不会的请自行移步到vite官网。......
  • 【数据结构与算法】之线性表:栈和队列个人总结
    进度好慢呀!冲冲冲!希望能在17号之前过完一遍数据结构基础!现在也有在做题,但是做题好慢,有的看题解也不理解,......
  • BOOST 在计算机视觉方面的应用及具体代码分析(二)
    摘要: 本论文聚焦于BOOST库在计算机视觉领域的多元应用,深入探究其在图像预处理、目标识别、图像分割以及运动分析等关键任务中的作用机制。通过详实的代码剖析,揭示BOOST如何助力开发人员优化算法、提升性能,进而推动计算机视觉技术迈向新高度,为相关领域的研究与实践提供坚实......
  • StringBuilder练习项目代码及相关知识点
    1.动态字符串操作需求:编写一个程序,接收用户输入的多个单词,并将它们组合成一个完整的句子,同时支持以下功能:动态添加单词删除某些单词将句子反转importjava.util.Scanner;publicclassStringBuilderDemo{publicstaticvoidmain(String[]args){StringB......
  • 代码精简之路-模板模式
    1.前言程序员怕重复CRUD,总是做一些简单繁琐的事情。“不要重复造轮子”,“把基础功能提炼出来封装成工具类”我喜欢把这些话挂在嘴边,写起来常不知从何下手。下面拆解一个项目中的功能。记录从复制粘贴到对业务抽象、实现功能分层的详细过程。如何着手提升代码重构优化能力,拿到......
  • 腾讯云AI代码助手编程挑战赛-月事小助手
    作品简介通过腾讯云AI代码助手来每月记录例假日期和周期,为你推算出下次的月经期和排卵期,轻松安排经期!更加了解自己,一切从容应对!小助手还会有针对性地推送大姨妈小贴士,做最了解你的闺蜜!使用工具腾讯AI代码助手技术架构通过python的tk库来完成页面设计,通过代码来实现页面......
  • Python Mixin 模式:解锁代码复用的艺术
    在面向对象编程中,代码复用是一个至关重要的概念。它不仅能够减少重复劳动,还能提高代码的可维护性和一致性。Python作为一种高度灵活且功能强大的编程语言,提供了多种机制来支持代码复用,其中Mixin模式便是实现这一目标的一种重要手段。一、什么是Mixin?Mixin是一种设计模式和......
  • Codeforces Round 986 (Div. 2) CF2028 代码集
    CodeforcesRound986(Div.2)CF2028代码集目录CodeforcesRound986(Div.2)CF2028代码集CF2028A-Alice'sAdventuresin''Chess''CF2028B-Alice'sAdventuresinPermutingCF2028C-Alice'sAdventuresinCuttingCakeCF2024D-A......