首页 > 其他分享 >2023.12.14

2023.12.14

时间:2023-12-14 21:45:58浏览次数:22  
标签:14 nn 2023.12 self torch test model data

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

class ModulatedAttLayer(nn.Module):
    # (Unchanged code)

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.mod_att = ModulatedAttLayer(in_channels=64, reduction=2, mode='embedded_gaussian')
        self.fc = nn.Linear(64 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x, _ = self.mod_att(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Dummy data loader for demonstration purposes
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((224, 224))])
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=True, download=True, transform=transform),
    batch_size=32, shuffle=True, num_workers=4)

# Move the model and data to GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(5):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)  # Move data to GPU
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')

# Testing the attention mechanism
test_data, _ = next(iter(train_loader))
test_data = test_data.to(device)
test_output, attention_maps = model(test_data)

# Visualize or analyze the attention maps as needed


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Assuming you have a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.mod_att = ModulatedAttLayer(in_channels=64, reduction=2, mode='embedded_gaussian')
        self.fc = nn.Linear(64 * 7 * 7, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x, _ = self.mod_att(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Dummy data loader for demonstration purposes
transform = transforms.Compose([transforms.ToTensor()])
train_loader = torch.utils.data.DataLoader(
    datasets.CIFAR10(root='./data', train=True, download=True, transform=transform),
    batch_size=32, shuffle=True, num_workers=4)

# Initialize the model and optimizer
model = SimpleCNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(5):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = nn.CrossEntropyLoss()(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item()}')

# Testing the attention mechanism
test_data, _ = next(iter(train_loader))
test_output, attention_maps = model(test_data)

# Visualize or analyze the attention maps as needed

标签:14,nn,2023.12,self,torch,test,model,data
From: https://www.cnblogs.com/ZarkY/p/17902068.html

相关文章

  • 闲话12.14
    今天晚上写题写累了,闲话可能比较水。上午接着颓,上课也颓废。K8一句话说的很对啊,就是提前来这体验大学生活了,上午上课颓废,下午晚上自习颓废,赢麻了。上午的树上问题没啥听懂的,感觉昨天的ds都放今天了是为啥。写起来也挺难受的。过几天好像就是沈老师来讲课了。下午普通的写题......
  • 2023.12.14
    7-1数据结构作业#include<iostream>#include<vector>usingnamespacestd;intmain(){intn,p;scanf("%d%d",&n,&p);vector<int>a(p,-1);//初始化数组,全为-1while(n--){intx;cin>>......
  • 2023-12-14 npm和yarn无法拉取依赖,cnpm可以 ==》切换镜像源
    这两天遇到个问题,是关于依赖无法拉取的问题,尽管我有三分猜到了是什么原因,但我还是不肯往那个方向思考,哎,真是死牛一便颈。如,我要给前端项目装个express框架,用npm装,装了大半天一点反应都没有,用yarn装就直接报网络无法连接,如图: 用cnpm装就没问题,秒过。注意:我的电脑是能正常上网......
  • CF1481D
    考虑二元环要是二元环相同那么显然怎么构造都可以了否则我们考虑没有二元环相同要是m是奇数我们随便跑跑就行要是m是偶数情况呢我们需要构造一种情况我们肯定用的点数越少越好我们考虑三个点要是两个二元环都是a出或者b出的就可以构造出来了voidsolve(){......
  • CF1493C
    以前写挂了今天又拿出来写手玩一下样例发现我们从高位贪是肯定的尽可能让该位置和原串一样然后我们可以枚举该位改成什么字母然后计算后面的放是否合法写的很屎其实就是复制粘贴了一坨我们先找到最远的位置可以修改再修改为最小的即可voidsolve(){intn,k;cin>>......
  • Solution Set 2023.12.14
    CF698FCoprimePermutation考虑\(p_i=0\)的情况下怎么做,首先排列\(p_i=i\)一定符合条件,考虑在此基础上生成其他符合要求的排列,考虑什么情况下\(p_i\)和\(p_j\)可以互换,发现其可以互换当且仅当对于所有\(x\neqi\)且\(x\neqj\),均有\(\left[\gcd\left(i,x\rig......
  • 百度网盘(百度云)SVIP超级会员共享账号每日更新(2023.12.14)
    一、百度网盘SVIP超级会员共享账号可能很多人不懂这个共享账号是什么意思,小编在这里给大家做一下解答。我们多知道百度网盘很大的用处就是类似U盘,不同的人把文件上传到百度网盘,别人可以直接下载,避免了U盘的物理载体,直接在网上就实现文件传输。百度网盘SVIP会员可以让自己百度账......
  • 力扣146 螺旋遍历二维数组
    Problem: LCR146.螺旋遍历二维数组思路多个循环螺旋模拟classSolution{public:vector<int>spiralArray(vector<vector<int>>&array){vector<int>res;intm=array.size();if(m==0){returnres;}......
  • 【笔记】2023.12.14 树上问题
    笔记2023.12.14:树上问题[Ynoi2004]rpmtdq支配对:\(i_1\leqi_2\leqj_2\leqj_1,dist(i_1,j_1)\geqdist(i_2,j_2)\)时,称\((i_1,j_1)\)被\((i_2,j_2)\)支配,前者就无用了,选到区间只要包含\((i_1,j_1)\)就一定包含\((i_2,j_2)\)。点分治到\(rt\)时,记\(d_x=dis......
  • 12.14周四每日总结
    今天上课更加深入讲解了类图和时序图,并通过测试和让学生讲解让我们更加深入这些内容。其中面向对象建模过程识别信息系统的目标和边界(上下范围图);识别用例,建立用例图;识别对象和类,建立类图;设计用例的详细逻辑,建立时序图或协作图;必要时重复以上活动,精化并调整各图。让我们更......