首页 > 编程语言 >Python最速下降法实例

Python最速下降法实例

时间:2022-10-23 20:36:22浏览次数:42  
标签:plt Python 实例 vec x2 X0 x1 grad 最速

最速下降法的实现需要通过符号计算。

首先笔算一步如下,然后通过程序验证:

 

python程序如下,需要pip install sympy:

import numpy as np
from sympy import *
import math
import matplotlib.pyplot as plt# 定义符号
x1, x2, t = symbols('x1, x2, t')

def func():
    # 自定义一个函数
    return (x1-2) * (x1-2) + 4 * (x2-3) * (x2-3)

def grad(data):
    f = func()
    grad_vec = [diff(f, x1), diff(f, x2)]
    grad = []
    for item in grad_vec:
        grad.append(item.subs(x1, data[0]).subs(x2, data[1]))
    return grad

def grad_len(grad):
    vec_len = math.sqrt(pow(grad[0], 2) + pow(grad[1], 2))
    return vec_len

def zhudian(f):
    t_diff = diff(f)
    t_min = solve(t_diff)
    return t_min

绘制目标函数的图形如下:

def func_numeric(x):
    # 自定义一个函数
    return (x[0]-2) * (x[0]-2) + 4 * (x[1]-3) * (x[1]-3)
plt.figure(3)
plt_x = np.arange(0, 4, 0.05)
plt_y = np.arange(0, 4, 0.05)
X, Y = np.meshgrid(plt_x, plt_y)
Z1 = func_numeric([X, Y])
ax = plt.gca(projection='3d')
ax.plot_surface(X, Y, Z1, cmap=plt.get_cmap('rainbow'), linewidth=0.2)
plt.title("steepest descent method")
plt.xlabel("x1")
plt.ylabel("x2")
plt.show()

进行迭代计算:

X0 = [0, 0]
theta = 0.00001

data_x = [X0[0]]
data_y = [X0[1]]
data_f = [func().subs(x1, X0[0]).subs(x2, X0[1])]

f = func()
grad_vec = grad(X0)
print(f"grad: {grad_vec}")
grad_length = grad_len(grad_vec)  # 梯度向量的模长
print('grad_length', grad_length)
k = 1
while grad_length > theta:  # 迭代的终止条件
    k += 1
    p = -np.array(grad_vec)
    # 迭代
    X = np.array(X0) + t*p
    t_func = f.subs(x1, X[0]).subs(x2, X[1])
    t_min = zhudian(t_func)
    print(f"t_min: {t_min}")
    X0 = np.array(X0) + t_min*p
    print(f"X0: {X0}")
    # print(floatat(X0[0]))
    print(f"fx0: {f.subs(x1, float(X0[0])).subs(x2, float(X0[1]))}")
    grad_vec = grad(X0)
    print(f"grad: {grad_vec}")
    grad_length = grad_len(grad_vec)
    print('grad_length', grad_length)
    print('坐标', X0[0], X0[1])
    data_x.append(X0[0])
    data_y.append(X0[1])
    data_f.append(f.subs(x1, float(X0[0])).subs(x2, float(X0[1])))

打印的第一步如下:

grad: [-4, -24]
grad_length 24.331050121192877
t_min: [37/290]
X0: [74/145 444/145]
fx0: 2.2344827586206

和笔算结果一致。

绘制等高线图和迭代过程折线图:

# 绘图
fig = plt.figure(4)
plt.title(r'$Gradient \ method - steepest \ descent \ method$')
plt.plot(data_x, data_y, color='k', label=r'$f(x_1,x_2)=(x_1-2)^2+4*(x_2-3)^2$')

def func_numeric(x):
    # 自定义一个函数
    return (x[0]-2) * (x[0]-2) + 4 * (x[1]-3) * (x[1]-3)
plt_x = np.arange(0, 4, 0.05)
plt_y = np.arange(0, 4, 0.05)
X, Y = np.meshgrid(plt_x, plt_y)
Z1 = func_numeric([X, Y])
# ctf = plt.contourf(X, Y, Z1, 15)
ct = plt.contour(X, Y, Z1)
plt.clabel(ct, inline=True, fontsize=10)
# plt.colorbar(ctf)

plt.legend()
# plt.scatter(1, 1, marker=(5, 1), c=5, s=1000)
# plt.grid()
plt.xlabel(r'$x_1$', fontsize=20)
plt.ylabel(r'$x_2$', fontsize=20)
plt.show()

如图所示,结果正确:

 

标签:plt,Python,实例,vec,x2,X0,x1,grad,最速
From: https://www.cnblogs.com/zhaoke271828/p/16819412.html

相关文章

  • ParserWarning: Falling back to the 'python' engine because the 'c' engine does n
    Python3.9.10,Window64bit   警告:ParserWarning:Fallingbacktothe'python'enginebecausethe'c'enginedoesnotsupportregexseparators(separators......
  • 抽象类的实例应用
    1.员工packagebigguy;/***@authorliu$*@version1.0*@description:TODO*@date$$*/publicabstractclassEmployee{privateStringname;private......
  • Python实验报告(第7周)
    实验7:面向对象程序设计一、实验目的和要求1、了解面向对象的基本概念(对象、类、构造方法);2、学会类的定义和使用;3、掌握属性的创建和修改;4、掌握继承的基本语法。 ......
  • [Python]学习笔记之- __name__ == '__main__'
     if__name__=='__main__':大多数规范的Python源码中都可以看到这个语句,初学者可能不清楚这句话的用处。这句代码的字面意思就是在做判断__name__是否为'__main__'。这......
  • Python安装OCR识别库tesserocr_pytesseract教程
    Python安装OCR识别库tesserocr1.tesserocr下载https://digi.bib.uni-mannheim.de/tesseract/尽量选不带dev的版本,dev是开发版本,不带dev的是稳定版个人配置tesseract-......
  • Python 用户输入
    1.input()输入【实例】:data=input("请输入:")print(data,type(data))【运行结果】:请输入:3131<class'str'> 2.int类型转换从input()函数输入的内容都是str......
  • Python 字典
    目录导航1.一个简单的字典2.添加键值对3.创建空字典4.修改字典中的值5.删除键值对6.使用get()来访问值7.遍历键值对8.遍历字典的......
  • python7
    一、创建大雁类并定义飞行方法classGeese:'''大雁类'''def__init__(self,beak,wing,claw):print("我是大雁类!我有以下特征:")print(bea......
  • python模块、异常处理、软件开发目录规范总结
    本周总结异常处理生成器模块软件开发目录1.异常处理1.1异常处理语法结构1.基本语法 try:待检测的代码(可能会出错的代码)except错误......
  • Spring —— bean实例化
    bean实例化bean本质上就是对象,创建bean使用构造方法完成(反射)    构造方法(常用)      静态工厂*      实例工厂*      FactoryBean(实用)......