首页 > 其他分享 >基于pytorch的nonlocalblock

基于pytorch的nonlocalblock

时间:2024-10-28 13:20:42浏览次数:6  
标签:dim 基于 self phi nonlocalblock pytorch theta norm size

论文《Non-local Neural Networks

为了满足即插即用的功能,本博客重写nonlocal块,并可以根据自己的喜好选择2D卷积或者3D卷积,并可以选择是否使用bn层或pool。

nonlocalblock模块图

在这里插入图片描述

3D代码如下:

class Nonlocal_3d(nn.Module):

    def __init__(self, dim, dim_inner, pool_size=None, norm=False, norm_eps=1e-5, norm_momentum=0.1,
                 norm_module=nn.BatchNorm3d, instantiation="softmax"):
        """
        Args:
            dim (int): number of dimension for the input.
            dim_inner (int): number of dimension inside of the Non-local block.

            pool_size (list): the kernel size of spatial temporal pooling,
                temporal pool kernel size, spatial pool kernel size, spatial
                pool kernel size in order. By default pool_size is None,
                then there would be no pooling used.
            instantiation (string): supports two different instantiation method:
                "dot_product": normalizing correlation matrix with L2.
                "softmax": normalizing correlation matrix with Softmax.
            norm (bool): If true, add the final batch norm of the Non-local block.
            norm_module (nn.Module): nn.Module for the normalization layer. The
                default is nn.BatchNorm3d.
        """


        super().__init__()
        self.dim = dim
        self.dim_inner = dim_inner
        self.norm_eps = norm_eps
        self.norm_momentum = norm_momentum
        self.norm = norm
        self.pool_size = pool_size
        self.use_pool = pool_size is not None
        self.instantiation = instantiation

        # 构建网络
        self._construct_nonlocal(norm_module)

    def _construct_nonlocal(self, norm_module):


        self.conv_theta = nn.Conv3d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
        self.conv_phi = nn.Conv3d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
        self.conv_g = nn.Conv3d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
        self.conv_out = nn.Conv3d(self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0)

        if self.norm:
            self.bn = norm_module(num_features=self.dim, eps=self.norm_eps, momentum=self.norm_momentum)
            self.bn.transform_final_bn = True

        if self.use_pool:

            self.pool = nn.MaxPool3d(kernel_size=self.pool_size, stride=self.pool_size, padding=[0, 0, 0])

    def forward(self, x):
        x_identity = x

        N, C, T, H, W = x.size()

        theta = self.conv_theta(x)
        if self.use_pool:
            x = self.pool(x)
        phi = self.conv_phi(x)
        g = self.conv_g(x)

        theta = theta.view(N, self.dim_inner, -1)
        phi = phi.view(N, self.dim_inner, -1)
        g = g.view(N, self.dim_inner, -1)

        # (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW).
        theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
        # For original Non-local paper, there are two main ways to normalize
        # the affinity tensor:
        #   1) Softmax normalization (norm on exp).
        #   2) dot_product normalization.
        if self.instantiation == "softmax":
            # Normalizing the affinity tensor theta_phi before softmax.
            theta_phi = theta_phi * (self.dim_inner ** -0.5)
            theta_phi = nn.functional.softmax(theta_phi, dim=2)
        elif self.instantiation == "dot_product":
            spatial_temporal_dim = theta_phi.shape[2]
            theta_phi = theta_phi / spatial_temporal_dim
        else:
            raise NotImplementedError("Unknown norm type {}".format(self.instantiation))

        # (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW).
        theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))

        # (N, C, TxHxW) => (N, C, T, H, W).
        theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W)


        p = self.conv_out(theta_phi_g)
        if self.norm:
            p = self.bn(p)

        return p + x_identity

2D代码如下

class Nonlocal_2d(nn.Module):

    def __init__(self, dim, dim_inner, pool_size=None, norm=False, norm_eps=1e-5, norm_momentum=0.1,
                 norm_module=nn.BatchNorm2d, instantiation="softmax"):
        """
        Args:
            dim (int): number of dimension for the input.
            dim_inner (int): number of dimension inside of the Non-local block.

            pool_size (list): the kernel size of spatial temporal pooling,
                temporal pool kernel size, spatial pool kernel size, spatial
                pool kernel size in order. By default pool_size is None,
                then there would be no pooling used.
            instantiation (string): supports two different instantiation method:
                "dot_product": normalizing correlation matrix with L2.
                "softmax": normalizing correlation matrix with Softmax.
            norm (bool): If true, add the final batch norm of the Non-local block.
            norm_module (nn.Module): nn.Module for the normalization layer. The
                default is nn.BatchNorm3d.
        """


        super().__init__()
        self.dim = dim
        self.dim_inner = dim_inner
        self.norm_eps = norm_eps
        self.norm_momentum = norm_momentum
        self.norm = norm
        self.pool_size = pool_size
        self.use_pool = pool_size is not None
        self.instantiation = instantiation

        # 构建网络
        self._construct_nonlocal(norm_module)

    def _construct_nonlocal(self, norm_module):


        self.conv_theta = nn.Conv2d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
        self.conv_phi = nn.Conv2d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
        self.conv_g = nn.Conv2d(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
        self.conv_out = nn.Conv2d(self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0)

        if self.norm:
            self.bn = norm_module(num_features=self.dim, eps=self.norm_eps, momentum=self.norm_momentum)
            self.bn.transform_final_bn = True

        if self.use_pool:

            self.pool = nn.MaxPool2d(kernel_size=self.pool_size, stride=self.pool_size, padding=[0, 0, 0])

    def forward(self, x):
        x_identity = x

        N, C, H, W = x.size()

        theta = self.conv_theta(x)
        if self.use_pool:
            x = self.pool(x)
        phi = self.conv_phi(x)
        g = self.conv_g(x)

        theta = theta.view(N, self.dim_inner, -1)
        phi = phi.view(N, self.dim_inner, -1)
        g = g.view(N, self.dim_inner, -1)

        # (N, C, HxW) * (N, C, HxW) => (N, HxW, TxHxW).
        theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
        # For original Non-local paper, there are two main ways to normalize
        # the affinity tensor:
        #   1) Softmax normalization (norm on exp).
        #   2) dot_product normalization.
        if self.instantiation == "softmax":
            # Normalizing the affinity tensor theta_phi before softmax.
            theta_phi = theta_phi * (self.dim_inner ** -0.5)
            theta_phi = nn.functional.softmax(theta_phi, dim=2)
        elif self.instantiation == "dot_product":
            spatial_temporal_dim = theta_phi.shape[2]
            theta_phi = theta_phi / spatial_temporal_dim
        else:
            raise NotImplementedError("Unknown norm type {}".format(self.instantiation))

        # (N, HxW, HxW) * (N, C, HxW) => (N, C, HxW).
        theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))

        # (N, C, HxW) => (N, C, H, W).
        theta_phi_g = theta_phi_g.view(N, self.dim_inner, H, W)


        p = self.conv_out(theta_phi_g)
        if self.norm:
            p = self.bn(p)

        return p + x_identity

通用代码如下:

#!/usr/bin/env python3

import torch
import torch.nn as nn


class Nonlocal(nn.Module):

    def __init__(self, dim, dim_inner, typeis='3D', pool_size=None, norm=False, norm_eps=1e-5, norm_momentum=0.1,
                 norm_module=nn.BatchNorm3d, instantiation="softmax"):
        """
        Args:
            dim (int): number of dimension for the input.
            dim_inner (int): number of dimension inside of the Non-local block.
            typeis (string):  supports two different convolution classes:
                2D convolution or 3D convolution.
                2D pool or 3D pool
            pool_size (list): the kernel size of spatial temporal pooling,
                temporal pool kernel size, spatial pool kernel size, spatial
                pool kernel size in order. By default pool_size is None,
                then there would be no pooling used.
            instantiation (string): supports two different instantiation method:
                "dot_product": normalizing correlation matrix with L2.
                "softmax": normalizing correlation matrix with Softmax.
            norm (bool): If true, add the final batch norm of the Non-local block.
            norm_module (nn.Module): nn.Module for the normalization layer. The
                default is nn.BatchNorm3d.
        """


        super().__init__()
        self.dim = dim
        self.dim_inner = dim_inner
        self.norm_eps = norm_eps
        self.norm_momentum = norm_momentum
        self.norm = norm
        self.pool_size = pool_size
        self.use_pool = pool_size is not None
        self.instantiation = instantiation
        self.type = typeis.lower()
        assert self.type in ['2d', '3d'], "dim of Nonlocal must be at ['2D','3D']."

        # 构建网络
        self._construct_nonlocal(norm_module)

    def _construct_nonlocal(self, norm_module):

        conv = eval('nn.Conv' + self.type)  # 自动适应2d和3d
        self.conv_theta = conv(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
        self.conv_phi = conv(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
        self.conv_g = conv(self.dim, self.dim_inner, kernel_size=1, stride=1, padding=0)
        self.conv_out = conv(self.dim_inner, self.dim, kernel_size=1, stride=1, padding=0)

        if self.norm:
            self.bn = norm_module(num_features=self.dim, eps=self.norm_eps, momentum=self.norm_momentum)
            self.bn.transform_final_bn = True

        if self.use_pool:
            pool = eval('nn.MaxPool' + self.type)
            self.pool = pool(kernel_size=self.pool_size, stride=self.pool_size, padding=[0, 0, 0])

    def forward(self, x):
        x_identity = x
        if self.type=='3d':
            N, C, T, H, W = x.size()
        elif self.type =="2d":
            N, C, H, W = x.size()
        theta = self.conv_theta(x)
        if self.use_pool:
            x = self.pool(x)
        phi = self.conv_phi(x)
        g = self.conv_g(x)

        theta = theta.view(N, self.dim_inner, -1)
        phi = phi.view(N, self.dim_inner, -1)
        g = g.view(N, self.dim_inner, -1)

        # (N, C, TxHxW) * (N, C, TxHxW) => (N, TxHxW, TxHxW).
        theta_phi = torch.einsum("nct,ncp->ntp", (theta, phi))
        # For original Non-local paper, there are two main ways to normalize
        # the affinity tensor:
        #   1) Softmax normalization (norm on exp).
        #   2) dot_product normalization.
        if self.instantiation == "softmax":
            # Normalizing the affinity tensor theta_phi before softmax.
            theta_phi = theta_phi * (self.dim_inner ** -0.5)
            theta_phi = nn.functional.softmax(theta_phi, dim=2)
        elif self.instantiation == "dot_product":
            spatial_temporal_dim = theta_phi.shape[2]
            theta_phi = theta_phi / spatial_temporal_dim
        else:
            raise NotImplementedError("Unknown norm type {}".format(self.instantiation))
        if self.type=='3d':
            # (N, TxHxW, TxHxW) * (N, C, TxHxW) => (N, C, TxHxW).
            theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))

            # (N, C, TxHxW) => (N, C, T, H, W).
            theta_phi_g = theta_phi_g.view(N, self.dim_inner, T, H, W)
        elif self.type =="2d":
            # (N, HxW, HxW) * (N, C, HxW) => (N, C, HxW).
            theta_phi_g = torch.einsum("ntg,ncg->nct", (theta_phi, g))

            # (N, C, HxW) => (N, C, H, W).
            theta_phi_g = theta_phi_g.view(N, self.dim_inner, H, W)

        p = self.conv_out(theta_phi_g)
        if self.norm:
            p = self.bn(p)

        return p + x_identity

构建参数说明:


    def __init__(self, dim, dim_inner, typeis='3D', pool_size=None, norm=False, norm_eps=1e-5, norm_momentum=0.1,
                 norm_module=nn.BatchNorm3d, instantiation="softmax"):

dim:输入维度
dim_inner: nonlocalblock内部变化的维度
typeis: 决定卷积的类别,如果是3D,则为nn.Conv3d;如果是2D,则为nn.Conv2d;
pool_size: 池化层尺寸,如果为None,则表示不使用pool层;如果使用,则输入pool的大小,根据typeis变化,2D为[x,x],3D为[x,x,x].
norm: 是否使用bn。
norm_module:使用什么正则化,nn对象。默认nn.BatchNorm3d。
instantiation:nonlocal中使用的激活函数类型。

标签:dim,基于,self,phi,nonlocalblock,pytorch,theta,norm,size
From: https://blog.csdn.net/weixin_44522636/article/details/143277272

相关文章