什么是Matching Network
1. 论文地址:Matching Networks for One Shot Learning
2. 简介:基于Metric Learning
部分思想,使用外部记忆
来增强网络,提高网络的学习能力。
3. 创新点
- 借鉴了
注意力和外部记忆
方面的经验来搭建网络 - 基于meta-learning用task来训练,而不是metric-learning输入固定类别的图片
4. 算法描述
Matching Network有两个输入:
- 输入任务S为一个N-way K-shot的任务(下图中是一个4way 1shot的任务),其中\(S=\left(x_i, y_i\right)_{i=1}^k\)
- 需要预测类别的图片\(\hat{x}\)
Matching Network的输出被定义为:
图片\(\hat{x}\)的预测类别\(\hat{y}\)
那么Matching Network算法就可以被构建为\(P(\hat{y} \mid \hat{x}, S)\)
其中,\(P(.)\)为网络的参数映射
,即注意力和外部记忆
(1) 注意力
- 简单范式
论文中给了一个简单的注意力范式:\(\hat{y}=\sum_{i=1}^k a\left(\hat{x}, x_i\right) y_i\) 这里用\(a(.)\)做注意力计算,计算\(\hat{x}\)和所有给定标签输入\(x\)的关系,然 后将这种关系与\(\hat{y}\)进行对应,从而求解需预测类别输入\(\hat{x}\)的预测类别\(\hat{y}\)
- 余弦距离注意力
直观上想,很容易想到注意力\(a(.)\)的定义可以选择一种metric指标(如:余弦距离),在浅层的向量空间
求解两张图片的类似度/距 离。
论文中定义了一个余弦距离注意力:
\[a\left(\hat{x}, x_i\right)=e^{c\left(f(\hat{x}), g\left(x_i\right)\right)} / \sum_{j=1}^k e^{c\left(f(\hat{x}), g\left(x_j\right)\right)} \] 其中\(c(.)\)为余弦距离,\(f(\hat{x})\)为输入\(\hat{x}\)的浅层向量表示,\(g(x_j)\)为输入标签\(x_j\)的浅层向量表示。论文中提到的\(f(.)\)和\(g(.)\)是共享参 数的(也就是同一个CNN网络)。
(2) 外部记忆
作者作者认为上述的余弦注意力定义的时候,(输入任务S中)每个已知标签的输入\(x_i\)通过CNN后的embedding,也就是 \(g(\hat{x_i})\)是 独立的,前后没有关系,然后与\(f(\hat{x})\)进行逐个对比,这看起来就有点简单粗暴,没有考虑到输入任务S改变embedding \(\hat{x_i}\) 的方式, 也就是\(f(.)\)应该是受\(g(S)\)影响的。
对此,作者提出了双向LSTM
来解决这个问题。
5. 网络设计
算法描述
- 将任务S中所有图片\(x_i\)和目标图片\(\hat{x}\)全部通过CNN网络,以获得它们的浅层向量表示,然后将这\(k+1\)个向量进行堆叠
- 将以上堆叠的浅层向量全部输入到双向LSTM中,获得\(k+1\)个输出。然后使用余弦距离判断前\(k\)个输出中与最后一个输出之间的相似度
- 根据计算出的相似度,按照任务中\(S\)中的标签信息求解目标图片\(\hat{x}\)的类别标签
核心代码
class MatchingNetwork(nn.Module):
def __init__(self, keep_prob, \
batch_size=100, num_channels=1, learning_rate=0.001, fce=False, num_classes_per_set=5, \
num_samples_per_class=1, nClasses = 0, image_size = 28):
super(MatchingNetwork, self).__init__()
"""
Builds a matching network, the training and evaluation ops as well as data augmentation routines.
:param keep_prob: A tf placeholder of type tf.float32 denotes the amount of dropout to be used
:param batch_size: The batch size for the experiment
:param num_channels: Number of channels of the images
:param is_training: Flag indicating whether we are training or evaluating
:param rotate_flag: Flag indicating whether to rotate the images
:param fce: Flag indicating whether to use full context embeddings (i.e. apply an LSTM on the CNN embeddings)
:param num_classes_per_set: Integer indicating the number of classes per set
:param num_samples_per_class: Integer indicating the number of samples per class
:param nClasses: total number of classes. It changes the output size of the classifier g with a final FC layer.
:param image_input: size of the input image. It is needed in case we want to create the last FC classification
"""
self.batch_size = batch_size
self.fce = fce
self.g = Classifier(layer_size = 64, num_channels=num_channels,
nClasses= nClasses, image_size = image_size )
if fce:
self.lstm = BidirectionalLSTM(layer_sizes=[32], batch_size=self.batch_size, vector_dim = self.g.outSize)
self.dn = DistanceNetwork()
self.classify = AttentionalClassify()
self.keep_prob = keep_prob
self.num_classes_per_set = num_classes_per_set
self.num_samples_per_class = num_samples_per_class
self.learning_rate = learning_rate
def forward(self, support_set_images, support_set_labels_one_hot, target_image, target_label):
"""
Builds graph for Matching Networks, produces losses and summary statistics.
:param support_set_images: A tensor containing the support set images [batch_size, sequence_size, n_channels, 28, 28]
:param support_set_labels_one_hot: A tensor containing the support set labels [batch_size, sequence_size, n_classes]
:param target_image: A tensor containing the target image (image to produce label for) [batch_size, n_channels, 28, 28]
:param target_label: A tensor containing the target label [batch_size, 1]
:return:
"""
# produce embeddings for support set images
# (batch_size,shot_num,3,img_size,img_size)
encoded_images = []
for i in np.arange(support_set_images.size(1)):
gen_encode = self.g(support_set_images[:,i,:,:,:])
encoded_images.append(gen_encode)
# produce embeddings for target images
for i in np.arange(target_image.size(1)):
gen_encode = self.g(target_image[:,i,:,:,:])
encoded_images.append(gen_encode)
outputs = torch.stack(encoded_images)
if self.fce:
outputs, hn, cn = self.lstm(outputs)
# get similarity between support set embeddings and target
similarities = self.dn(support_set=outputs[:-1], input_image=outputs[-1])
similarities = similarities.t()
# produce predictions for target probabilities
preds = self.classify(similarities,support_set_y=support_set_labels_one_hot)
# calculate accuracy and crossentropy loss
values, indices = preds.max(1)
if i == 0:
accuracy = torch.mean((indices.squeeze() == target_label[:,i]).float())
crossentropy_loss = F.cross_entropy(preds, target_label[:,i].long())
else:
accuracy = accuracy + torch.mean((indices.squeeze() == target_label[:, i]).float())
crossentropy_loss = crossentropy_loss + F.cross_entropy(preds, target_label[:, i].long())
# delete the last target image encoding of encoded_images
# make the embedding vector for each new target images to be at the end of the list
encoded_images.pop()
return accuracy/target_image.size(1), crossentropy_loss/target_image.size(1)
标签:set,Network,self,算法,target,images,hat,Matching,size
From: https://www.cnblogs.com/horolee/p/mn.html