首页 > 其他分享 >共轭梯度法示例代码

共轭梯度法示例代码

时间:2022-10-04 22:13:20浏览次数:43  
标签:示例 int 梯度 df minimum double found 共轭

当初第一次推共轭梯度法公式,觉得有点难,现在再回头看,共轭梯度法真是个漂亮的代码。最近写了个示例代码,录在这里吧。

参考文献:
[1] 马红儒,“计算物理讲义”

1. 算法

Ref. [1],以及很多教材中有很多介绍,我自己写的讲义里也有,所以懒得在这里做叙述了,好像有本 painless conjugate gradient 口碑挺好。
我就简单把做法在这里说一下吧。

  • 次取 x0, 沿 p0 = - g0 (g0 为 x0 点梯度),进行一维优化。
  • 对于每次循环,即从 xi 出发,沿 pi 方向进行一维优化,找到 pi 方向的最低点 x_{i+1}
  • 计算 x_{i+1} 点的梯度 g_{i+1},计算 alpha = N[ g_{i+1} ] / N[ g_i ],据此计算 p_{i+1} = alpha * p_i - g_i
  • 某次一维最优化以后,梯度 norm 小于某个值,结束迭代。

2. 示例代码


#include<iostream>
using namespace std;

#include<cmath>
#define max_iteration 100

double f( int n, double * x ){

        //return x[0] * x[0] + x[1] * x[1];
        return 10.0*(x[0]-1)*(x[0]-1) + 20.0 * (x[1]-2) * (x[1]-2) + 30;
}

void df( int n, double * x, double * df ){
        //df[0] = 2 * x[0]; df[1] = 2 * x[1];
        df[0] = 20.0 * (x[0]-1); df[1] = 40.0 * (x[1]-2);
}

void PrintVec( int n, double * x ){
        cout<<"[ ";
        for(int i=0;i<n;i++)cout<<x[i]<<"  "; cout<<"]\n";
}

double VecProd( int n, double * a, double * b)
{
        double y = 0; for(int i=0;i<n;i++) y += a[i] * b[i];
        return y;
}

bool VecBetween( int n, double * a, double * b, double * d ){
        double y=0;
        for(int i=0;i<n;i++){
                y += (d[i]-a[i]) * (d[i]-b[i]);
        }
        return y < 0 ? true : false;
}

double VecDistance( int n, double * a, double * b )
{
        double y = 0; for(int i=0;i<n;i++) y += (a[i] - b[i]) * (a[i] - b[i]);
        return sqrt(y);
}

double VecNorm( int n, double * a )
{
        double y = 0; for(int i=0;i<n;i++) y += a[i] * a[i];
        return y;
}

double op1d( double (*f)(int, double *), int n, double * x0, double * p, double step, double precision )
{
        double * a = new double [n]; for(int i=0;i<n;i++) a[i] = x0[i];
        double fa = f(n, x0);
        double * b = new double [n]; for(int i=0;i<n;i++) b[i] = a[i] + step * p[i] / sqrt( VecNorm(n, p) );
        double fb = f(n, b);
        for(int i=0;i<10;i++)
        {
                if( fb > fa )
                {
                        step /= 10;
                        for(int i=0;i<n;i++) b[i] = a[i] + step * p[i] / sqrt( VecNorm(n, p) );
                        fb = f(n, b);
                }
                else break;
        }
        if( fb > fa ){
                cout<<" op1d: f( x0 + step/10^10 * p/|p| ) > f( x0 ), that means x0 is already close to a local minimum. "<<endl;
                exit(1);
        }

        // find a suspicious zone
        double * c = new double [n]; double fc;
        int i;
        for(i=0;i<max_iteration;i++)
        {
                for(int j=0;j<n;j++) c[j] = b[j] + 1.618 * (b[j] - a[j]);
                fc = f( n, c ); if( fc > fb ) break; // suspicious zone found: [a,c]
                for(int j=0;j<n;j++)
                {
                        a[j] = b[j]; b[j] = c[j];
                }
                fa = fb; fb = fc;
        }
        if( i == max_iteration ){
                cout<<" op1d: failed to find a suspicious zone. \n"; exit(1);
        }
        cout<<" After "<<i+1<<" iterations, found a suspicious zone : \n";
        cout<<" fa = "<<fa<<"  a = "; PrintVec(n, a);
        cout<<" fb = "<<fb<<"  b = "; PrintVec(n, b);
        cout<<" fc = "<<fc<<"  c = "; PrintVec(n, c);

        // narrow the suspicious zone [a, c], to find the 1d minimum
        double * d = new double [n]; double fd;
        for(i=0;i<max_iteration;i++)
        {
                for(int j=0;j<n;j++) d[j] = a[j] + 0.618 * (c[j] - a[j] );
                fd = f( n, d );
                if( VecBetween( n, a, b, d ) )
                {
                        if( fd < fb ) // [a, b, c] -> [a, d, b]
                        {
                                for(int j=0;j<n;j++)
                                {
                                        c[j] = b[j]; b[j] = d[j];
                                }
                                fc = fb; fb = fd;
                        }
                        else // [a, b, c] -> [d, b, c]
                        {
                                for(int j=0;j<n;j++) a[j] = d[j];
                                fa = fd;
                        }
                }
                else
                {
                        if( fd < fb ) // [a, b, c] -> [b, d, c]
                        {
                                for(int j=0;j<n;j++)
                                {
                                        a[j] = b[j]; b[j] = d[j];
                                }
                                fa = fb; fb = fd;
                        }
                        else // [a, b, c] -> [a, b, d]
                        {
                                for(int j=0;j<n;j++)
                                {
                                        c[j] = d[j];
                                }
                                fc = fd;
                        }
                }
                if( VecDistance( n, a, c ) < precision )break;
        }
        for(i=0;i<n;i++) x0[i] = (a[i] + c[i])/2;
        cout<<" After "<<i+1<<" iterations, a 1d minimum is found.\n";
        delete [] a; delete [] b; delete [] c; delete [] d;
        return f(n, x0);
}


void cgm( double (*f)(int, double *), void (*df)(int, double *, double * ), int n, double * x0, double step, double precision, double epsabs ){

        double * g = new double [n]; df( n, x0, g );
        double * p = new double [n]; for(int i=0;i<n;i++) p[i] = -g[i];

        int k;
        for( k=0;k<max_iteration;k++)
        {
                double alpha = 1 / VecNorm(n, g);
                cout<<"-------------------------------------------------\n";
                cout<<" descend direction p = "; PrintVec(n, p);
                double f1dm = op1d( f, n, x0, p, step, precision );
                cout<<" op1d minimum: f = "<< f1dm <<", x = "; PrintVec(n, x0);
                cout<<"-------------------------------------------------\n";
                df( n, x0, g );
                if( sqrt(VecNorm(n, g)) < epsabs ) break;
                alpha *= VecNorm(n, g);
                for(int i=0;i<n;i++) p[i] = alpha * p[i] - g[i];
        }
        if( k == max_iteration )
        {
                cout<<" cgm: failed to find a minimum after "<<max_iteration<<" iterations.\n";
        }
        else
        {
                cout<<" cgm : found a minimum after "<< k+1 << " 1d optimizations\n";
                cout<<" \t f = "<<f(n, x0)<<", x = "; PrintVec( n, x0 );
        }
        delete [] g; delete [] p;
}

int main(){
        double x0[] = {4,0}, p[] = {-4, 3};
        //op1d( f, 2, x0, p, 0.1, 1E-5 );
        cgm( f, df, 2, x0, 100, 1E-8, 1E-5 );
        return 0;
}

运行结果:

-------------------------------------------------
 descend direction p = [ -60  80  ]
 After 2 iterations, found a suspicious zone : 
 fa = 116.4  a = [ 3.4  0.8  ]
 fb = 50.6044  b = [ 2.4292  2.0944  ]
 fc = 126.012  c = [ 0.858446  4.18874  ]
 After 3 iterations, a 1d minimum is found.
 op1d minimum: f = 47.561, x = [ 2.17073  2.43902  ]
-------------------------------------------------
-------------------------------------------------
 descend direction p = [ -28.5544  -10.7079  ]
 After 1 iterations, found a suspicious zone : 
 fa = 47.561  a = [ 2.17073  2.43902  ]
 fb = 30.704  b = [ 1.2344  2.0879  ]
 fc = 51.011  c = [ -0.280578  1.51978  ]
 After 3 iterations, a 1d minimum is found.
 op1d minimum: f = 30, x = [ 1  2  ]
-------------------------------------------------
 cgm : found a minimum after 2 1d optimizations
 	 f = 30, x = [ 1  2  ]

可以看到,一共只需要两次一维优化,这符合理论预期——对于 n 元二次型函数只需要 n 次一维优化。

标签:示例,int,梯度,df,minimum,double,found,共轭
From: https://www.cnblogs.com/luyi07/p/16754603.html

相关文章

  • 李沐:用随机梯度下降来优化人生!
     Datawhale推荐 作者:李沐,CMU,亚马逊首席科学家1.要有目标。你需要有目标。短的也好,长的也好。认真定下的也好,别人那里捡的也好。就跟随机梯度下降需要有个目标函数一样。2.......
  • Springboot接收请求参数示例
    packagecom.example.demo.controller;importcom.example.demo.model.User;importorg.springframework.web.bind.annotation.*;importorg.springframework.web.mult......
  • vue-ant design示例大全——按钮本地css/js资源
    vue-antdesign示例大全——本地css/js资源示例资源来自官网:​​https://www.antdv.com/components/button-cn​​在AntDesignVue中我们提供了五种按钮。主按钮:用于主行......
  • Socket示例
    服务端代码:usingSystem;usingSystem.Collections.Generic;usingSystem.ComponentModel;usingSystem.Data;usingSystem.Drawing;usingSystem.Linq;usingSyst......
  • 0069-Tui-综合示例(一)
    环境Time2022-08-23Rust1.63.0Tui0.19.0前言说明参考:https://github.com/fdehau/tui-rs/tree/master/examples/demo目标实现tui-rs的综合示例程序,读取命令......
  • 0070-Tui-综合示例(二)
    环境Time2022-08-23Rust1.63.0Tui0.19.0前言说明参考:https://github.com/fdehau/tui-rs/tree/master/examples/demo目标实现tui-rs的综合示例程序,终端的开......
  • 0071-Tui-综合示例(三)
    环境Time2022-08-23Rust1.63.0Tui0.19.0前言说明参考:https://github.com/fdehau/tui-rs/tree/master/examples/demo目标实现tui-rs的综合示例程序,应用数据......
  • 0072-Tui-综合示例(四)
    环境Time2022-08-23Rust1.63.0Tui0.19.0前言说明参考:https://github.com/fdehau/tui-rs/tree/master/examples/demo目标实现tui-rs的综合示例程序。定义布......
  • 0061-Tui-迷你图示例
    环境Time2022-08-16Rust1.63.0Tui0.18.0前言说明参考:https://github.com/fdehau/tui-rs/blob/master/examples/sparkline.rs目标使用tui-rs显示迷你图。生......
  • 0064-Tui-图表示例
    环境Time2022-08-16Rust1.63.0Tui0.18.0前言说明参考:https://github.com/fdehau/tui-rs/blob/master/examples/chart.rs目标使用tui-rs显示图表。常量数据......