首页 > 其他分享 >邻接矩阵、稀疏矩阵(torch, sparse, numpy)相互转换 [转载]

邻接矩阵、稀疏矩阵(torch, sparse, numpy)相互转换 [转载]

时间:2023-03-21 21:33:05浏览次数:53  
标签:torch 矩阵 邻接矩阵 sparse numpy mx coo

邻接矩阵转稀疏矩阵

Example:

import scipy.sparse as sp
import numpy as np
import torch

adj_matrix = torch.randint(0,2,(4,4))
print(adj_matrix)
# 输出:
# tensor([[1, 1, 0, 0],
#        [0, 1, 0, 1],
#        [0, 0, 1, 1],
#        [1, 0, 0, 0]])
# adj_matrix 是邻接矩阵
tmp_coo = sp.coo_matrix(adj_matrix)
values = tmp_coo.data
indices = np.vstack((tmp_coo.row,tmp_coo.col))
i = torch.LongTensor(indices)
v = torch.LongTensor(values)
edge_index=torch.sparse_coo_tensor(i,v,tmp_coo.shape)
print(edge_index)
# 输出:
#tensor(indices=tensor([[0, 0, 1, 1, 2, 2, 3],
#                       [0, 1, 1, 3, 2, 3, 0]]),
#       values=tensor([1, 1, 1, 1, 1, 1, 1]),
#       size=(4, 4), nnz=7, layout=torch.sparse_coo)

torch 矩阵转numpy 矩阵

A = torch.load('adj1.pt')
A = A.numpy()

numpy 矩阵转 scipy 稀疏矩阵

A = sp.coo_matrix(A)

scipy 稀疏矩阵转numpy 矩阵

A.toarray()

将 Scipy Sparse 矩阵转换成 torch sparse 矩阵

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

torch sparse矩阵转 torch 稠密矩阵

sparse_adj.to_dense()

标签:torch,矩阵,邻接矩阵,sparse,numpy,mx,coo
From: https://www.cnblogs.com/mercurysun/p/17241546.html

相关文章

  • Android数据结构-SparseArray实现原理
    SparseArray家族SparseArray基于键值对存储数据,key为int,value为object,简单使用如下://声明SparseArray<String>sparseArray=newSparseArray<>();......
  • Pytorch安装与基础知识
    Pytorch安装与基础知识安装环境:Win10专业版显卡:NviidaGeforceGTX1660Ti安装Anacodna官网下载安装安装CudaCuda官网下载安装包。进入CMD,使用命令nvcc-V测......
  • pytorch cnn 手写数字识别
    结果   训练好的模型呢训练过程中,不断变化   官网:是https://github.com/pytorch/examples/blob/main/mnist/main.py  test改个名字如......
  • RuntimeError: NCCL error in: /pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:784, u
    ​ 发现报错:RuntimeError:NCCLerrorin:/pytorch/torch/lib/c10d/ProcessGroupNCCL.cpp:784,unhandledsystemerror​编辑想在linux上跑跑mmclassification......
  • 【2023-Pytorch-检测教程】手把手教你使用YOLOV5做电线绝缘子缺陷检测
    随着社会和经济的持续发展,电力系统的投资与建设也日益加速。在电力系统中,输电线路作为电能传输的载体,是最为关键的环节之一。而绝缘子作为输电环节中的重要设备,在支撑固定导......
  • Pytorch安装
    Pytorch安装1.Anaconda的下载和安装Anaconda的官网在官网进行下载,一路next(可以修改安装路径,默认安装在C盘),安装完成后可以在菜单看到新增了一些文件打开这个Anaconda......
  • PyTorch学习笔记 8. 实现线性回归模型
    PyTorch学习笔记8.实现线性回归模型​​一、回归的概念​​​​1.概念​​​​2.目标​​​​3.应用​​​​4.训练线性回归的步骤​​​​二、数据集​​​​1.构......
  • 深度学习6. 多层感知机及PyTorch实现
    深度学习6.多层感知机及PyTorch实现​​一、概念​​​​1.MLP​​​​2.前向传播​​​​3.反向传播​​​​4.评估模式与训练模式​​​​二、模型定义​​​​1.加......
  • 人工智能-Pytorch案例实战(2)-CNN 的stride和zero-padding
    CNN的stridestride:是filter滑动图像的步长例如:stride=1,对于一个7*7的灰白图片,通过一个3*3大小的filter,输出下一个图片的大小为5*5(如何计算呢?公式呢?)(W-F+2P/......
  • 人工智能-Pytorch案例实战(1)-CNN Convolution Layer
    ConvolutionLayer左侧图示:一张彩色的图片,有三个部分组成(长度width宽度high深度depth),例如:32*32*3表示一彩色图片长度和宽度分别是32,32右侧图示:在CNN中,filter是一个......