首页 > 其他分享 >使用torch pruning工具进行结构化剪枝

使用torch pruning工具进行结构化剪枝

时间:2022-12-04 22:13:40浏览次数:46  
标签:剪枝 nn self torch stride planes model pruning out

网络结构定义

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_pruning as tp
from torchvision.datasets import CIFAR10
from torchvision import transforms
import numpy as np 
import time
  
class BasicBlock(nn.Module):
    expansion = 1
 
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
 
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
 
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
 
 
class Bottleneck(nn.Module):
    expansion = 4
 
    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)
 
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
 
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
 
 
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64
 
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
 
    def forward(self, x, out_feature=False):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        feature = out.view(out.size(0), -1)
        out = self.linear(feature)
        if out_feature == False:
            return out
        else:
            return out,feature
 
 
def ResNet18(num_classes=10):
    return ResNet(BasicBlock, [2,2,2,2], num_classes)
def ResNet50(num_classes=10):
    return ResNet(Bottleneck, [3,4,6,3], num_classes)

speed test

原始模型 ResNet18
剪枝策略: L1Strategy 各Block裁剪比率 [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
比较原始网络,通道不取整,通道按照16倍数取整的推理速度


def measure_inference_time(net, input, repeat=100):
   # torch.cuda.synchronize()   # if use cuda uncomment it
    start = time.perf_counter()
    for _ in range(repeat):
        model(input)
        #torch.cuda.synchronize() # if use cuda uncomment it
    end = time.perf_counter()
    return (end-start) / repeat

def prune_model(model, round_to=1):
    model.cpu()
    DG = tp.DependencyGraph().build_dependency( model, torch.randn(1, 3, 32, 32) )
    def prune_conv(conv, amount=0.2, round_to=1):
        #weight = conv.weight.detach().cpu().numpy()
        #out_channels = weight.shape[0]
        #L1_norm = np.sum( np.abs(weight), axis=(1,2,3))
        #num_pruned = int(out_channels * pruned_prob)
        #pruning_index = np.argsort(L1_norm)[:num_pruned].tolist() # remove filters with small L1-Norm
        strategy = tp.strategy.L1Strategy()
        pruning_index = strategy(conv.weight, amount=amount, round_to=round_to)
        plan = DG.get_pruning_plan(conv, tp.prune_conv, pruning_index)
        plan.exec()
    
    block_prune_probs = [0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.3, 0.3]
    blk_id = 0
    for m in model.modules():
        if isinstance( m, BasicBlock ):
            prune_conv( m.conv1, block_prune_probs[blk_id], round_to )
            prune_conv( m.conv2, block_prune_probs[blk_id], round_to )
            blk_id+=1
    return model 
 
device = torch.device('cpu')  #torch.device('cuda') # or torch.device('cpu')
repeat = 100

# before pruning
model = ResNet18().eval()
fake_input = torch.randn(16,3,32,32)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_before_pruning = measure_inference_time(model, fake_input, repeat)
print("before pruning: inference time=%f s, parameters=%d"%(inference_time_before_pruning, tp.utils.count_params(model)))

# w/o rounding
model = ResNet18().eval()
prune_model(model)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_without_rounding = measure_inference_time(model, fake_input, repeat)
print("w/o rounding: inference time=%f s, parameters=%d"%(inference_time_without_rounding, tp.utils.count_params(model)))
    
# w/ rounding
model = ResNet18().eval()
prune_model(model, round_to=16)
print(model)
model = model.to(device)
fake_input = fake_input.to(device)
inference_time_with_rounding = measure_inference_time(model, fake_input, repeat)
print("w/ rounding: inference time=%f s, parameters=%d"%(inference_time_with_rounding, tp.utils.count_params(model)))

accuracy test

from cifar_resnet import ResNet18
import cifar_resnet as resnet

def get_dataloader():
    train_loader = torch.utils.data.DataLoader(
        CIFAR10('./chapter3_data', train=True, transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]), download=True),batch_size=256, num_workers=2)
    test_loader = torch.utils.data.DataLoader(
        CIFAR10('./chapter3_data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
        ]),download=True),batch_size=256, num_workers=2)
    return train_loader, test_loader

def eval(model, test_loader):
    correct = 0
    total = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    with torch.no_grad():
        for i, (img, target) in enumerate(test_loader):
            img = img.to(device)
            out = model(img)
            pred = out.max(1)[1].detach().cpu().numpy()
            target = target.cpu().numpy()
            correct += (pred==target).sum()
            total += len(target)
    return correct / total

_, test_loader = get_dataloader()

# original
previous_ckpt = 'resnet18-round0.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("before pruning: Acc=%.4f"%(acc))

# w/o rounding
previous_ckpt = 'resnet18-pruning-noround.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/o rounding: Acc=%.4f"%(acc))

# w/ rounding
previous_ckpt = 'resnet18-pruning-round_to16.pth'
model = torch.load( previous_ckpt )
acc = eval(model, test_loader)
print("w/ rounding: Acc=%.4f"%(acc))

标签:剪枝,nn,self,torch,stride,planes,model,pruning,out
From: https://www.cnblogs.com/whiteBear/p/16950942.html

相关文章

  • Torch-Pruning工具箱
    Torch-Pruning通道剪枝网络实现加速的工作。Torchpruning是进行结构剪枝的pytorch工具箱,和pytorch官方提供的基于mask的非结构化剪枝不同,工具箱移除整个通道剪枝,自动发......
  • pytorch安装
    pytorch安装1、查看本机的CUDA版本cmd命令行输入nvidia-smi,在第一行最右边可以看到CUDA的版本号![version](C:\Users\nice7\Pictures\SavedPictures\version.png)2、......
  • pytorch 如何从checkpoints中继续训练
    左1:从头开始训练时,lr的变化。左2:从epoch100时开始训练......
  • win10 中 anaconda3 安装 pytorch 教程
    anaconda中自带python,所以不需要提前安装python。1.安装anaconda3下载链接:https://mirrors.tuna.tsinghua.edu.cn/anaconda/archive/下载文件:Anaconda3-2021.11-Windo......
  • Pytorch tensor操作 gather、expand、repeat、reshape、view、permute、transpose
    文章目录​​tensor.gather​​​​tensor.expand​​​​tensor.repeat​​​​reshape()和view()​​​​permute()和transpose()​​​​torch.matmul()​​​​torc......
  • torch.nn.CrossEntropyLoss
    文章目录​​交叉熵损失函数`torch.nn.CrossEntropyLoss`​​​​F.cross_entropy​​​​F.nll_loss​​交叉熵损失函数​​torch.nn.CrossEntropyLoss​​weight(Tensor......
  • Pytorch mask:上三角和下三角
    上三角triuPytorch上三角和下三角的调用与numpy是相同的。np.triu(np.ones((5,5)),k=0)#k控制对角线开始的位置Out[25]:array([[1.,1.,1.,1.,1.],[0.,1.,1......
  • 矩池云 | GPU 分布式使用教程之 Pytorch
    GPU分布式使用教程之PytorchPytorch官方推荐使用DistributedDataParallel(DDP)模块来实现单机多卡和多机多卡分布式计算。DDP模块涉及了一些新概念,如网络(WorldSize......
  • torch.autograd.Function 用法及注意事项
    众所周知,作为深度学习框架之一的PyTorch和其他深度学习框架原理几乎完全一致,都有着自动求导机制,当然也可以说成是自动微分机制。有些时候,我们不想要它自带的求导机制,需要......
  • 树莓派安装torch与torchvision测试
    在官方下载页搜索下载cpu/torch-1.8.1-cp39-cp39-manylinux2014_aarch64.whlcpu/torchvision-0.9.1-cp39-cp39-manylinux2014_aarch64.whl注意版本对应,且PyTorch仅能......