目录
Liu J., Kumar A., Ba J., Kiros J. and Swersky K. Graph normalizing flows. NIPS, 2019.
概
基于 flows 的图的生成模型.
符号说明
- \(\mathcal{G} = (H, \Omega)\), 图;
- \(H = (\mathbf{h}^{(1)}, \cdots, \mathbf{h}^{(N)}) \in \mathbb{R}^{N \times d_n}\), node feature matrix;
- \(\Omega \in \mathbb{R}^{N \times N \times (d_e + 1)}\), 其中 \(\Omega_{:,:,0} \in \mathbb{R}^{N \times N}\) 表示邻接矩阵, \(\Omega_{:, :, :1:(d_e+1)}\) 表示 edge features.
- 寻常的 MPNN 可以表述为:\[\mathbf{m}_{t+1}^{(v)} = \text{Agg} \Bigg( \{ M_t (\mathbf{h}_t^{(v)}, \mathbf{h}_t^{(u)}, \Omega_{u, v}) \}_{u \in \mathcal{N}}(v) \Bigg) \\ \mathbf{h}_{t+1}^{(v)} = U_t(\mathbf{h}_t^{(v)}, \mathbf{m}_{t+1}^{v}), \]其中 \(M_t(\cdot), U_t(\cdot)\) 分别是 message generation function 和 vertex update function.
Graph Normalizing Flows
-
需要注意的是, 本文的 flows 和一般的 flows 有点区别, 它并不具有一个 encoder 先将 \(\mathbf{x}\) 转换为隐变量 \(\mathbf{z}\) 再 \(\mathbf{z}' = f(\mathbf{z})\) 的过程, 而是直接构造 flow \(\mathbf{z} = f(\mathbf{x})\).
-
简单来说, flow 需要保证 \(f(\cdot)\) 是可逆的, 此时:
\[P(\mathbf{z}) = P(\mathbf{x})|\text{det}(\frac{\partial f(\mathbf{x})}{\partial \mathbf{x}})|^{-1}. \] -
作者是基于 RealNVP 进行的, 该方法将 \(\mathbf{x}\) 切分为 \(\mathbf{x}^{(0)}, \mathbf{x}^{(1)}\), 然后:
\[\mathbf{z}^{(0)} = \mathbf{x}^{(0)} \\ \mathbf{z}^{(1)} = \mathbf{x}^{(1)} \odot \exp(s(\mathbf{x}^{(0)})) + t(\mathbf{x}^{(0)}), \]其中 \(s, t\) 为两个 non-linear 函数. 和显然 \(\nabla_{\mathbf{x}} \mathbf{z}\) 为一个下三角矩阵. 此时行列式就是对角线元素相乘.
GRevNets
-
让我们来看看作者是怎么构造可以的 flow 的.
-
首先, 对每个结点 \(v\), 将它的结点特征切分为 \(\mathbf{h}_t^{0}, \mathbf{h}_t^{1}\) (这里我们省略标识 \((v)\)).
-
前向的过程可以表述为:
\[H_{t+\frac{1}{2}}^0 = H_t^0 \odot \exp(F_1(H_t^1)) + F_2 (H_t^1), \quad H_{t+1}^0 = H_{t+\frac{1}{2}}^0, \\ H_{t+\frac{1}{2}}^1 = H_t^{1}, \quad H_{t+1}^{1} = H_{t+\frac{1}{2}}^1 \odot \exp(G_1(H_{t+\frac{1}{2}}^{0})) + G_2(H_{t+\frac{1}{2}}^0). \] -
给定 \(H_{t+1}^0, H_{t+1}^1\) 我们可以得到:
-
于是我们有:
\[P(H_0) = P(H_T) \prod_{t=1}^T |\text{det}(\frac{\partial H_t}{\partial H_{t-1}})|. \] -
我们可以通过极大化对数似然 \(\log P(H_0)\) 来优化参数.
-
但是, 我们最终希望的其实是生成离散的图 (通过邻接矩阵 \(A\) 来刻画).
-
所以在生成的时候, 比如我们采样 \(H_T \sim \mathcal{N}(0, 1)\), 然后通过 GNF 得到 \(H_0\), 那么我们实际上还需要一个 decoder 将 \(H_0\) 映射为 \(\hat{A}\).
-
为此, 作者还额外设计了一个 encoder, 将 \(A, H\) 映射为隐变量 \(X\), 不过我不是特别清楚为什么 \(H\) 也是采样子正态分布而不是直接用 node features.
-
训练编码器是通过如下损失:
\[\mathcal{L}(\theta) = -\sum_{i=1}^N \sum_{j=1}^{N/2} A_{ij} \log (\hat{A}_{ij}) + (1 - A_{ij}) \log (1 - \hat{A}_{ij}), \]这里 \(N/2\) 的原因是作者假设我们生成的是无向图, 所以 \(A\) 是对称的.
-
对于 decoder, 作者采用的是一种非常简单的方式:
\[\hat{A}_{ij} = \frac{1}{1 + \exp(C(\|\mathbf{x}_i - \mathbf{x}_j\|_2^2 - 1))}. \]