首页 > 其他分享 >5、SimGNN实战

5、SimGNN实战

时间:2023-09-26 22:45:46浏览次数:37  
标签:实战 torch features self args batch SimGNN data

一、概述

文献标题:SimGNN: A Neural Network Approach to Fast Graph Similarity Computation
来源:WSDM2018( 网络搜索和数据挖掘国际会议)
论文链接:
https://arxiv.org/abs/1808.05689
代码链接:
https://paperswithcode.com/paper/graph-edit-distance-computation-via-graph#code

目的:使用图神经网络的方法计算图相似度,并减轻计算的负担。

创新处:SimGNN的方法结合了两种策略。一、首先设计了一个可学习的嵌入函数,将每个图映射为一个嵌入向量,该向量提供了图的全局摘要,该策略提出了一种新的注意机制来强调特定相似度度量下的重要节点。二、设计了一种成对节点比较方法,用细粒度节点信息补充图级嵌入。

引言

设计了一个基于神经网络的函数,将一对图映射成一个相似度评分。在训练阶段,该函数所涉及的参数将通过最小化预测的相似度分数与事实(真是标签)的差来学习,其中每个训练数据点是一对图及其真实相似度分数。在测试阶段,通过向学习的函数输入任意一对图,我们可以得到一个预测的相似度分数。我们将这种方法命名为SimGNN,即通过图神经网络进行相似性计算。

模型的优势

(1)表示不变。通过改变节点的顺序,可以用不同的邻接矩阵来表示同一个图。所计算的相似性得分对于这种变化应该是不变的。

(2)归纳。相似性计算应该推广到看不见的图,即计算训练图对之外的图的相似性得分。

(3)可学。通过训练调整其参数,该模型应该适应任何相似性度量。


二、背景

Background and Motivation

图相似度搜索具有重要的意义,比如找到与query化合物最相似的化合物等。通常用图编辑距离或者最大共同子图来衡量图的相似度,然而这两个指标的计算复杂度都是很高的(NP-complete)。这篇文章提出了一种基于图神经网络的方法来解决这一问题。

Main idea

神经网络学习的对象是从输入 一对图(a pair of graphs)到输出 两个图的相似度分数 的映射。因此是一种有监督的学习,需要知道输入图对相似度的ground truth。

Network structure

一个简单直接的思想就是:给定一对图,我们需要将图进行向量表示,再根据图对应的向量来计算相似度,也就是 graph embedding。在此基础上,考虑到只利用 graph embedding 可能忽略了局部节点的差异性,因此作者进一步考虑了两个图中节点之间的相关性或者是差异性 (pairwise node comparison)。

 三、代码

1、SimGNN和计算直方图

import torch
import random
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm, trange
from scipy.stats import spearmanr, kendalltau

from layers import AttentionModule, TensorNetworkModule, DiffPool
from utils import calculate_ranking_correlation, calculate_prec_at_k, gen_pairs

from torch_geometric.nn import GCNConv, GINConv
from torch_geometric.data import DataLoader, Batch
from torch_geometric.utils import to_dense_batch, to_dense_adj, degree
from torch_geometric.datasets import GEDDataset
from torch_geometric.transforms import OneHotDegree

import matplotlib.pyplot as plt


class SimGNN(torch.nn.Module):
    """
    SimGNN: A Neural Network Approach to Fast Graph Similarity Computation
    https://arxiv.org/abs/1808.05689
    """

    def __init__(self, args, number_of_labels):
        """
        :param args: Arguments object.
        :param number_of_labels: Number of node labels.
        """
        super(SimGNN, self).__init__()
        self.args = args
        self.number_labels = number_of_labels
        self.setup_layers()

    def calculate_bottleneck_features(self):
        """
        Deciding the shape of the bottleneck layer.
        """
        if self.args.histogram:
            self.feature_count = self.args.tensor_neurons + self.args.bins
        else:
            self.feature_count = self.args.tensor_neurons

    def setup_layers(self):
        """
        Creating the layers.
        """
        self.calculate_bottleneck_features()
        if self.args.gnn_operator == "gcn":
            self.convolution_1 = GCNConv(self.number_labels, self.args.filters_1)
            self.convolution_2 = GCNConv(self.args.filters_1, self.args.filters_2)
            self.convolution_3 = GCNConv(self.args.filters_2, self.args.filters_3)
        elif self.args.gnn_operator == "gin":
            nn1 = torch.nn.Sequential(
                torch.nn.Linear(self.number_labels, self.args.filters_1),
                torch.nn.ReLU(),
                torch.nn.Linear(self.args.filters_1, self.args.filters_1),
                torch.nn.BatchNorm1d(self.args.filters_1),
            )

            nn2 = torch.nn.Sequential(
                torch.nn.Linear(self.args.filters_1, self.args.filters_2),
                torch.nn.ReLU(),
                torch.nn.Linear(self.args.filters_2, self.args.filters_2),
                torch.nn.BatchNorm1d(self.args.filters_2),
            )

            nn3 = torch.nn.Sequential(
                torch.nn.Linear(self.args.filters_2, self.args.filters_3),
                torch.nn.ReLU(),
                torch.nn.Linear(self.args.filters_3, self.args.filters_3),
                torch.nn.BatchNorm1d(self.args.filters_3),
            )

            self.convolution_1 = GINConv(nn1, train_eps=True)
            self.convolution_2 = GINConv(nn2, train_eps=True)
            self.convolution_3 = GINConv(nn3, train_eps=True)
        else:
            raise NotImplementedError("Unknown GNN-Operator.")

        if self.args.diffpool:
            self.attention = DiffPool(self.args)
        else:
            self.attention = AttentionModule(self.args)

        self.tensor_network = TensorNetworkModule(self.args)
        self.fully_connected_first = torch.nn.Linear(
            self.feature_count, self.args.bottle_neck_neurons
        )
        self.scoring_layer = torch.nn.Linear(self.args.bottle_neck_neurons, 1)

    def calculate_histogram(
        self, abstract_features_1, abstract_features_2, batch_1, batch_2
    ):
        """
        Calculate histogram from similarity matrix.
        :param abstract_features_1: Feature matrix for target graphs.
        :param abstract_features_2: Feature matrix for source graphs.
        :param batch_1: Batch vector for source graphs, which assigns each node to a specific example
        :param batch_1: Batch vector for target graphs, which assigns each node to a specific example
        :return hist: Histsogram of similarity scores.
        """
        print(abstract_features_1.shape)#torch.Size([1156, 16])
        print(abstract_features_2.shape)#torch.Size([1156, 16])
        #to_dense_batch意思是稀疏向量转稠密
        #另外,每个图最大10个节点,不足10个的补上。因此,mask返回了一个128行,10列的【true,false】矩阵,一行一个图,10列表示10个节点。补上的节点为false
        abstract_features_1, mask_1 = to_dense_batch(abstract_features_1, batch_1)
        print(abstract_features_1.shape)#torch.Size([128, 10, 16])
        print(mask_1.shape)#torch.Size([128, 10])
        abstract_features_2, mask_2 = to_dense_batch(abstract_features_2, batch_2)
        print(abstract_features_2.shape)#torch.Size([128, 10, 16])
        print(mask_2.shape)#torch.Size([128, 10])
        B1, N1, _ = abstract_features_1.size()
        B2, N2, _ = abstract_features_2.size()
        #b=128 n=10
        mask_1 = mask_1.view(B1, N1)
        mask_2 = mask_2.view(B2, N2)
        num_nodes = torch.max(mask_1.sum(dim=1), mask_2.sum(dim=1))
        #128的数组,数组中每一个数是一对图的最大节点数。
        """
        .detach()方法用于将张量从计算图中分离出来,得到一个新的张量。
        在PyTorch中,计算图是用于自动求导的一种机制。当我们对张量进行操作时,
        PyTorch会自动构建一个计算图,用于跟踪张量之间的依赖关系,并计算梯度。计算图的构建过程会消耗一定的内存和计算资源。
        使用.detach()方法可以将张量从计算图中分离出来,得到一个新的张量,新的张量不再与原始计算图相关联。
        这意味着新的张量不会再参与梯度计算,也不会影响原始张量的梯度计算。
        具体来说,.detach()方法会返回一个新的张量,该张量与原始张量的值相同,
        但是不再具有梯度信息。这对于需要保留中间结果但不需要进行梯度计算的情况非常有用。
        例如,在训练神经网络时,有时我们需要计算某个中间结果,并将其用于后续的计算,但是不希望中间结果对网络参数进行梯度传播。
        这时,可以使用.detach()方法将中间结果从计算图中分离出来,保留其值,并将其用于后续的计算,而不会对网络参数进行梯度计算。
        总之,.detach()方法用于将张量从计算图中分离出来,得到一个新的张量,新的张量不再参与梯度计算。
        """
        scores = torch.matmul(
            abstract_features_1, abstract_features_2.permute([0, 2, 1])
        ).detach()#[128,10,16],[128,16,10]
        print(scores.shape)#[128,10,10]
        hist_list = []#128个1行16列的矩阵
        for i, mat in enumerate(scores):
            #mat[10,10],对于具体一对网络的score矩阵,由于矩阵乘得到的得分矩阵是相似度,得出了10个节点与另10个节点的相似度。
            mat = torch.sigmoid(mat[: num_nodes[i], : num_nodes[i]]).view(-1)#展平100
            print(mat.shape)
            hist = torch.histc(mat, bins=self.args.bins)#bin是16,画出数组的直方图。16堆
            print(hist.shape)#【16】
            hist = hist / torch.sum(hist)
            hist = hist.view(1, -1)#【1,16】
            print(hist.shape)
            hist_list.append(hist)
        print(torch.stack(hist_list).view(-1, self.args.bins).shape)
        """
        import torch
        x = torch.tensor([1, 2, 3])
        y = torch.tensor([4, 5, 6])
        z = torch.stack([x, y], dim=0)
        print(z)
        输出结果为:
        tensor([[1, 2, 3],
                [4, 5, 6]])
        .view(-1, self.args.bins)将堆叠后的张量形状重塑为(-1, self.args.bins),
        其中-1表示根据其他维度的大小自动计算该维度的大小,而self.args.bins表示指定的维度大小。
        """
        return torch.stack(hist_list).view(-1, self.args.bins) #【128,16】

    def convolutional_pass(self, edge_index, features):
        """
        Making convolutional pass.
        :param edge_index: Edge indices.
        :param features: Feature matrix.
        :return features: Abstract feature matrix.
        """
        features = self.convolution_1(features, edge_index)
        features = F.relu(features)
        features = F.dropout(features, p=self.args.dropout, training=self.training)
        features = self.convolution_2(features, edge_index)
        features = F.relu(features)
        features = F.dropout(features, p=self.args.dropout, training=self.training)
        features = self.convolution_3(features, edge_index)
        return features

    def diffpool(self, abstract_features, edge_index, batch):
        """
        Making differentiable pooling.
        :param abstract_features: Node feature matrix.
        :param edge_index: Edge indices
        :param batch: Batch vector, which assigns each node to a specific example
        :return pooled_features: Graph feature matrix.
        """
        x, mask = to_dense_batch(abstract_features, batch)
        adj = to_dense_adj(edge_index, batch)
        return self.attention(x, adj, mask)

    def forward(self, data):
        """
        Forward pass with graphs.
        :param data: Data dictionary.
        :return score: Similarity score.
        """
        edge_index_1 = data["g1"].edge_index
        edge_index_2 = data["g2"].edge_index
        features_1 = data["g1"].x
        print(features_1.shape) #torch.Size([1152, 29])
        features_2 = data["g2"].x
        batch_1 = (
            data["g1"].batch
            if hasattr(data["g1"], "batch")
            else torch.tensor((), dtype=torch.long).new_zeros(data["g1"].num_nodes)
        )
        batch_2 = (
            data["g2"].batch
            if hasattr(data["g2"], "batch")
            else torch.tensor((), dtype=torch.long).new_zeros(data["g2"].num_nodes)
        )
        #两个图过同一个GIN
        abstract_features_1 = self.convolutional_pass(edge_index_1, features_1)
        print(abstract_features_1.shape)#torch.Size([1156, 16])
        abstract_features_2 = self.convolutional_pass(edge_index_2, features_2)

        # 得到直方图向量
        if self.args.histogram:
            hist = self.calculate_histogram(
                abstract_features_1, abstract_features_2, batch_1, batch_2
            )

        # 得到图级别的向量
        if self.args.diffpool:
            pooled_features_1 = self.diffpool(
                abstract_features_1, edge_index_1, batch_1
            )
            pooled_features_2 = self.diffpool(
                abstract_features_2, edge_index_2, batch_2
            )
        else:
            pooled_features_1 = self.attention(abstract_features_1, batch_1)
            print(pooled_features_1.shape)
            pooled_features_2 = self.attention(abstract_features_2, batch_2)

        #TNT模块,意思类似与SVD学习隐向量。例如【老虎和尾巴两个实体之间的关系,用户和商品的某个关系】
        scores = self.tensor_network(pooled_features_1, pooled_features_2)
        print(scores.shape)
        if self.args.histogram:
            scores = torch.cat((scores, hist), dim=1)

        scores = F.relu(self.fully_connected_first(scores))
        print(scores.shape)
        score = torch.sigmoid(self.scoring_layer(scores)).view(-1)
        print(score.shape)
        return score


class SimGNNTrainer(object):
    """
    SimGNN model trainer.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        self.args = args
        self.process_dataset()
        self.setup_model()

    def setup_model(self):
        """
        Creating a SimGNN.
        """
        self.model = SimGNN(self.args, self.number_of_labels)

    def save(self):
        """
        Saving model.
        """
        torch.save(self.model.state_dict(), self.args.save)
        print(f"Model is saved under {self.args.save}.")

    def load(self):
        """
        Loading model.
        """
        self.model.load_state_dict(torch.load(self.args.load))
        print(f"Model is loaded from {self.args.save}.")

    def process_dataset(self):
        """
        Downloading and processing dataset.
        """
        print("\nPreparing dataset.\n")

        self.training_graphs = GEDDataset(
            "datasets/{}".format(self.args.dataset), self.args.dataset, train=True
        )
        self.testing_graphs = GEDDataset(
            "datasets/{}".format(self.args.dataset), self.args.dataset, train=False
        )
        self.nged_matrix = self.training_graphs.norm_ged
        self.real_data_size = self.nged_matrix.size(0)

        if self.args.synth:
            # self.synth_data_1, self.synth_data_2, _, synth_nged_matrix = gen_synth_data(500, 10, 12, 0.5, 0, 3)
            self.synth_data_1, self.synth_data_2, _, synth_nged_matrix = gen_pairs(
                self.training_graphs.shuffle()[:500], 0, 3
            )

            real_data_size = self.nged_matrix.size(0)
            synth_data_size = synth_nged_matrix.size(0)
            self.nged_matrix = torch.cat(
                (
                    self.nged_matrix,
                    torch.full((real_data_size, synth_data_size), float("inf")),
                ),
                dim=1,
            )
            synth_nged_matrix = torch.cat(
                (
                    torch.full((synth_data_size, real_data_size), float("inf")),
                    synth_nged_matrix,
                ),
                dim=1,
            )
            self.nged_matrix = torch.cat((self.nged_matrix, synth_nged_matrix))

        if self.training_graphs[0].x is None:
            max_degree = 0
            for g in (
                self.training_graphs
                + self.testing_graphs
                + (self.synth_data_1 + self.synth_data_2 if self.args.synth else [])
            ):
                if g.edge_index.size(1) > 0:
                    max_degree = max(
                        max_degree, int(degree(g.edge_index[0]).max().item())
                    )
            one_hot_degree = OneHotDegree(max_degree, cat=False)
            self.training_graphs.transform = one_hot_degree
            self.testing_graphs.transform = one_hot_degree

            # labeling of synth data according to real data format
            if self.args.synth:
                for g in self.synth_data_1 + self.synth_data_2:
                    g = one_hot_degree(g)
                    g.i = g.i + real_data_size
        elif self.args.synth:
            for g in self.synth_data_1 + self.synth_data_2:
                g.i = g.i + real_data_size
                # g.x = torch.cat((g.x, torch.zeros((g.x.size(0), self.training_graphs.num_features-1))), dim=1)

        self.number_of_labels = self.training_graphs.num_features

    def create_batches(self):
        """
        Creating batches from the training graph list.
        :return batches: Zipped loaders as list.
        """
        if self.args.synth:
            synth_data_ind = random.sample(range(len(self.synth_data_1)), 100)

        source_loader = DataLoader(
            self.training_graphs.shuffle()
            + (
                [self.synth_data_1[i] for i in synth_data_ind]
                if self.args.synth
                else []
            ),
            batch_size=self.args.batch_size,
        )
        target_loader = DataLoader(
            self.training_graphs.shuffle()
            + (
                [self.synth_data_2[i] for i in synth_data_ind]
                if self.args.synth
                else []
            ),
            batch_size=self.args.batch_size,
        )

        return list(zip(source_loader, target_loader))

    def transform(self, data):
        """
        Getting ged for graph pair and grouping with data into dictionary.
        :param data: Graph pair.
        :return new_data: Dictionary with data.
        """
        new_data = dict()

        new_data["g1"] = data[0]
        new_data["g2"] = data[1]

        normalized_ged = self.nged_matrix[
            data[0]["i"].reshape(-1).tolist(), data[1]["i"].reshape(-1).tolist()
        ].tolist()
        new_data["target"] = (
            torch.from_numpy(np.exp([(-el) for el in normalized_ged])).view(-1).float()
        )
        return new_data

    def process_batch(self, data):
        """
        Forward pass with a data.
        :param data: Data that is essentially pair of batches, for source and target graphs.
        :return loss: Loss on the data.
        """
        self.optimizer.zero_grad()
        data = self.transform(data)
        target = data["target"]
        prediction = self.model(data)
        loss = F.mse_loss(prediction, target, reduction="sum")
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def fit(self):
        """
        Training a model.
        """
        print("\nModel training.\n")
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.args.learning_rate,
            weight_decay=self.args.weight_decay,
        )
        self.model.train()

        epochs = trange(self.args.epochs, leave=True, desc="Epoch")
        loss_list = []
        loss_list_test = []
        for epoch in epochs:

            if self.args.plot:
                if epoch % 10 == 0:
                    self.model.train(False)
                    cnt_test = 20
                    cnt_train = 100
                    t = tqdm(
                        total=cnt_test * cnt_train,
                        position=2,
                        leave=False,
                        desc="Validation",
                    )
                    scores = torch.empty((cnt_test, cnt_train))

                    for i, g in enumerate(self.testing_graphs[:cnt_test].shuffle()):
                        source_batch = Batch.from_data_list([g] * cnt_train)
                        target_batch = Batch.from_data_list(
                            self.training_graphs[:cnt_train].shuffle()
                        )
                        data = self.transform((source_batch, target_batch))
                        target = data["target"]
                        prediction = self.model(data)

                        scores[i] = F.mse_loss(
                            prediction, target, reduction="none"
                        ).detach()
                        t.update(cnt_train)

                    t.close()
                    loss_list_test.append(scores.mean().item())
                    self.model.train(True)

            batches = self.create_batches()
            main_index = 0
            loss_sum = 0
            for index, batch_pair in tqdm(
                enumerate(batches), total=len(batches), desc="Batches", leave=False
            ):
                loss_score = self.process_batch(batch_pair)
                main_index = main_index + batch_pair[0].num_graphs
                loss_sum = loss_sum + loss_score
            loss = loss_sum / main_index
            epochs.set_description("Epoch (Loss=%g)" % round(loss, 5))
            loss_list.append(loss)

        if self.args.plot:
            plt.plot(loss_list, label="Train")
            plt.plot(
                [*range(0, self.args.epochs, 10)], loss_list_test, label="Validation"
            )
            plt.ylim([0, 0.01])
            plt.legend()
            filename = self.args.dataset
            filename += "_" + self.args.gnn_operator
            if self.args.diffpool:
                filename += "_diffpool"
            if self.args.histogram:
                filename += "_hist"
            filename = filename + str(self.args.epochs) + ".pdf"
            plt.savefig(filename)

    def measure_time(self):
        import time

        self.model.eval()
        count = len(self.testing_graphs) * len(self.training_graphs)

        t = np.empty(count)
        i = 0
        tq = tqdm(total=count, desc="Graph pairs")
        for g1 in self.testing_graphs:
            for g2 in self.training_graphs:
                source_batch = Batch.from_data_list([g1])
                target_batch = Batch.from_data_list([g2])
                data = self.transform((source_batch, target_batch))

                start = time.process_time()
                self.model(data)
                t[i] = time.process_time() - start
                i += 1
                tq.update()
        tq.close()

        print(
            "Average time (ms): {}; Standard deviation: {}".format(
                round(t.mean() * 1000, 5), round(t.std() * 1000, 5)
            )
        )

    def score(self):
        """
        Scoring.
        """
        print("\n\nModel evaluation.\n")
        self.model.eval()

        scores = np.empty((len(self.testing_graphs), len(self.training_graphs)))
        ground_truth = np.empty((len(self.testing_graphs), len(self.training_graphs)))
        prediction_mat = np.empty((len(self.testing_graphs), len(self.training_graphs)))

        rho_list = []
        tau_list = []
        prec_at_10_list = []
        prec_at_20_list = []

        t = tqdm(total=len(self.testing_graphs) * len(self.training_graphs))

        for i, g in enumerate(self.testing_graphs):
            source_batch = Batch.from_data_list([g] * len(self.training_graphs))
            target_batch = Batch.from_data_list(self.training_graphs)

            data = self.transform((source_batch, target_batch))
            target = data["target"]
            ground_truth[i] = target
            prediction = self.model(data)
            prediction_mat[i] = prediction.detach().numpy()

            scores[i] = (
                F.mse_loss(prediction, target, reduction="none").detach().numpy()
            )

            rho_list.append(
                calculate_ranking_correlation(
                    spearmanr, prediction_mat[i], ground_truth[i]
                )
            )
            tau_list.append(
                calculate_ranking_correlation(
                    kendalltau, prediction_mat[i], ground_truth[i]
                )
            )
            prec_at_10_list.append(
                calculate_prec_at_k(10, prediction_mat[i], ground_truth[i])
            )
            prec_at_20_list.append(
                calculate_prec_at_k(20, prediction_mat[i], ground_truth[i])
            )

            t.update(len(self.training_graphs))

        self.rho = np.mean(rho_list).item()
        self.tau = np.mean(tau_list).item()
        self.prec_at_10 = np.mean(prec_at_10_list).item()
        self.prec_at_20 = np.mean(prec_at_20_list).item()
        self.model_error = np.mean(scores).item()
        self.print_evaluation()

    def print_evaluation(self):
        """
        Printing the error rates.
        """
        print("\nmse(10^-3): " + str(round(self.model_error * 1000, 5)) + ".")
        print("Spearman's rho: " + str(round(self.rho, 5)) + ".")
        print("Kendall's tau: " + str(round(self.tau, 5)) + ".")
        print("p@10: " + str(round(self.prec_at_10, 5)) + ".")
        print("p@20: " + str(round(self.prec_at_20, 5)) + ".")

2、注意力得到全局

class AttentionModule(torch.nn.Module):
    """
    SimGNN Attention Module to make a pass on graph.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        super(AttentionModule, self).__init__()
        self.args = args
        self.setup_weights()
        self.init_parameters()

    def setup_weights(self):
        """
        Defining weights.
        """
        self.weight_matrix = torch.nn.Parameter(
            torch.Tensor(self.args.filters_3, self.args.filters_3)
        )

    def init_parameters(self):
        """
        Initializing weights.
        """
        torch.nn.init.xavier_uniform_(self.weight_matrix)

    def forward(self, x, batch, size=None):
        """
        Making a forward propagation pass to create a graph level representation.
        :param x: Result of the GNN.
        :param size: Dimension size for scatter_
        :param batch: Batch vector, which assigns each node to a specific example
        :return representation: A graph level representation matrix.
        输入张量input_tensor的形状为(3, 3),内容如下:
        tensor([[1, 2, 3],
                [4, 5, 6],
                [7, 8, 9]])
        索引张量index的形状为(3,),内容如下:
        tensor([0, 1, 0])
        聚合操作的过程如下:
        根据索引张量index的值,将输入张量input_tensor中的值聚合到对应的位置上。在这个例子中,索引张量index的第一个元素为0,
        表示将输入张量的第一行([1, 2, 3])聚合到输出张量的第一行上;
        索引张量的第二个元素为1,表示将输入张量的第二行([4, 5, 6])聚合到输出张量的第二行上;索引张量的第三个元素为0,
        表示将输入张量的第三行([7, 8, 9])聚合到输出张量的第一行上。
        对于每个聚合位置,使用指定的聚合操作进行聚合。在这个例子中,我们使用的是scatter_add()函数,它将输入张量中的值累加到聚合位置上。
        聚合结果保存在输出张量output_tensor中。在这个例子中,输出张量的形状为(2, 3),内容如下:
        tensor([[8, 10, 12],
                [4,  5,  6]])
        输出张量的第一行是将输入张量的第一行和第三行聚合得到的,第二行是将输入张量的第二行聚合得到的。
        通过聚合操作,我们可以将输入张量中的值按照指定的索引聚合到输出张量的指定位置上,从而实现灵活的聚合操作。
        """
        size = batch[-1].item() + 1 if size is None else size#128
        mean = scatter_mean(x, batch, dim=0, dim_size=size)#【128,16】,每个图中所有节点求均值
        print(mean.shape)#X [1151,16],batch[1151], self.weight_matrix[16,16]
        transformed_global = torch.tanh(torch.mm(mean, self.weight_matrix))#[128,16],全局上下文(全局向量)乘以可学习参数
        print(self.weight_matrix.shape)
        print(transformed_global.shape)
        coefs = torch.sigmoid((x * transformed_global[batch]).sum(dim=1))#【1151】。数据乘以全局向量
        print(coefs.shape)
        weighted = coefs.unsqueeze(-1) * x
        
        return scatter_add(weighted, batch, dim=0, dim_size=size)

    def get_coefs(self, x):
        mean = x.mean(dim=0)
        transformed_global = torch.tanh(torch.matmul(mean, self.weight_matrix))

        return torch.sigmoid(torch.matmul(x, transformed_global))

3、TN得到隐向量

class TensorNetworkModule(torch.nn.Module):
    """
    SimGNN Tensor Network module to calculate similarity vector.
    """

    def __init__(self, args):
        """
        :param args: Arguments object.
        """
        super(TensorNetworkModule, self).__init__()
        self.args = args
        self.setup_weights()
        self.init_parameters()

    def setup_weights(self):
        """
        Defining weights.
        """
        self.weight_matrix = torch.nn.Parameter(
            torch.Tensor(
                self.args.filters_3, self.args.filters_3, self.args.tensor_neurons
            )
        )
        self.weight_matrix_block = torch.nn.Parameter(
            torch.Tensor(self.args.tensor_neurons, 2 * self.args.filters_3)
        )
        self.bias = torch.nn.Parameter(torch.Tensor(self.args.tensor_neurons, 1))

    def init_parameters(self):
        """
        Initializing weights.
        """
        torch.nn.init.xavier_uniform_(self.weight_matrix)
        torch.nn.init.xavier_uniform_(self.weight_matrix_block)
        torch.nn.init.xavier_uniform_(self.bias)

    def forward(self, embedding_1, embedding_2):
        """
        Making a forward propagation pass to create a similarity vector.
        :param embedding_1: Result of the 1st embedding after attention.
        :param embedding_2: Result of the 2nd embedding after attention.
        :return scores: A similarity score vector.
        """
        batch_size = len(embedding_1)#【128】,embedding_1, embedding_2都是【128,16】
        #print(self.weight_matrix.view(self.args.filters_3, -1).shape) # 原始输入的两个实体都是16维向量,k中关系16个关系,现在用256维表示他们的某种关系
        scoring = torch.matmul(
            embedding_1, self.weight_matrix.view(self.args.filters_3, -1)
        )#scoring为【128,256】
        #print(self.weight_matrix.view(self.args.filters_3, -1).shape) 【k=16种关系,256为关系矩阵16*16】
        #print(scoring.shape)
        scoring = scoring.view(batch_size, self.args.filters_3, -1).permute([0, 2, 1]) #filters_3可以理解成找多少种关系【128,16,16】,比如两个实体找出16中关系。最后的16是固定的
        #print(scoring.shape)
        scoring = torch.matmul(
            scoring, embedding_2.view(batch_size, self.args.filters_3, 1)
        ).view(batch_size, -1)
        print(scoring.shape)#【128,16】
        combined_representation = torch.cat((embedding_1, embedding_2), 1)#【128,32】
        print(combined_representation.shape)
        block_scoring = torch.t(
            torch.mm(self.weight_matrix_block, torch.t(combined_representation))
        )#【128,16】=(【16,32】*【32,128】)T:拼接块乘以一个可学习权重,下一步加偏执
        print(block_scoring.shape)
        scores = F.relu(scoring + block_scoring + self.bias.view(-1))#【128,16】
        print(scores.shape)
        return scores

 

标签:实战,torch,features,self,args,batch,SimGNN,data
From: https://www.cnblogs.com/zhangxianrong/p/17731445.html

相关文章

  • 面向对象实战后的总结
    面向对象封装继承多态类对象(实例)方法消息面向对象编程:1.使用对象和对象之间的交互来设计系统,2.数据和相关的逻辑封装在一起什么是面向对象面向对象是:一种程序设计思想,它的核心概念是“对象”。“对象”是指具有特定属性和行为的实体,能够接收消息、处理消息并返回结果。......
  • 新手指引:前后端分离的springboot + mysql + vue实战案例
    案例说明:使用springboot+mysql+vue实现前后端分离的用户查询功能。1、mysql:创建test数据库->创建user数据表->创建模拟数据;2、springboot:配置mysql->使用mybatis操作mysql数据库->接口开发;3、vue:使用axios访问接口->user数据展示;1、mysql数据库1.1、安......
  • Websocket集群解决方案以及实战(附图文源码)
    最近在项目中在做一个消息推送的功能,比如客户下单之后通知给给对应的客户发送系统通知,这种消息推送需要使用到全双工的websocket推送消息。所谓的全双工表示客户端和服务端都能向对方发送消息。不使用同样是全双工的http是因为http只能由客户端主动发起请求,服务接收后返回消息。web......
  • 开源防火墙实战手册(4)-linux/unix基础(3)
    目录配置文件主机名主机名和IP地址的映射域名系统(DNS)解析器配置文件主机名[waterruby@localhost~]$cat/etc/hostnamewaterruby-server主机名和IP地址的映射[waterruby@localhost~]$cat/etc/hosts127.0.0.1localhostlocalhost.localdomainlocalhost4loca......
  • # yyds干货盘点 # 盘点一个使用Python自动化处理GPS、北斗经纬度数据实战(下篇)
    大家好,我是皮皮。一、前言上一篇文章我们使用了Python来实现数据的导入和分列处理,最终可以得到符合预期的结果,不过还可以继续深挖优化下,这一篇文章一起来看看吧。优化的背景如下图所示:二、实现过程这里【瑜亮老师】继续给了一个优化指导,如下图所示:并且给出的代码如下:withopen("./G......
  • 盘点一个使用Python自动化处理GPS、北斗经纬度数据实战(下篇)
    大家好,我是皮皮。一、前言上一篇文章我们使用了Python来实现数据的导入和分列处理,最终可以得到符合预期的结果,不过还可以继续深挖优化下,这一篇文章一起来看看吧。优化的背景如下图所示:二、实现过程这里【瑜亮老师】继续给了一个优化指导,如下图所示:并且给出的代码如下:with......
  • 推荐源哥和川川的新书:《Pyhton网络爬虫从入门到实战》
    ❤️作者主页:小虚竹❤️作者简介:大家好,我是小虚竹。2022年度博客之星评选TOP10......
  • MySQL实战实战系列 07 行锁功过:怎么减少行锁对性能的影响?
    在上一篇文章中,我跟你介绍了MySQL的全局锁和表级锁,今天我们就来讲讲MySQL的行锁。 MySQL的行锁是在引擎层由各个引擎自己实现的。但并不是所有的引擎都支持行锁,比如MyISAM引擎就不支持行锁。不支持行锁意味着并发控制只能使用表锁,对于这种引擎的表,同一张表上任何时刻只......
  • chart模板实战
    参考:https://helm.sh/zh/docs/chart_template_guide/getting_started/https://helm.sh/zh/docs/chart_template_guide/function_list/一.入门chart1.创建一个charthelmcreatemychart查看目录结构[root@k8s-masterhelm-test]#treemychart/mychart/├──charts├......
  • MySQL实战实战系列 06 全局锁和表锁 :给表加个字段怎么有这么多阻碍?
    今天我要跟你聊聊MySQL的锁。数据库锁设计的初衷是处理并发问题。作为多用户共享的资源,当出现并发访问的时候,数据库需要合理地控制资源的访问规则。而锁就是用来实现这些访问规则的重要数据结构。 根据加锁的范围,MySQL里面的锁大致可以分成全局锁、表级锁和行锁三类。今天这......