首页 > 其他分享 >使用libtorch训练一个异或逻辑门

使用libtorch训练一个异或逻辑门

时间:2024-03-13 09:25:55浏览次数:17  
标签:LOSS 逻辑 features int torch libtorch 异或 LOOP Tensor

本文以一个例子介绍如何使用libtorch创建一个包含多层神经元的感知机,训练识别异或逻辑。即${ z = x \text{^} y }$。本例的测试环境是VS2017和libtorch1.13.1。从本例可以学到如何复用网络结构,如下方的LinearSigImpl类的写法。该测试网络结构如下图。一个线性层2输入3输出,一个Sigmoid激活函数3输入3输出,一个线性输出层:

头文件代码如下:

class LinearSigImpl : public torch::nn::Module
{
public:
    LinearSigImpl(int intput_features, int output_features);
    torch::Tensor forward(torch::Tensor x);

private:
    torch::nn::Linear ln;
    torch::nn::Sigmoid bn;
};

TORCH_MODULE(LinearSig);

class Mlp : public torch::nn::Module
{
public:
    Mlp(int in_features, int out_features);
    torch::Tensor forward(torch::Tensor x);

private:
    LinearSig ln1;
    torch::nn::Linear output;
};

CPP文件:

LinearSigImpl::LinearSigImpl(int in_features, int out_features) : 
    ln(nullptr), bn(nullptr)
{
    ln = register_module("ln", torch::nn::Linear(in_features, out_features));
    bn = register_module("bn", torch::nn::Sigmoid());
}

torch::Tensor LinearSigImpl::forward(torch::Tensor x)
{
    x = ln->forward(x);
    x = bn->forward(x);
    return x;
}

Mlp::Mlp(int in_features, int out_features) : 
    ln1(nullptr), output(nullptr)
{
    ln1 = register_module("ln1", LinearSig(in_features, 3));
    output = register_module("output", torch::nn::Linear(3, out_features));
}

torch::Tensor Mlp::forward(torch::Tensor x)
{
    x = ln1->forward(x);
    x = output->forward(x);
    return x;
}

int main()
{
    Mlp linear(2, 1);

    /* 30个样本。在这里是一行一个样本 */
    at::Tensor b = torch::rand({ 30, 2 });
    at::Tensor c = torch::zeros({ 30, 1 });
    for (int i = 0; i < 30; i++)
    {
        b[i][0] = (b[i][0] >= 0.5f);
        b[i][1] = (b[i][1] >= 0.5f);
        c[i] = b[i][0].item().toBool() ^ b[i][1].item().toBool();
    }

    //cout << b << endl;
    //cout << c << endl;

    /* 训练过程 */
    torch::optim::SGD optim(linear.parameters(), torch::optim::SGDOptions(0.01));
    torch::nn::MSELoss lossFunc;
    linear.train();
    for (int i = 0; i < 50000; i++)
    {
        torch::Tensor predict = linear.forward(b);
        torch::Tensor loss = lossFunc(predict, c);
        optim.zero_grad();
        loss.backward();
        optim.step();
        if (i % 2000 == 0)
        {
            /* 每2000次循环输出一次损失函数值 */
            cout << "LOOP:" << i << ",LOSS=" << loss.item() << endl;
        }
    }
    /* 非线性的网络就不输出网络参数了 */
    /* 太过玄学,输出也看不懂 */

    /* 做个测试 */
    at::Tensor x = torch::tensor({ { 1.0f, 0.0f }, { 0.0f, 1.0f }, { 1.0f, 1.0f }, { 0.0f, 0.0f} });
    at::Tensor y = linear.forward(x);
    cout << "输出为[1100]=" << y;

    /* 看看能不能泛化 */
    x = torch::tensor({ { 0.9f, 0.1f }, { 0.01f, 0.2f } });
    y = linear.forward(x);
    cout << "输出为[10]=" << y;

    return 0;
}

控制台输出如下。如果把0.5作为01分界线,从输出上看网络是有一定的泛化能力的。当然每次运行输出数字都不同,绝大多数泛化结果都正确:

LOOP:0,LOSS=1.56625
LOOP:2000,LOSS=0.222816
LOOP:4000,LOSS=0.220547
LOOP:6000,LOSS=0.218447
LOOP:8000,LOSS=0.215877
LOOP:10000,LOSS=0.212481
LOOP:12000,LOSS=0.207645
LOOP:14000,LOSS=0.199905
LOOP:16000,LOSS=0.187244
LOOP:18000,LOSS=0.168875
LOOP:20000,LOSS=0.145476
LOOP:22000,LOSS=0.118073
LOOP:24000,LOSS=0.087523
LOOP:26000,LOSS=0.0554768
LOOP:28000,LOSS=0.0280211
LOOP:30000,LOSS=0.0109953
LOOP:32000,LOSS=0.00348786
LOOP:34000,LOSS=0.000959343
LOOP:36000,LOSS=0.000243072
LOOP:38000,LOSS=5.89887e-05
LOOP:40000,LOSS=1.40228e-05
LOOP:42000,LOSS=3.3041e-06
LOOP:44000,LOSS=7.82167e-07
LOOP:46000,LOSS=1.85229e-07
LOOP:48000,LOSS=4.43763e-08
输出为[1100]= 0.9999
 1.0000
 0.0002
 0.0001
[ CPUFloatType{4,1} ]输出为[10]= 0.9999
 0.4588
[ CPUFloatType{2,1} ]

 

标签:LOSS,逻辑,features,int,torch,libtorch,异或,LOOP,Tensor
From: https://www.cnblogs.com/mengxiangdu/p/18023716

相关文章

  • 41. 抽卡面板的实际逻辑
    本节目标打开抽卡面板之后,显示三张可以抽取的卡牌,点击选择卡牌之后将卡牌添加到牌堆中,然后游戏胜利面板隐藏抽卡按钮实现方法添加卡牌的UIDocument需要给Card也添加turnbutton样式,这样鼠标移动过去的时候会有放大的效果抽卡面板调试布局调试的时候,我们可以把Project......
  • R语言逻辑回归、决策树、随机森林、神经网络预测患者心脏病数据混淆矩阵可视化
    全文链接:https://tecdat.cn/?p=33760原文出处:拓端数据部落公众号概述:众所周知,心脏疾病是目前全球最主要的死因。开发一个能够预测患者心脏疾病存在的计算系统将显著降低死亡率并大幅降低医疗保健成本。机器学习在全球许多领域中被广泛应用,尤其在医疗行业中越来越受欢迎。机器......
  • 39. 对战胜负逻辑
    本节目标让用户从地图场景进入对战场景,然后对战胜利或失败之后收起卡牌实现过程让用户从地图场景进入对战场景Persistent场景隐藏Player和GameplayPanelHierarchy窗口中移除对战场景,添加map场景清除MapLayoutSO中的数据然后启动游戏,将Map场景设置为激活然后......
  • Unity3D 逻辑服的ECS框架设计架构与原理详解
    ECS(Entity-Component-System)是一种游戏开发架构模式,它将游戏对象划分为实体(Entity)、组件(Component)和系统(System),并通过数据驱动的方式来实现游戏逻辑。在Unity3D中,ECS框架的设计架构与原理是非常重要的,本文将详细介绍Unity3D逻辑服的ECS框架设计架构与原理,并给出技术详解以及代码实......
  • 256. 最大异或和
    可持久化字典树#include<iostream>#include<stdio.h>#include<algorithm>#include<string>#include<cmath>#defineFor(i,j,n)for(inti=j;i<=n;++i)usingnamespacestd;constintN=6e5+5,M=N*24;intn,m,......
  • SAP中五个报废率的计算逻辑
    废话不多说,SAP中有几个地方都有报废率的字段,对应到不同的业务场景,这些不同的报废字段会起到不同的作用,希望能通过这篇博文能整理出这些报废字段的逻辑,以及适用的业务场景.首先看看哪些地方有报废率字段,要注意一点,既然讲到报废率,它们的单位都是百分号%:1.物料主数据MRP1视......
  • 38. 敌人的动画执行逻辑
    本节目标上节只是把逻辑写好了,为了让游戏看起来更好看,我们需要将敌人的动画也加上去实现动画状态机敌人的状态比较简单,只有站立、加Buff、攻击、受伤、死亡这五种状态各状态之间的转换关系如下站立->加Buff通过skill触发,立刻执行因为是从AnyState出来的,所以不......
  • 37. 敌人意图 AI 逻辑
    本节目标在玩家回合,需要显示敌人的意图,然后在敌人回合执行意图代码实现拼UI在HealthBar上面添加意图的图片和文字然后在HealthBarController上面添加意图图片和意图文字添加敌人意图敌人意图ScriptableObject敌人意图实例Effect把意图添加到Enemy类......
  • 逻辑卷
    linux扩容vg空间******************[root@ymgit01~]#fdisk-l[root@ymgit01~]#fdisk/dev/sdbnpt8epw[root@ymgit01~]#partprobe[root@ymgit01~]#pvcreate/dev/sdcPhysicalvolume"/dev/sdb1"successfullycreated[root@ymgit01~]#vgextendap......
  • libtorch入门例程
    libtorchC++版可以直接在官网下载。自己学习如果没有合适的显卡可以选择下载CPU版的。下面是官网链接:PyTorch下载后就可以把开发包包含到VS的项目中使用。注意libtorch官网提供的Release/Debug的开发包,Debug版的程序用Debug版的库,Release版的程序用Release版的库,不能混用。另......