首页 > 其他分享 >机器学习-线性回归-梯度下降法-03

机器学习-线性回归-梯度下降法-03

时间:2023-12-11 20:23:56浏览次数:34  
标签:03 梯度 样本 50 np 线性 theta 100

1. 梯度下降法

梯度:


是一个theta 与 一条样本x 组成的 映射公式


可以看出梯度的计算量主要来自于 左边部分

所有样本参与 -- 批量梯度下降法
随机抽取一条样本参与 -- 随机梯度下降法
一小部分样本参与 -- 小批量 梯度下降法

2. epoch 与 batch

epoch:一次迭代 w t --> w t+1
batch:一次迭代中 一部分样本

例如: 背诵唐诗300首
循环背诵20次
n_sample = 300 (样本总量)
n_epoch = 20 (循环20次)

一次300首太多了 50首 50首的进行
batch_size = 50 (批次的大小 50)
n_batch = 6 (6个批次)

3. 代码实现

import numpy as np

# 创建数据集
X = 2 * np.random.rand(100, 1)  # (100, 1)  0-1之间均匀分布
y = 4 + 3 * X + np.random.randn(100, 1)  # 这里的 4,3需要求解的

X_b = np.c_[np.ones((100, 1)), X]

# 超参
learn_rate = 0.001
n_iterations = 10000


# 初始化theta
theta = np.random.randn(2, 1)


for _ in range(n_iterations):
    gradient = X_b.T.dot((X_b.dot(theta) - y))  # 梯度的计算公式
    theta = theta - learn_rate*gradient  # 梯度下降法 为什么将梯度下降法 因为梯度的前面 有个负号

print(theta)  # 可以看到 接近于 (4, 3)


标签:03,梯度,样本,50,np,线性,theta,100
From: https://www.cnblogs.com/cavalier-chen/p/17895462.html

相关文章

  • 机器学习-线性回归-模型解析解-02
    1.解析解解析解的公式importnumpyasnpimportmatplotlib.pyplotasplt#有监督机器学习#XyX=2*np.random.rand(100,1)#np.random.rand#100行1列的[0,1)之间均匀分布*2之后则变成[0,2)之间均匀分布e=np.random.randn(100,1)#误差均值0......
  • [持续更新][数据结构][算法]涵盖线性表、栈、链表、队列、图、动态规划、分治递归、回
    备考考点整理内部排序表格树的主要考点二叉树的常考紧紧抓住\(n_0=n_2+1\)\(n=n_0+n_1+n_2...n_m\)\(n=n_1+2*n_2+3*n_3...m*n_m\)+1哈夫曼树没有度为1的结点,也就是\(n_1=0\)完全二叉树常考总结最大岛屿问题(dfs模板)#include<iostream>#include<algorith......
  • CF1764H Doremy's Paint 2 题解
    题目链接先断环成链,由于对于多组询问不好一起处理,我们先考虑单组询问的处理方式。一个很暴力的想法是每次模拟题目要求的操作并且最后数颜色,我们这是在通过下标进行操作最后再数颜色,而很多对于下标的操作都是不必要的,考虑直接枚举颜色进行判定。对于每种颜色,它对于最后答案有贡......
  • Si24R03—低功耗 SOC 芯片(集成RISC-V内核+2.4GHz无线收发器)
    Si24R03是一款高度集成的低功耗SOC芯片,其集成了基于RISC-V核的低功耗MCU和工作在2.4GHzISM频段的无线收发器模块。MCU模块具有低功耗、LowPinCount、宽电压工作范围,集成了13/14/15/16位精度的ADC、LVD、UART、SPI、I2C、TIMER、WUP、IWDG、RTC等丰富的外设。内核采用RISC-VRV......
  • What's new in Pika v3.5.2
    Pika社区近期发布了备受期待的v3.5.2版本https://github.com/OpenAtomFoundation/pika/releases/tag/v3.5.2-alpha,不仅解决了历史遗留的Bug问题,还引入了多项新特性。这些新特性主要包括Pika支持Redis事务、Pika上层增加缓存层实现冷热数据分离、提升读性能、Codis-Proxy......
  • MBR60300PT-ASEMI大电流肖特基二极管MBR60300PT
    编辑:llMBR60300PT-ASEMI大电流肖特基二极管MBR60300PT型号:MBR60300PT品牌:ASEMI封装:TO-247正向电流:60A反向电压:300V引线数量:3芯片个数:2芯片尺寸:150MIL漏电流:<10ua恢复时间:5ns浪涌电流:500A芯片材质:正向电压:0.85V~0.90V工作结温:-40℃~175℃包装方式:500/箱MBR60300PT......
  • 奥特曼被指爱权力胜过金钱;人类才是「幻觉问题」根本原因丨 RTE 开发者日报 Vol.103
       开发者朋友们大家好:这里是「RTE开发者日报」,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享RTE(RealTimeEngagement)领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「有看点的会议」,但内......
  • Confluence7.4.6突然爆事务隔离级别问题-解决方案-MySQL session isolation level 'RE
    MySQLsessionisolationlevel'REPEATABLE-READ'isnolongersupported.Sessionisolationlevelmustbe'READ-COMMITTED'.Seehttp://confluence.atlassian.com/x/GAtmDg  成功解决方案:查看http://confluence.atlassian.com/x/GAtmDgFORMYSQL8.X......
  • Linux学习35- python3.9出现ModuleNotFoundError: No module named '_ctypes'的解决
    遇到问题pip安装第三方库的时候报错ModuleNotFoundError:Nomodulenamed'_ctypes'File"/usr/local/python3/lib/python3.9/ctypes/__init__.py",line7,in<module>from_ctypesimportUnion,Structure,ArrayModuleNotFoundError:Nomodulen......
  • 【算法】【线性表】两个排序数组的中位数
    1 题目两个排序的数组A和B分别含有m和n个数,找到两个排序数组的中位数,要求时间复杂度应为O(log(m+n))。中位数的定义:这里的中位数等同于数学定义里的中位数。中位数是排序后数组的中间值。如果有数组中有n个数且n是奇数,则中位数为 A((n-1)/2)。如果有数组中有n个数且n......