首页 > 其他分享 >【rust】《Rust深度学习[4]-理解线性网络(Candle)》

【rust】《Rust深度学习[4]-理解线性网络(Candle)》

时间:2024-04-23 13:59:09浏览次数:21  
标签:线性网络 weight DType bias let 线性 Candle rust Tensor

全连接/线性

在神经网络中,全连接层,也称为线性层,是一种层,其中来自一层的所有输入都连接到下一层的每个激活单元。在大多数流行的机器学习模型中,网络的最后几层是完全连接的。实际上,这种类型的层执行基于在先前层中学习的特征输出类别预测的任务。

全连接层的示例,具有四个输入节点和八个输出节点。

全连接层在输入中接收在先前卷积层中激活的节点向量。这个向量在被发送到输出层之前,会经过一个或多个密集层。在到达输出层之前,使用激活函数进行预测。虽然卷积层和池化层通常使用ReLU函数,但基于分类问题的类型,全连接层可以使用两种类型的激活函数:

  Sigmoid:逻辑函数,用于二进制分类问题。

  Softmax:一个更广义的逻辑激活函数,它确保输出层中的值总和为1。通常用于多类分类。

依赖

[package]
name = "mnist-ml-linear"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
# 使用 cargo add [email protected] 下载
candle-core = "0.4.1"

代码

use candle_core::{DType, Device, Result, Tensor};

// 定义线性层
struct Linear {
    // 权重
    weight: Tensor,
    // 偏移量
    bias: Tensor,
}

// 线性层函数
impl Linear {
    fn forward(
        &self,
        x: &Tensor,
    ) -> Result<Tensor> {
        let x = x
            .contiguous()?
            // 将输入值乘以权重
            .matmul(&self.weight.contiguous()?)?;
        // 再加上偏移量
        x.broadcast_add(&self.bias)
    }
}

// 模型
struct Model {
    // 线性第一层
    first: Linear,
    // 线性第二层
    second: Linear,
}

// 模型函数
impl Model {
    fn forward(
        &self,
        image: &Tensor,
    ) -> Result<Tensor> {
        // 传入图片进行第一层线性分析
        let x = self.first.forward(image)?;
        // 将其与 ReLU 激活函数相乘
        let x = x.relu()?;
        // 进行第二层线性分析
        self.second.forward(&x)
    }
}

// 模仿线性模型分析过程
fn main() -> Result<()> {
    //  Device::new_cuda(0)?;
    //  Device::Cpu;
    // 使用CPU资源
    let device = Device::Cpu;

    // 创建demo(线性第一层)
    let weight = Tensor::zeros((784, 100), DType::F32, &device)?;
    let bias = Tensor::zeros((100,), DType::F32, &device)?;
    let first = Linear {
        weight,
        bias,
    };
    // 创建demo(线性第二层)
    let weight = Tensor::zeros((100, 10), DType::F32, &device)?;
    let bias = Tensor::zeros((10,), DType::F32, &device)?;
    let second = Linear {
        weight,
        bias,
    };
    // 创建模型
    let model = Model {
        first,
        second,
    };

    // demo图片
    let dummy_image = Tensor::zeros((1, 784), DType::F32, &device)?;

    // 开始模型推演
    let digit = model.forward(&dummy_image)?;

    println!("Digit {digit:?} digit");
    Ok(())
}

 

标签:线性网络,weight,DType,bias,let,线性,Candle,rust,Tensor
From: https://www.cnblogs.com/-CO-/p/18152701

相关文章

  • 【rust】《Rust深度学习[5]-理解卷积神经网络(Candle)》
    卷积神经网络ConvolutionalNeuralNetwork,简称为CNN。CNN与一般的顺传播型神经网络不同,它不仅是由全结合层,还由卷积层(ConvolutionLayer)和池层(PoolingLayer)构成的神经网络。在卷积层和池化层中,如下图所示,缩小输入神经元的一部分区域,局部地与下一层进行对应。每一层都有一个称......
  • 【rust】《Rust深度学习[2]-数据分析和挖掘库(Polars)》
    什么是Polars?Polars是一个用于操作结构化数据的高性能DataFrame库,可以用来进行数据清洗和格式转换、数据分析和统计、数据可视化、数据读取和存储、数据合并和拼接等等,相当于Rust版本的Pandas库。Polars读写数据支持如下:  常见数据文件:csv、parquet(不支持xlsx、json文件) ......
  • 【rust】《Rust深度学习[3]-数据可视化(Plotters)》
    什么是Plotters?Plotters是一个用纯Rust开发的图形库,用于中渲染图形、图表和数据可视化。它支持静态图片渲染和实时渲染,并支持多种后端,包括:位图格式(png、bmp、gif等)、矢量图(svg)、窗口和HTML5Canvas。Plotters对不同后端使用统一的高级API,并允许开发者自定义坐标系。在Plotters......
  • 【rust】《Rust深度学习[1]-科学计算库(Ndarray)》
    什么是Ndarray?ndarray是Rust生态中用于处理数组的库。它包含了所有常用的数组操作。简单地说ndarray相当于Rust版本的numpy。ndarray生态系统中crate的文档:ndarray基础库ndarray-rand随机数生成库ndarray-stats统计方法  顺序统计(最小、最大、中值、分位数等);  汇总......
  • three.js使用Instanced Draw+Frustum Cull+LOD来渲染大场景(开源)
    大家好,本文使用three.js实现了渲染大场景,在移动端也有较好的性能,并给出了代码,分析了关键点,感谢大家~关键词:three.js、InstancedDraw、大场景、LOD、FrustumCull、优化、Web3D、WebGL、开源代码:Github我正在承接Web3D数字孪生项目,具体介绍可看承接各种Web3D业务加QQ群交流:106......
  • vim配置rust开发环境
    vim配置需要环境首先需要安装rust,然后安装rust-analysis,还需要nodejs,npm。插件使用vim-plug管理,也是需要提前安装的安装coc之后还需要安装CocInstallcoc-rust-analysis下边是踩坑出来的配置文件"插件安装在callplug#begin('~/.vim/plugged')和callplug#end()之间。cal......
  • Windows快速安装Rust
    本文是最简最快最小化安装重点提示:如果不想安装VS消耗时间和6-8G的空间,可以按本文安装。如果系统中已经安装了VS,那么直接运行rustup-init安装Rust,并一路回车即可。前置条件:安装C++环境rust底层是依赖C环境的连接器,所以需要先安装C/C++编译环境,点击下载64位mingw-builds......
  • The Stack and the Heap栈与堆__Rust
    Manyprogramminglanguagesdon’trequireyoutothinkaboutthestackandtheheapveryoften.许多编程语言并不会要求你经常思考堆栈。ButinasystemprogramminglanguagelikeRust,whetheravalueisonthestackortheheapaffectshowthelanguagebehaves......
  • rust程序中设置和访问环境变量
    在项目中,我们通常需要设置一些环境变量,用来保存一些凭证或其它数据,这时我们可以使用dotenv这个crate。1、添加crate依赖首先在项目中添加dotenv这个依赖:  2、添加.env文件在开发环境下,我们可以在项目根目录下创建和编辑.env这个文件: 在运行环境下,这个.env文件要......
  • 50个Rust新手常犯的错误:看看你中过几条?
    错误地使用可变和不可变借用letmutdata=vec![1,2,3];letx=&data[0];data.push(4);println!("{}",x);不能在有不可变引用时修改数据。忘记处理Optionfnmain(){letsome_number=Some(5);letsum=some_number+5;//错误:Option类型不能这......