首页 > 其他分享 >【深度学习基础模型】径向基函数网络(Radial Basis Function Networks, RBFN)详细理解并附实现代码。

【深度学习基础模型】径向基函数网络(Radial Basis Function Networks, RBFN)详细理解并附实现代码。

时间:2024-09-25 10:50:28浏览次数:9  
标签:Function RBFN 函数 Basis self torch RBF 径向

【深度学习基础模型】径向基函数网络(Radial Basis Function Networks, RBFN)

【深度学习基础模型】径向基函数网络(Radial Basis Function Networks, RBFN)


文章目录


参考地址:https://www.asimovinstitute.org/neural-network-zoo/
论文地址:https://apps.dtic.mil/sti/pdfs/ADA196234.pdf

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

1.算法原理介绍:径向基函数网络(Radial Basis Function Networks, RBFN)

1.1 径向基函数网络 (RBFN) 概述

径向基函数网络(RBFN)是一种特殊的前馈神经网络(FFNN),其隐藏层的激活函数为径向基函数。与传统的FFNN不同,RBFN在输入和隐藏层之间使用非线性变换,具体地,RBF激活函数根据输入与中心的距离计算响应,通常为高斯函数或其他径向对称的函数。RBFN在分类、回归和函数逼近任务中表现优异,尤其是对局部特征敏感的问题。

1.2 网络结构

RBF网络的结构通常包括以下三层:

  • 输入层:输入数据直接传递到隐藏层,不进行任何处理。
  • 隐藏层:每个神经元使用径向基函数来计算输入与预先定义的中心的距离。
  • 输出层:通常是线性输出,表示隐藏层的加权和。

1.3 径向基函数

RBFN的核心是径向基函数,它根据输入与中心点之间的距离来决定输出。最常见的径向基函数是高斯函数,形式如下:

ϕ ( r ) = e − r 2 2 σ 2 ϕ(r)=e^{-\frac{r^2}{2σ^2}} ϕ(r)=e−2σ2r2​

其中:

  • r r r 是输入 x x x 和中心 c c c 之间的欧氏距离 r = ∥ x − c ∥ r=∥x−c∥ r=∥x−c∥。
  • σ σ σ 是控制函数扩展范围的参数。

1.4 训练方法

RBFN的训练包括两个步骤:

  • 确定径向基函数的中心、宽度:通常通过聚类算法(如k-means)来选择中心点。
  • 线性层权重的训练:可以通过最小二乘法或梯度下降法来学习线性层的权重。

1.5 RBF网络的应用

  • 分类:RBFN可用于模式分类问题,尤其是在数据呈现出类簇分布的情况下,RBFN能通过径向基函数对局部区域敏感,提供较好的分类性能
  • 回归:RBFN常用于函数逼近问题,通过对输入空间的局部区域响应,可以实现高效的回归任务
  • 时间序列预测:由于RBFN对非线性问题的优良建模能力,它也常用于时间序列预测和动态系统控制。

2.Python实现RBF网络的应用实例

下面的代码使用Python和深度学习框架PyTorch实现了一个简单的RBF网络,用于解决分类问题(例如二分类任务)。

2.1代码实现:RBF网络的实现及二分类应用

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# 定义RBF激活函数
class RBF(nn.Module):
    def __init__(self, input_size, num_centers, sigma):
        super(RBF, self).__init__()
        self.num_centers = num_centers
        self.centers = nn.Parameter(torch.randn(num_centers, input_size))  # 随机初始化中心点
        self.sigma = sigma  # 高斯函数的宽度参数

    def forward(self, x):
        # 计算输入到每个中心点的欧氏距离
        expanded_input = x.unsqueeze(1).expand(-1, self.num_centers, -1)
        expanded_centers = self.centers.unsqueeze(0).expand(x.size(0), -1, -1)
        distances = torch.norm(expanded_input - expanded_centers, dim=2)
        # 通过高斯径向基函数计算输出
        return torch.exp(-distances ** 2 / (2 * self.sigma ** 2))

# 定义RBF网络
class RBFNetwork(nn.Module):
    def __init__(self, input_size, num_centers, output_size, sigma):
        super(RBFNetwork, self).__init__()
        self.rbf = RBF(input_size, num_centers, sigma)  # RBF层
        self.linear = nn.Linear(num_centers, output_size)  # 线性层

    def forward(self, x):
        rbf_output = self.rbf(x)
        return self.linear(rbf_output)

# 创建数据集(使用scikit-learn生成的二分类数据集)
X, y = make_classification(n_samples=200, n_features=2, n_classes=2, n_clusters_per_class=1, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 转换为Tensor
X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32).view(-1, 1)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32).view(-1, 1)

# 定义模型参数
input_size = X_train.shape[1]  # 输入特征数
num_centers = 10  # 径向基函数的中心数
output_size = 1  # 输出为1(用于二分类)
sigma = 1.0  # 高斯核宽度

# 创建RBF网络
model = RBFNetwork(input_size, num_centers, output_size, sigma)

# 定义损失函数和优化器
criterion = nn.BCEWithLogitsLoss()  # 二元交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
epochs = 100
for epoch in range(epochs):
    optimizer.zero_grad()  # 梯度清零
    outputs = model(X_train)  # 前向传播
    loss = criterion(outputs, y_train)  # 计算损失
    loss.backward()  # 反向传播
    optimizer.step()  # 更新权重

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

# 测试模型
with torch.no_grad():
    test_outputs = model(X_test)
    predictions = torch.round(torch.sigmoid(test_outputs))
    accuracy = (predictions == y_test).float().mean()
    print(f'Accuracy: {accuracy.item():.4f}')

2.2代码解释

1.定义RBF激活函数:

class RBF(nn.Module):
    def __init__(self, input_size, num_centers, sigma):
        super(RBF, self).__init__()
        self.centers = nn.Parameter(torch.randn(num_centers, input_size))
        self.sigma = sigma
  • RBF类中,初始化了中心点和高斯核的宽度(sigma)。
  • 每个中心点对应隐藏层中的一个神经元。

2.计算欧氏距离:

distances = torch.norm(expanded_input - expanded_centers, dim=2)
  • 计算每个输入样本与各个中心点之间的距离,使用欧氏距离衡量输入与中心点的相似性。

3.RBF网络的前向传播:

rbf_output = self.rbf(x)
return self.linear(rbf_output)
  • 通过RBF层计算隐藏层输出,然后通过线性层计算最终输出。

4.生成二分类数据集:

X, y = make_classification(n_samples=200, n_features=2, n_classes=2)
  • 这里生成了一个二分类问题数据集,包含200个样本和2个特征。

5.训练与测试:

  • 使用PyTorch内置的优化器和损失函数进行训练和优化。
  • 训练结束后,使用测试集评估模型的准确性。

3.总结

径向基函数网络(RBFN)是一种特殊的前馈神经网络,使用径向基函数作为激活函数,使得网络能够对输入空间的局部区域敏感。RBF网络在分类、回归、时间序列预测等领域中有广泛的应用。通过PyTorch实现的RBF网络示例展示了RBFN在简单二分类问题中的应用。

标签:Function,RBFN,函数,Basis,self,torch,RBF,径向
From: https://blog.csdn.net/gaoxiaoxiao1209/article/details/142422313

相关文章

  • Call to undefined function think\exception\config()
    错误信息 Calltoundefinedfunctionthink\exception\config() 表示在尝试调用 think\exception\config() 函数时,该函数未被定义。这可能是由于以下几个原因导致的:命名空间问题:可能是命名空间没有正确引入或定义。类文件未加载:可能是某个类文件没有正确加载或包含。......
  • JavaScript 中 new Function() 和 new function() 的区别
    javascript确实很灵活,但它也带来了一些混乱。例如,你可以使用多种方式来做同一件事情,比如创建函数、对象等。那么标题中提到的两者有什么区别呢?newfunction是另一种创建函数的方式,其语法:constfunc=newfunction([arg1,arg2,...argn],functionbody);一个简单的例子:constsu......
  • qwen2.5 vllm推理;openai function call调用中文离线agents使用
    参考:https://qwenlm.github.io/zh/blog/qwen2.5/https://qwen.readthedocs.io/zh-cn/latest/framework/function_call.html#vllm安装:pipinstall-Uvllm-ihttps://pypi.tuna.tsinghua.edu.cn/simplevllm-0.6.1.post2运行:</......
  • 使用swig映射c++function
    swig可以自动生成从c++到其他语言如Java、Python等转换的中间语言,目前swig已经支持很多c++11的特性了,但是这次项目中发现function特性还没有支持,只能自己生成。从网上找了一份Java的java-HowtouseSWIGtowrapstd::functionobjects?-StackOverflow,我需要的c#的,故需要稍......
  • 帝国CMS报错Deprecated: Function get_magic_quotes
    当使用帝国CMS时遇到“Deprecated:Functionget_magic_quotes”这类报错,通常是因为PHP版本升级后,某些旧的函数被弃用。get_magic_quotes_gpc() 函数在PHP5.4中已被弃用,并在PHP7.0中被移除。原因分析PHP版本升级:如果你的服务器从较旧的PHP版本(如5.3或更低)升级到了PHP7.......
  • 兼收并蓄 TypeScript - 类: function
    源码https://github.com/webabcd/TypeScriptDemo作者webabcd兼收并蓄TypeScript-类:function示例如下:class\function.ts{//定义函数时要指定参数的类型和返回值的类型,无返回值时可以用void表示functionf1(x:number,y:number):number{retur......
  • VUE 使用用Echart 报错:this.dom.getContext is not a function
    问题:在VUE中 如果使用了 tabs 关在tab 中加入了<div>标签;在初始化中 执行echart.init() 可能会报错:this.dom.getContextisnotafunction;大致如下所示:<el-tabs> <el-tab-pane><div>    <divstyle="height:500px;widows:1000px;"ref="homeLineRe......
  • 易优eyoucms网站报错 \core\library\think\App.php Fatal error: Call to undefin
    当你遇到 Fatalerror:Calltoundefinedfunctionthink\switch_citysite() 这样的错误时,说明在代码中调用了一个未定义的函数 think\switch_citysite()。这种情况通常是因为函数没有被正确地引入或者该函数根本不存在于当前的代码库中。解决方案确认函数的存在检查 s......
  • 解决React Warning: Function components cannot be given refs. Attempts to access
    问题当我使用如下方式调用组件子组件UploadModal并且绑定Ref时React报错“Warning:Functioncomponentscannotbegivenrefs.Attemptstoaccessthisrefwillfail.DidyoumeantouseReact.forwardRef()?”;constUploadModalRef=useRef(null);constopenUploadModa......
  • ZBLOG PHP提示"Call to undefined function mysql_connect()"错误
    当遇到Z-BlogPHP在PHP7.2上出现 mysql_connect() 未定义的错误时,这是因为PHP7.2默认不再支持MySQL扩展(mysql 扩展)。你需要进行一些调整来使Z-BlogPHP兼容PHP7.2。以下是两种解决方案:解决方案一:降级PHP版本如果你暂时不想修改代码,可以选择降级PHP版本到一......