首页 > 其他分享 >maxpool3d修改成maxpool2d与maxpool1d方法

maxpool3d修改成maxpool2d与maxpool1d方法

时间:2023-03-02 11:34:00浏览次数:40  
标签:kernel maxpool3d self torch maxpool2d padding stride maxpool1d size

有时候遇到不支持maxpool3d的硬件或算子时候,可将其改成maxpool2d加上maxpool1d组合方式表示,经验证与maxpool3d结果完全一致,其实现细节如下:

代码:

import torch


class MaxPool3d_modify(torch.nn.Module):
    def __init__(self, kernel_size=(2, 2, 2), stride=(2, 2, 2), padding=(0, 0, 0)):
        super(MaxPool3d_modify, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.max_pool_2d = torch.nn.MaxPool2d(kernel_size[1:], self.stride[1:], padding[1:])
        self.max_pool_1d = torch.nn.MaxPool1d(kernel_size=kernel_size[0], stride=self.stride[0],
                                              padding=self.padding[0])  # stride is kernal_size

    def forward(self, x):
        x1 = self.max_pool_2d(x)
        x = x1.squeeze(0).permute(1, 2, 0)
        x = self.max_pool_1d(x)

        x = x.permute(2, 0, 1).unsqueeze(0)

        return x


if __name__ == '__main__':
    '''
    torch.nn.MaxPool3d处理维度4或5,[b,c,h,w]或[b,c,f,h,w] 处理维度为c,h,w或f,h,w
    torch.nn.MaxPool2d处理维度为4,[b,c,h,w]处理h,w维度pool
    torch.nn.MaxPool1d处理维度为3,[d1,d2,d3]处理d3维度pool
    '''
    input_ori = torch.rand(1, 128, 20, 90)  # 64,18,44  kernel_size=(1, 3, 3), stride=(2, 1, 2)
    model1 = torch.nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(4, 3, 2), padding=(0, 0, 0))
    model2 = MaxPool3d_modify(kernel_size=(1, 3, 3), stride=(4, 3, 2), padding=(0, 0, 0))
    o1 = model1(input_ori)
    print('\noutput1 shape: ', o1.shape)
    o2 = model2(input_ori)
    print('\noutput2 shape: ', o2.shape)
    output1 = o1.reshape(-1)
    output2 = o2.reshape(-1)
    n = 0
    for i, o in enumerate(output1):
        if o == output2[i]:
            n = n + 1
    print('precision', n / len(output1))
    

结果展示:

 

 

 

标签:kernel,maxpool3d,self,torch,maxpool2d,padding,stride,maxpool1d,size
From: https://www.cnblogs.com/tangjunjun/p/17171198.html

相关文章

  • torch.nn.MaxPool2d()
    torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)\(2D\)最大池化。参数:kernel_size:最大池化......