目录
概
本文主要是解决 importance sampling 中的一个效率问题.
符号说明
-
\(i \in \mathcal{I}\), item;
-
\(c \in \mathcal{C}\), context (e.g., user);
-
\(\bm{\theta}\), 网络参数;
-
BPR loss:
\[\min_{\bm{\theta}} \: \sum_{(c, i) \in \mathcal{D}} \mathbb{E}_{j \sim P_{ns}(j|c)} [-\ln \sigma(r(c, i|\bm{\theta})- r(c, j|\bm{\theta}))] + \lambda \|\bm{\theta}\|^2, \]其中 \(\sigma(\cdot)\) 表示 sigmoid, \(P_{ns}\) 是某种负样本采样策略.
-
记:
\[r_{cij} := r(c, i|\bm{\theta})- r(c, j|\bm{\theta}). \]
Motivation
-
以往最简单的采样方式是 uniform sampler, 但是这种采样方式在模型训练得比较好之后就很难再提供有效的负样本了, 所以作者认为依赖 item score 的 adaptive sampelr 是最优的:
\[P^*_{ns} (i|c) = \frac{\exp(r_{ci})}{\sum_{j \in \mathcal{I}} \exp r_{cj}}. \]注: 我有点不确定 \(\mathcal{I}\) 是不是只包含负样本了, 他们团队之前的工作倒是只在负样本中采样的.
-
不过直接采样有点费时 (考虑到随着模型的变化, 这个概率分布是在不断变化的). 作者的策略是, 为每个 \(c\) 采样一个固定 pool, 然后从 pool 中进行采样.
本文方法
注: 以下操作每一个 epoch 都要重新计算.
-
对每个 \(c\) 利用 proposal distribution \(Q\) (e.g., uniform or popularity sampling) 采样得到 size 较小的 \(\mathcal{K}_c\);
-
计算 resampling weight:
\[w_i = \frac{\exp(r_{ci} - \log Q(i))}{\sum_{j \in \mathcal{K}_c} \exp(r_{cj} - \log Q(j))}, \: \forall i \in \mathcal{K}_c; \] -
根据上面的概率从 \(\mathcal{K}_c\) 中重新采样得到 \(\mathcal{R}_c\) (with replacedment);
-
训练的时候, 对于每个 \((c, i)\) 从 \(\mathcal{R}_{c}\) 采样负样本集合 \(\mathcal{J}_{ci}\);
-
根据如下损失进行训练:
\[\min_{\bm{\theta}} \: \sum_{(c, i) \in \mathcal{D}} \sum_{j \in \mathcal{J}_{ci}} -\hat{w}(j|c) \ln \sigma(\hat{r}_{ci} - \hat{r}_{cj}) + \lambda \|\bm{\theta}\|^2, \\ \hat{w}(j|c) = \frac{\exp(\hat{r}_{ci} - \log w(i|c))}{\sum_{j \in \mathcal{J}_c} \exp(\hat{r}_{cj} - \log w(j|c))}. \]
几个需要注意的点:
-
假设 \(|\mathcal{R}_c| = |\mathcal{K}_c| = K\), 则当 \(|\mathcal{K}_c|\) 足够大的时候 \(\mathcal{R}_c\) 中样本所构成的分布 (朴素贝叶斯) 可以逼近 \(P_{ns}^*\). 即, 假设 \(\mathcal{S}\) 为 \(\mathcal{R}_c\) 中的一部分样本集合, \(N_\mathcal{S}\) 表示 \(\mathcal{R}_s\)中包含样本 \(s \in \mathcal{S}\) 的个数, 则 \(N_{\mathcal{S}} / K \approx P_{ns}^*(S|c), \: K \rightarrow +\infty\):
\[\begin{array}{ll} & \frac{N_\mathcal{S}}{K} \\ \approx & \sum_{i \in \mathcal{S}} w(i|c) \cdot [\sum_{j \in \mathcal{K}_c} \mathbb{I}[j = i]] \\ = & \sum_{j \in \mathcal{K}_c} \sum_{i \in \mathcal{S}} w(i|c) \mathbb{I}[j = i] \\ = & \sum_{j \in \mathcal{K}_c} \sum_{i \in \mathcal{S}} w(j|c) \mathbb{I}[j = i] \\ = & \sum_{j \in \mathcal{K}_c} w(j|c) [\sum_{i \in \mathcal{S}}\mathbb{I}[j = i]] \\ = & \sum_{j \in \mathcal{K}_c} w(j|c) \mathbb{I}[i \in \mathcal{S}]\\ = & \frac{\sum_{j \in \mathcal{K}_c} \mathbb{I}[i \in \mathcal{S}] \exp(r_{ci} - \log Q(i))}{\sum_{j \in \mathcal{K}_c} \exp(r_{cj}- \log Q(j))}\\ \approx & \frac{ \mathbb{E}_{i \sim Q} [\mathbb{I}[i \in \mathcal{S}] \exp(r_{ci} - \log Q(i))]}{\mathbb{E}_{j \in Q} [\exp(r_{cj}- \log Q(j))]}\\ = & \frac{\sum_{i \in \mathcal{S}} \exp(r_{ci})}{\sum_{j \in \mathcal{I}} \exp(r_{cj})} = P^*(\mathcal{S}|c). \end{array} \] -
关于损失为什么这么设计请参考 here.
-
关于如何采样 \(\mathcal{J}\), 作者有更多的讨论.