基本原理
样本的干净标签后验概率\(P(\mathbf{Y}|X=\mathbf{x})\),
可通过噪声标签的后验概率\(P(\bar{\mathbf{Y}}|X = \mathbf{x})\)和噪声转移矩阵\(T(\mathbf{x})\)得到,即:
其中\(T_{ij}(\mathbf{x}) = P(\bar{Y} = j|Y = i,X = \mathbf{x})\)。
通常,转移矩阵\(T\)是不可识别的,并且在没有额外假设的情况下很难学习。因此实际上,噪声标签问题下,使用
噪声转移矩阵估计的方式特别少,本文只讨论最简单的噪声转移矩阵估计的形式。
代码层面上
转移矩阵由一个 \(C\times C\)的矩阵表示,其中\(C\)是类别数目。转移矩阵的参数随着模型训练更新。
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransitionMatrix(nn.Module):
def __init__(self, num_classes, device='cpu'):
super().__init__()
if num_classes == 10:
init = -2
else:
init = -4.5
w = torch.ones([num_classes, num_classes]) * init
self.register_parameter(name="w", param=nn.parameter.Parameter(w))
self.w.to(device)
self.identity = torch.eye(num_classes).to(device)
self.coeff = torch.ones([num_classes, num_classes]) - torch.eye(num_classes)
self.coeff = self.coeff.to(device)
def forward(self):
sig = torch.sigmoid(self.w)
T = self.identity.detach() + sig * self.coeff.detach()
T = F.normalize(T, p=1, dim=1)
return T
在训练过程中,使用噪声转移矩阵对模型输出调整,然后计算损失。注意,此处model输出为类别概率分布,也就是经过softmax后的logits
...
transition_matrix = TransitionMatrix(num_classes=num_classes, device=device)
for epoch in range(EPOCH):
transition_matrix.train()
...
for index, (batch_x, batch_y) in loop:
...
clean = model(batch_x)
t_hat = transition_matrix()
y_tilde = torch.mm(clean, t_hat)
vol_loss = torch.abs(t_hat.slogdet().logabsdet)
ce_loss = loss_func_ce(y_tilde.log(), batch_y.long())
loss = ce_loss + opt.lam * vol_loss
...
...
依赖:
torch 2.4.1