首页 > 其他分享 >可视化全连接层(蒙特卡洛法)

可视化全连接层(蒙特卡洛法)

时间:2023-06-01 18:23:35浏览次数:43  
标签:__ random nn self torch 可视化 import 蒙特卡洛 连接

import random
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import math
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

epochs=1000
class pt:
    def __init__(self,x,y):
        self.x=x
        self.y=y

class logistic(nn.Module):
    def __init__(self):
        super(logistic, self).__init__()
        self.w = torch.nn.Parameter(torch.randn(2, 1))
        self.b = torch.nn.Parameter(torch.zeros([1]))
        self.line1=torch.nn.Linear(2,1)
        self.line2=nn.Sequential(
          nn.Linear(2,5000),
            nn.ReLU(),
        nn.Linear(5000, 1),
        )
        self.pred=None
    def forward(self, X):
        #self.pred=torch.matmul(X,self.w)+self.b
        self.pred=self.line2(X)
        return torch.sigmoid(self.pred),self.pred
def generate_point(th=0.4,start=50,end=90):
    class_list=[]
    point_list=[]
    for angle in range(360):
        theta=3.14/180.0*angle
        x=math.cos(theta)+random.random()*th
        y=math.sin(theta)+random.random()*th
        point_list.append(list([x,y]))
        if angle>start and angle<end or angle>180 and angle<230 or angle>250 and angle<300:#or angle>180 and angle<230
            class_list.append(0)
        else:
            class_list.append(1)
    return np.array(point_list),np.array(class_list)

class fdata(Dataset):
    # dirname 为训练/测试数据地址,使得训练/测试分开
    def __init__(self, train=True):
        super(fdata, self).__init__()
        self.data,self.label = generate_point()

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        image = self.data[index]
        image = image.astype(np.float32)
        image = torch.unsqueeze(torch.from_numpy(image),0)
        label = self.label[index]
        label = np.array(label.astype(np.float32)).reshape(1)
        label = torch.unsqueeze(torch.from_numpy(label), 0)
        return image,label


def draw(pt_list,cls_list,module):
    plt.title('circle')
    pt_r,pt_b=[],[]
    for n in range(len(cls_list)):
        if(cls_list[n]==1):
            pt_r.append(pt_list[n])
        else:
            pt_b.append(pt_list[n])
    pt_r=np.array(pt_r)
    pt_b = np.array(pt_b)
    line_list=[]
    for n in range(-10,20):
        n=n*0.1
        for m in range(-10,20):
            m=m*0.1
            line_list.append([n,m])

    line_array=[]
    line_tensor=torch.from_numpy(np.array(line_list)).reshape(-1,1,2).float()
    output,pred=module(line_tensor)
    pred=pred.squeeze().detach().numpy().tolist()
    for n in range(len(pred)):
        if pred[n]<0:
            line_array.append(line_list[n])
    line_array=np.array(line_array)
    plt.scatter(line_array[:, 0], line_array[:, 1], c="g")
    plt.scatter(pt_r[:,0],pt_r[:,1],c="r")
    plt.scatter(pt_b[:, 0],pt_b[:, 1],c="b")
    plt.xlim(-1,2)
    plt.ylim(-1, 2)
    plt.show()

criterion = nn.BCELoss()
md=logistic()
opt=torch.optim.Adam(md.parameters(),lr=0.001)
pt_list,cls_list=generate_point()
# input=torch.from_numpy(pt_list).reshape(-1,1,2).float()
# label=torch.from_numpy(cls_list).reshape(-1,1,1).float()

train_dataset = fdata()
train_dataloder = DataLoader(train_dataset, batch_size=10,
                            num_workers=0, drop_last=True,shuffle=True)
for i in range(epochs):
    for input,label in train_dataloder:
        output,pred=md(input)
        loss=criterion(output,label)
        opt.zero_grad()
        loss.backward()
        opt.step()
        print("第"+str(i)+":"+str(loss))
draw(pt_list,cls_list,md)

这个实验揭示了一个结果:带有激活函数的全连接层(至少两层)越宽,其拟合能力越强。

其实我不是很明白,为什么“低维映射到高维,经过激活函数,低维线性映射到高维后,全连接层就具有了很强的非线性能力”?

标签:__,random,nn,self,torch,可视化,import,蒙特卡洛,连接
From: https://www.cnblogs.com/xmds/p/17449853.html

相关文章

  • C# 连接SQLite数据库与建表
    SQLite是⼀个软件库,实现了自给自足的、无服务器的、零配置的、事务性的轻量级SQL数据库引擎。声明连接SQLite的变量Conn添加SQLite操作驱动dll引用:System.Data.SQLite.dllusingSystem.Data.SQLite;SQLiteConnectionConn;直接NuGet包搜索System.Data.SQLite......
  • 大模型可视化
     说明:左边是一个与Showprobabilities设置为的OpenAIPlayground基本一致的界面Fullspectrum。提示是Arebugsreal?,随后突出显示的文本是模型生成的完成。令牌根据模型预测的概率进行着色,绿色最有可能,红色最不可能。左侧的下拉菜单显示了在特定位置(在本例中为are采样位置)预......
  • 如何使用PL/SQL Developer 连接remote 数据库
    https://www.allroundautomations.com/plsqldev.html 1.下载并安装OracleInstantClient Free,light-weightandeasilyinstalledOracleDatabaselibrariesandSDKsforbuildingandconnectingclientapplicationstolocalorremoteOracleDatabases.可以通过......
  • postgresql 的 idle_session_timeout 与连接池的 max-ide-time参数
    看下面的异常:下面的错误说:terminatingconnectionduetoidle-sessiontimeout下面的这个错误说:Causedby:reactor.pool.PoolShutdownException:Poolhasbeenshutdownreactor.core.Exceptions$ErrorCallbackNotImplemented:org.springframework.dao.DataAccessResou......
  • python neo4j将新节点连接到旧标签
    要在Python中使用py2neo将新节点连接到已存在的节点标签,你可以执行以下步骤:导入所需的类和函数:frompy2neoimportGraph,Node,Relationship连接到Neo4j数据库:graph=Graph("bolt://localhost:7687",auth=("username","password"))确保将"username"和&q......
  • 数据可视化:地图类可视化图表大全
    导语随着数据在各行业中的应用越来越广泛,大家也逐渐认识到数据可视化在企业生产经营中的重要作用,在数据可视化过程中,图表是处理数据的重要组成部分,因为它们是一种将大量数据压缩为易于理解的格式的方法。数据可视化可以让受众快速Get到重点。 今天,数维图小编将为大家介绍数据......
  • Spark GraphX 的数据可视化
    概述SparkGraphX本身并不提供可视化的支持,我们通过第三方库GraphStream和Breeze来实现这一目标详细Spark和GraphX对并不提供对数据可视化的支持,它们所关注的是数据处理。但是,一图胜千言,尤其是在数据分析时。接下来,我们构建一个可视化分析图的Sp......
  • 说一下朗数可视化快速开发平台
    朗数他们的业务定位比较特别:他们卖的是居于eclipse体系上的基础平台代码、技术、服务这些,而不是卖一个体系完整的平台成品,简单的来说,如果你不懂eclipse插件开发技术,缺乏可视化平台研发经验,找他们就对了。  目前国内的平台不少,不过基本上都是卖成品,也有些是卖代码的,但卖代码的基......
  • C#中通过连接池连接mysql数据库
       使用连接池可以提高C#程序连接MySQL数据库的性能,使得不必每次建立新的物理连接。 usingSystem.Data;usingMySql.Data.MySqlClient;namespaceConsoleApp1{classProgram{privateconstintMAX_POOL_SIZE=100;//设置最大连接数......
  • 蒙特卡洛算法
    从今天开始要研究SamplingMethods,主要是MCMC算法。本文是开篇文章,先来了解蒙特卡洛算法。Contents  1.蒙特卡洛介绍  2.蒙特卡洛的应用  3.蒙特卡洛积分   1.蒙特卡洛介绍   蒙特卡罗方法(MonteCarlomethod),也称统计模拟方法,是二十世纪四十年代中期由于科学技......