首页 > 其他分享 >代码笔记26 pytorch复现pointnet

代码笔记26 pytorch复现pointnet

时间:2022-10-13 21:55:29浏览次数:59  
标签:__ 26 nn self pointnet channels pytorch size out

1

浅浅记录一下model的复现,之后做好完整的工程放到github上

2

import torch.nn as nn
import torch
import numpy as np


class tnet(nn.Module):
    def __init__(self, inplanes: int):
        super(tnet, self).__init__()

        self.k = inplanes
        # conv layers in T-net
        self.relu = nn.ReLU(inplace=True)
        self.tconv1 = nn.Conv1d(in_channels=inplanes, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.tconv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm1d(128)
        self.tconv3 = nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm1d(1024)

        # fc layers in T-net
        self.tfc1 = nn.Linear(in_features=1024, out_features=512, bias=False)
        self.bnf1 = nn.BatchNorm1d(512)
        self.tfc2 = nn.Linear(in_features=512, out_features=256, bias=False)
        self.bnf2 = nn.BatchNorm1d(256)
        self.tfc3 = nn.Linear(in_features=256, out_features=inplanes ** 2, bias=False)

    def forward(self, x):
        # input size supposed to be (Batch, Numbers, Channels)
        B, C, N = x.size()
        assert (C == self.k), "input size is not suitable for the T-Net model!"

        # conv operations
        x1 = self.relu(self.bn1(self.tconv1(x)))
        x2 = self.relu(self.bn2(self.tconv2(x1)))
        x3 = self.bn3(self.tconv3(x2))

        # maxpool operation for global descriptors
        maxpool_x, _ = torch.max(x3, dim=2)

        # fc operations
        x4 = self.relu(self.bnf1(self.tfc1(maxpool_x)))
        x5 = self.relu(self.bnf2(self.tfc2(x4)))
        x6 = self.tfc3(x5)

        # reshape from (B, k**) to transform matrix (B, k, k)
        trans_matrix = torch.reshape(x6, (B, self.k, self.k))

        # the identity matrix
        iden_matrix = torch.from_numpy(np.eye(self.k).astype(np.float32)).repeat(B, 1, 1)

        # output the multipy results
        out = torch.matmul(trans_matrix + iden_matrix, x)
        return out


class pointnet_encoder(nn.Module):
    def __init__(self):
        super(pointnet_encoder, self).__init__()
        self.relu = nn.ReLU(inplace=True)

        self.trans1 = tnet(inplanes=3)

        self.econv1 = nn.Conv1d(in_channels=3, out_channels=64, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm1d(64)

        self.trans2 = tnet(inplanes=64)

        self.econv2 = nn.Conv1d(in_channels=64, out_channels=128, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm1d(128)
        self.econv3 = nn.Conv1d(in_channels=128, out_channels=1024, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm1d(1024)

    def forward(self, x):
        # change data type
        x = x.permute(0, 2, 1)

        # stage1 3*3 transform
        x1 = self.trans1(x)

        # stage2 shared MLP
        x2 = self.relu(self.bn1(self.econv1(x1)))

        # stage3 64*64 transform
        x3 = self.trans2(x2)

        # stage4 64-128-1024 shared MLPs
        x4 = self.relu(self.bn2(self.econv2(x3)))
        x5 = self.relu(self.bn3(self.econv3(x4)))

        glb = x5.permute(0, 2, 1)
        glb = torch.max(glb, dim=1)[0]

        seg = x3.permute(0, 2, 1)

        return seg, glb


class pointnet_seg(nn.Module):
    def __init__(self, num_classes):
        super(pointnet_seg, self).__init__()
        self.encoder = pointnet_encoder()

        self.relu = nn.ReLU(inplace=True)

        self.sconv1 = nn.Conv1d(in_channels=1088, out_channels=512, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm1d(512)

        self.sconv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm1d(256)

        self.sconv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm1d(128)

        self.finalconv = nn.Conv1d(in_channels=128, out_channels=num_classes, kernel_size=1, stride=1, padding=0,
                                   bias=True)

    def forward(self, x):
        # concatenate the global features and segmentation features
        B, N, C = x.size()
        seg_feat, glb_feat = self.encoder(x)
        glb_feat = glb_feat.unsqueeze(1).repeat(1, N, 1)

        cmb_feat = torch.cat([seg_feat, glb_feat], dim=2)

        seg_x = cmb_feat.permute(0, 2, 1)

        x1 = self.relu(self.bn1(self.sconv1(seg_x)))

        x2 = self.relu(self.bn2(self.sconv2(x1)))

        x3 = self.relu(self.bn3(self.sconv3(x2)))

        score = self.finalconv(x3).permute(0, 2, 1)

        return score

def main():
    model = pointnet_seg(num_classes=13)
    points = torch.randn([10, 100, 3])
    score = model(points)
    print("score shape is {}".format(score.size()))

    for name, para in model.state_dict(keep_vars=True).items():
        print(name, para.shape, para.requires_grad)


if __name__ == "__main__":
    main()

标签:__,26,nn,self,pointnet,channels,pytorch,size,out
From: https://www.cnblogs.com/HumbleHater/p/16789843.html

相关文章