首页 > 其他分享 >统计pytorch、caffe稀疏代码

统计pytorch、caffe稀疏代码

时间:2022-11-22 14:33:14浏览次数:64  
标签:layer weight para 稀疏 pytorch caffe np total name

pytorch

net = net_now()
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
net.load_state_dict(state_dict, strict=True)
#net.cuda()
net.eval()
print("==========load success===!!")



T = 1e-8
total_weight = 0
total_weight_avail = 0
for name,parameters in net.named_parameters():
    if "bias" in name:
        continue
    print(name,':',parameters.size())
    weight = parameters.detach().numpy()
    weights_np = abs(weight)

    total_weight += weights_np.size
    tmp = weights_np > T
        # tmp = weights_np != 0
    total_weight_avail += tmp.sum()

    ratio_zero = (1 - (tmp.sum() * 1.0 / weights_np.size))
    print("name=", name, "    ratio_zero=", ratio_zero)

print("ratio_conv_avail_weight=", total_weight_avail * 1.0 / total_weight, "     ratio_conv_not_avail_weight=",
        1 - total_weight_avail * 1.0 / total_weight)

exit(0)

caffe

    net.forward(**input_dict)

    ##################################
    #####bn
    print("==========>>bn"*5)
    T = 1e-5
    for layer_para_name, para in net.params.items():
        if "bn" in layer_para_name:
            weights_np = abs(para[3].data)  # para[3]是λ
            tmp = weights_np > T
            ratio_zero = (1 - (tmp.sum() * 1.0 / weights_np.size))
            print("layer_para_name=", layer_para_name, "    ratio_zero=", ratio_zero)

    ####conv
    print("==========>>conv" * 5, "            --->T=",T)
    total_weight = 0
    total_weight_avail = 0
    for layer_para_name, para in net.params.items():
        if "bn" in layer_para_name or "scale" in layer_para_name or "Scale" in layer_para_name or "bias" in layer_para_name:
            continue

        # if "conv1_1" == layer_para_name:
        #     # Statistics_weight(layer_para_name, abs(para[0].data))
        #
        #     Statistics_weight("/media/algo/data_1/project/oof/oof_sparse/caffe-jacinto/0000/deply/show_L2", layer_para_name, abs(para[0].data))
        #     a = 0



        # Statistics_weight("/media/algo/data_1/project/oof/oof_sparse/caffe-jacinto/0000/deply/show/0930/0930_L1+sprse", "L1+sparse", layer_para_name, abs(para[0].data))


        weights_np = abs(para[0].data)  # para[0]weight   para[1]bias   2  128  3  3
        weights_np_0 = weights_np[0]

        tmp_2 = weights_np <= 0.2
        ratio_123 = tmp_2.sum() * 1.0 / weights_np.size

        total_weight += weights_np.size
        tmp = weights_np > T
        # tmp = weights_np != 0
        total_weight_avail += tmp.sum()

        ratio_zero = (1 - (tmp.sum() * 1.0 / weights_np.size))
        print("layer_para_name=", layer_para_name, "    ratio_zero=", ratio_zero)

    print("ratio_conv_avail_weight=", total_weight_avail * 1.0 / total_weight, "     ratio_conv_not_avail_weight=",
          1 - total_weight_avail * 1.0 / total_weight)

    ##################################

标签:layer,weight,para,稀疏,pytorch,caffe,np,total,name
From: https://www.cnblogs.com/yanghailin/p/16915022.html

相关文章

  • Pytorch入门(4)—— Tensor和Module的保存与加载
    参考:动手学深度学习注意:由于本文是jupyter文档转换来的,代码不一定可以直接运行,有些注释是jupyter给出的交互结果,而非运行结果!!文章目录​​1.读写Tensor​​​​2.读写......
  • Pytorch入门(3)—— 构造网络模型
    参考:动手学深度学习注意:由于本文是jupyter文档转换来的,代码不一定可以直接运行,有些注释是jupyter给出的交互结果,而非运行结果!!文章目录​​1.模型构造​​​​1.1继承`M......
  • 【2022.11.21】pytorch的使用相关(五)
    资料来源ShusenTang/Dive-into-DL-PyTorch:本项目将《动手学深度学习》(DiveintoDeepLearning)原书中的MXNet实现改为PyTorch实现。(github.com)代码部分%matplotl......
  • Pytorch在训练时冻结某些层使其不参与反向传播
    笔记摘抄:https://blog.csdn.net/qq_36429555/article/details/118547133定义网络#定义一个简单的网络classnet(nn.Module):def__init__(self,num_class=10):......
  • 17.5 稀疏调拨的内存映射文件--《Windows核心编程》
    原文链接:https://www.likecs.com/show-306421749.html,原文中代码是C++MFC程序,更详细。本文是C语言测试代码。(1)稀疏文件(SparseFile)定义指的是文件中出现大量的0数据,这......
  • ACGAN-pytorch
    点击查看代码importargparseimportosimportnumpyasnpimporttorchimporttorch.nnasnnimporttorchvision.transformsastransformsfromtorch.autograd......
  • [PyTorch] 自定义数据集
    步骤:自定义Dataset实例:定义__init__方法:返回feature和label两个部分的数据;定义__getitem();定义_len_()方法;使用torch.utils.data.DataLoader加载数据;示......
  • RNN的PyTorch实现
    官方实现PyTorch已经实现了一个RNN类,就在torch.nn工具包中,通过torch.nn.RNN调用。使用步骤:实例化类;将输入层向量和隐藏层向量初始状态值传给实例化后的对象,获得RNN的......
  • pytorch学习笔记(1)
    pytorch学习笔记(1)   expand向左扩展维度、扩展元素个数a=t.ones(2,3)只能在左侧增加维度,而不能在右侧增加维度,也不能在中间增加维度新增维度的元素个数可以为任......
  • Pytorch基于MNIST数据集简单实现手写数字识别
    """模型训练代码"""importtorchimporttorchvision.datasetsfromtorchimportnnfromtorchvisionimporttransformsfromtorch.utils.dataimportDataLoaderi......