首页 > 其他分享 >图伸神经网络GCN实现图内点云分类任务(物体的部件分类)

图伸神经网络GCN实现图内点云分类任务(物体的部件分类)

时间:2024-08-25 13:26:53浏览次数:9  
标签:图伸 acc torch 图内点 分类 channels device hidden data

点云分类任务


本项目是一个简单的使用图中点分类代码,内涵完整的网络搭建、模型训练、模型保存、模型调用、可视化、的全过程。可以帮助初学者快速熟悉流程。帮助入门。

数据集下载

关键代码

项目使用了shapenet数据集中的飞机类数据集,在使用图神经网络飞机上进行部件分割,本项目写了一个自动下载数据集的方法,直接运行项目会自动下载数据集。关键部分代码如下。

def load_data():
    path = './data'
    if not os.path.exists(path):
        # 如果目录不存在,则创建它
        os.makedirs(path)
        print(f"目录 '{path}' 已创建。")
    else:
        print(f"目录 '{path}' 已存在。")
    train_data = ShapeNet(root=path, categories=['Airplane'], split='trainval', pre_transform=T.KNNGraph(k=6))
    test_data = ShapeNet(root=path, categories=['Airplane'], split='test', pre_transform=T.KNNGraph(k=6))
    return train_data, test_data
// A code block
这个代码先检查了存储路径是否存在,不存在创建一个,之后在直接下载数据集。
我们下载的时候选定了预处理pre_transform=T.KNNGraph(k=6),会预先使用
k临近算法给图数据生成边。这个方法下载了训练集和测试集。

数据集结构

 t, t1 = load_data()
 print(t)
 print(t[0])
 print(len(t))
// 输出
ShapeNet(2349, categories=['Airplane'])
# 这句的意思一共2349个图,都是飞机类
Data(x=[2518, 3], y=[2518], pos=[2518, 3], category=[1], 
edge_index=[2, 15108])
# 这里是选中了一张图看看里面的结构,x是点特征每个点是三维的,意思是一个
# 点用三个数表示,可以联想成空间里面点的x,y,z吧。pos是空间的点坐标也是
# 2518个。y是各个点的标签,用来表示这个点是属于那个部件的。category
# 表示整个图是什么类,这里只有飞机类,所以说只有一个数。edge_index是
# 邻接矩阵,指明那个点之间有连接
2349

网络模型

class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.conv4 = GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, num_classes)

    def forward(self, data, edge_index):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        data = data.to(device)
        edge_index = edge_index.to(device)
        x = data.x
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = self.conv4(x, edge_index)
        x = self.lin(x)

        return x

1 __init__ 方法:
in_channels:输入特征的维度,即每个节点的特征向量的维度。
hidden_channels:隐藏层的维度,代表在每一层中节点特征被投影到的空间维度。
num_classes:输出类别的数量,表示模型要分类的类别数。
该方法初始化了网络的各个层,包括四个图卷积层(GCNConv)和一个线性层(nn.Linear):
conv1: 将输入特征从 in_channels 映射到 hidden_channels。
conv2, conv3, conv4: 每一层将特征从 hidden_channels 映射到相同的 hidden_channels。
lin: 线性层,将最后一个隐藏层的输出特征映射到 num_classes,用于最终的分类任务。

2 forward 方法: 该方法定义了前向传播的过程,即输入数据通过网络时的运算流程。
device:自动检测并选择使用 GPU(如果可用)或 CPU 作为计算设备。
data.to(device) 和edge_index.to(device): 将输入数据和边信息移动到指定的设备(CPU 或 GPU)。
x = data.x: 提取输入节点的特征矩阵 x。
conv1 -> conv4: 执行四次图卷积操作,每次卷积后使用 ReLU 激活函数。ReLU 引入非线性,使模型能够学习复杂的模式。
lin: 最后一层是线性层,将最终的图卷积输出特征映射到类别空间,用于分类任务。

模型训练

训练关键代码如下:

def train_model(epochs):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train_, test_ = load_data()
    num_class = len(np.unique(train_[0].y))
    net = Net(in_channels=3, hidden_channels=32, num_classes=num_class).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
    criterion = torch.nn.CrossEntropyLoss()
    ok = None
    for epoch in range(epochs):
        train_loss, train_acc = 0, 0
        max_acc = 0
        for i in tqdm(range(len(train_)), desc='{}/{}'.format(epoch + 1, epochs)):
            data = train_[i].to(device)
            optimizer.zero_grad()
            output = net(data, data.edge_index)
            loss = criterion(output, data.y)
            # print(loss)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            _, preds = torch.max(output, 1)
            train_acc += (preds == data.y).sum().item() / len(data.y)
        avg_loss = train_loss / len(data.y)
        avg_acc = train_acc / len(data.y)
        if avg_acc > max_acc:
            max_acc = avg_acc
            torch.save(net.state_dict(), './models/best_{}.pt'.format(max_acc))
            print('找到更高准确率的模型,准确率为{}'.format(max_acc))
        print('平均损失{}   平均准确率{}'.format(avg_loss, avg_acc))

这段代码定义了一个训练方法,接受一个epochs,指明要训练多少次使用了CrossEntropyLoss做损失函数。训练过程中每个epoch检查一遍模型准确率若准确率比前者高就保存本次训练模型。
训练过程截图:
在这里插入图片描述

测试模型+可视化结果

关键代码:

# 使用open3D进行可视化
def visualize(data_x, data_y):
    # 创建点云对象
    data_x = data_x.cpu().detach().numpy()
    data_y = data_y.cpu().detach().numpy()
    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(data_x)

    # 创建颜色映射字典
    color_map = {
        0: [1, 0, 0],  # 类别0:红色
        1: [0, 1, 0],  # 类别1:绿色
        2: [0, 0, 1],  # 类别2:蓝色
        3: [0, 1, 1],  # 类别3:青色
        # 添加更多类别及其颜色
    }

    # 将标签转换为颜色
    colors = np.array([color_map.get(label, [0, 0, 0]) for label in data_y])
    pcd.colors = o3d.utility.Vector3dVector(colors)

    # 可视化点云
    o3d.visualization.draw_geometries([pcd])

//进行测试并且可视化,这里就拿了其中一张图进行可视化
import numpy as np
from Train import load_data
from Net import Net
import os
import torch
from visulizd import visualize
from tqdm import tqdm
if __name__ == '__main__':
    # 取第120张图进行分类
    i = 120
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    train, test = load_data()
    num_c = len(np.unique(test[0].y))
    # test = test.to(device)
    model = Net(in_channels=3, hidden_channels=32, num_classes=num_c)
    model.load_state_dict(torch.load("./models/best_0.7097560044925633.pt"))
    model.eval()
    model.to(device)
    out = model(test[i], test[i].edge_index)
    criterion = torch.nn.CrossEntropyLoss()
    loss = criterion(out, test[i].y.to(device))
    acc = out.argmax(dim=1).eq(test[i].y.to(device)).sum().item() / len(test[i].y)
    print('损失{}  准确率{}'.format(loss, acc))
    visualize(test[i].pos, out.argmax(dim=1))


这里我只用了训练20次的模型,若想追求更高准确率,可以增加训练次数和优化网络模型结构
在这里插入图片描述

在这里插入图片描述

可能会出现的问题

pyg没有配置好

可以查看我这个博客,能够完美配置好pyg
链接: 完美配置pyg

懒人专属(代码链接)

链接: 代码链接

标签:图伸,acc,torch,图内点,分类,channels,device,hidden,data
From: https://blog.csdn.net/JTYANGWEI/article/details/141527534

相关文章

  • 基于SSM的垃圾分类管理系统的设计与实现 (含源码+sql+视频导入教程+论文)
    ......
  • (苍穹外卖)day02 员工管理 分类管理
    目录一.新增员工1.需求分析和设计2.代码开发3.功能测试4.代码完善二.员工的分页查询1.需求分析和设计2.代码开发3.功能测试与代码完善三.启用禁用员工账号 1.需求分析和设计 2.代码开发四.编辑员工1.需求分析和设计2.代码开发五.导入模块功能代码一.新......
  • 相遇(容斥+最短路+分类,水紫)
    第5题   相遇 查看测评数据信息给定一个有n个节点m条边的无向图,在某一时刻节点st上有一个动点a,节点end上有一个动点b,动点a向节点end方向移动,要求是尽快到达end点,与此同时,动点b向节点st方向移动,要求是尽快到达st点,但是整个过程中a和b不能相遇,问两点不相遇一共有多少种......
  • IP地址的五大分类及回环地址
    你好,我是沐爸,欢迎点赞、收藏和关注。个人知乎IP地址根据网络号的不同可以分为五大类,即A类、B类、C类、D类和E类。以下是这五大类IP地址的详细介绍:1.A类地址地址范围:1.0.0.1~126.255.255.254特点:第1个字节为网络地址,其他3个字节为主机地址。网络地址的最高位始终是0......
  • Mac M1用tensorflow中的Keras进行基本图像分类
    一.为什么要进行图像分类、图像识别目的是为了利用计算机对图像进行处理、分析和理解,让计算机能够像人类一样理解和解释图像中的内容。‌这一技术的应用范围广泛,包括但不限于人脸识别和商品识别。人脸识别技术主要应用于安全检查、身份核验与移动支付等领域,而商品识别则广......
  • 如何建立一种检测漏油的不平衡分类模型
    如何建立一种检测漏油的不平衡分类模型许多不平衡的分类任务需要一个熟练的模型来预测清晰的类别标签,其中两个类别同等重要。不平衡分类问题的一个例子是检测卫星图像中的漏油或浮油,其中需要一个类别标签,并且两个类别同等重要。检测漏油需要动员昂贵的响应,而错过事件同样......
  • 常用Linux操作系统分类
    Linux操作系统由于其开源的特点,受到世界各国计算机软件企业和Linux操作系统爱好者的青睐。因此,各种发行版本的Linux操作系统出现在计算机操作系统市场和开源社区。为了能让大家对各种Linux操作系统进行区分认识,就让我对其进行梳理分类。一、按发展体系分类第一类是基于Debia......
  • 1075 链表元素分类——PAT乙级
    给定一个单链表,请编写程序将链表元素进行分类排列,使得所有负值元素都排在非负值元素的前面,而[0,K]区间内的元素都排在大于K的元素前面。但每一类内部元素的顺序是不能改变的。例如:给定链表为18→7→-4→0→5→-6→10→11→-2,K为10,则输出应该为-4→-6→-2→7→0→5→10......
  • 机器学习—KNN算法-分类及模型选择与调优
    KNN算法-分类样本距离判断:欧氏距离、曼哈顿距离、明可夫斯基距离KNN算法原理:        K-近邻算法(K-NearestNeighbors,简称KNN),根据K个邻居样本的类别来判断当前样本的类别;如果一个样本在特征空间中的k个最相似(最邻近)样本中的大多数属于某个类别,......
  • 网站分类错误怎么办
    网站分类错误通常是指网站的内容被错误地归类或者是在某些安全设备(如防火墙、安全网关等)中被标记为不正确的类别。这可能导致访问受限或被阻止。以下是解决网站分类错误的一些方法:检查网站内容:确认网站的实际内容与分类是否相符。如果网站内容发生了变化,可能需要更新分类。......