首页 > 编程语言 >Matching Network算法概述

Matching Network算法概述

时间:2023-10-21 16:12:34浏览次数:30  
标签:set Network self 算法 target images hat Matching size

什么是Matching Network

1. 论文地址:Matching Networks for One Shot Learning

2. 简介:基于Metric Learning部分思想,使用外部记忆来增强网络,提高网络的学习能力。

3. 创新点

  • 借鉴了注意力和外部记忆方面的经验来搭建网络
  • 基于meta-learning用task来训练,而不是metric-learning输入固定类别的图片

4. 算法描述

Matching Network有两个输入

  1. 输入任务S为一个N-way K-shot的任务(下图中是一个4way 1shot的任务),其中\(S=\left(x_i, y_i\right)_{i=1}^k\)
  2. 需要预测类别的图片\(\hat{x}\)

Matching Network的输出被定义为:

图片\(\hat{x}\)的预测类别\(\hat{y}\)

那么Matching Network算法就可以被构建为\(P(\hat{y} \mid \hat{x}, S)\)

其中,\(P(.)\)为网络的参数映射,即注意力和外部记忆

(1) 注意力
  1. 简单范式

​ 论文中给了一个简单的注意力范式:\(\hat{y}=\sum_{i=1}^k a\left(\hat{x}, x_i\right) y_i\) 这里用\(a(.)\)做注意力计算,计算\(\hat{x}\)和所有给定标签输入\(x\)的关系,然 后将这种关系与\(\hat{y}\)进行对应,从而求解需预测类别输入\(\hat{x}\)的预测类别\(\hat{y}\)

  1. 余弦距离注意力

​ 直观上想,很容易想到注意力\(a(.)\)的定义可以选择一种metric指标(如:余弦距离),在浅层的向量空间求解两张图片的类似度/距 离。

​ 论文中定义了一个余弦距离注意力:

\[a\left(\hat{x}, x_i\right)=e^{c\left(f(\hat{x}), g\left(x_i\right)\right)} / \sum_{j=1}^k e^{c\left(f(\hat{x}), g\left(x_j\right)\right)} \]

​ 其中\(c(.)\)为余弦距离,\(f(\hat{x})\)为输入\(\hat{x}\)的浅层向量表示,\(g(x_j)\)为输入标签\(x_j\)的浅层向量表示。论文中提到的\(f(.)\)和\(g(.)\)是共享参 数的(也就是同一个CNN网络)。

(2) 外部记忆

​ 作者作者认为上述的余弦注意力定义的时候,(输入任务S中)每个已知标签的输入\(x_i\)通过CNN后的embedding,也就是 \(g(\hat{x_i})\)是 独立的,前后没有关系,然后与\(f(\hat{x})\)进行逐个对比,这看起来就有点简单粗暴,没有考虑到输入任务S改变embedding \(\hat{x_i}\) 的方式, 也就是\(f(.)\)应该是受\(g(S)\)影响的。

​ 对此,作者提出了双向LSTM来解决这个问题。

5. 网络设计

算法描述

  1. 将任务S中所有图片\(x_i\)和目标图片\(\hat{x}\)全部通过CNN网络,以获得它们的浅层向量表示,然后将这\(k+1\)个向量进行堆叠
  2. 将以上堆叠的浅层向量全部输入到双向LSTM中,获得\(k+1\)个输出。然后使用余弦距离判断前\(k\)个输出中与最后一个输出之间的相似度
  3. 根据计算出的相似度,按照任务中\(S\)中的标签信息求解目标图片\(\hat{x}\)的类别标签

核心代码

class MatchingNetwork(nn.Module):
    def __init__(self, keep_prob, \
                 batch_size=100, num_channels=1, learning_rate=0.001, fce=False, num_classes_per_set=5, \
                 num_samples_per_class=1, nClasses = 0, image_size = 28):
        super(MatchingNetwork, self).__init__()

        """
        Builds a matching network, the training and evaluation ops as well as data augmentation routines.
        :param keep_prob: A tf placeholder of type tf.float32 denotes the amount of dropout to be used
        :param batch_size: The batch size for the experiment
        :param num_channels: Number of channels of the images
        :param is_training: Flag indicating whether we are training or evaluating
        :param rotate_flag: Flag indicating whether to rotate the images
        :param fce: Flag indicating whether to use full context embeddings (i.e. apply an LSTM on the CNN embeddings)
        :param num_classes_per_set: Integer indicating the number of classes per set
        :param num_samples_per_class: Integer indicating the number of samples per class
        :param nClasses: total number of classes. It changes the output size of the classifier g with a final FC layer.
        :param image_input: size of the input image. It is needed in case we want to create the last FC classification 
        """
        self.batch_size = batch_size
        self.fce = fce
        self.g = Classifier(layer_size = 64, num_channels=num_channels,
                            nClasses= nClasses, image_size = image_size )
        if fce:
            self.lstm = BidirectionalLSTM(layer_sizes=[32], batch_size=self.batch_size, vector_dim = self.g.outSize)
        self.dn = DistanceNetwork()
        self.classify = AttentionalClassify()
        self.keep_prob = keep_prob
        self.num_classes_per_set = num_classes_per_set
        self.num_samples_per_class = num_samples_per_class
        self.learning_rate = learning_rate

    def forward(self, support_set_images, support_set_labels_one_hot, target_image, target_label):
        """
        Builds graph for Matching Networks, produces losses and summary statistics.
        :param support_set_images: A tensor containing the support set images [batch_size, sequence_size, n_channels, 28, 28]
        :param support_set_labels_one_hot: A tensor containing the support set labels [batch_size, sequence_size, n_classes]
        :param target_image: A tensor containing the target image (image to produce label for) [batch_size, n_channels, 28, 28]
        :param target_label: A tensor containing the target label [batch_size, 1]
        :return: 
        """
        # produce embeddings for support set images
        # (batch_size,shot_num,3,img_size,img_size)
        encoded_images = []
        for i in np.arange(support_set_images.size(1)):
            gen_encode = self.g(support_set_images[:,i,:,:,:])
            encoded_images.append(gen_encode)

        # produce embeddings for target images
        for i in np.arange(target_image.size(1)):
            gen_encode = self.g(target_image[:,i,:,:,:])
            encoded_images.append(gen_encode)
            outputs = torch.stack(encoded_images)

            if self.fce:
                outputs, hn, cn = self.lstm(outputs)

            # get similarity between support set embeddings and target
            similarities = self.dn(support_set=outputs[:-1], input_image=outputs[-1])
            similarities = similarities.t()

            # produce predictions for target probabilities
            preds = self.classify(similarities,support_set_y=support_set_labels_one_hot)

            # calculate accuracy and crossentropy loss
            values, indices = preds.max(1)
            if i == 0:
                accuracy = torch.mean((indices.squeeze() == target_label[:,i]).float())
                crossentropy_loss = F.cross_entropy(preds, target_label[:,i].long())
            else:
                accuracy = accuracy + torch.mean((indices.squeeze() == target_label[:, i]).float())
                crossentropy_loss = crossentropy_loss + F.cross_entropy(preds, target_label[:, i].long())

            # delete the last target image encoding of encoded_images
            # make the embedding vector for each new target images to be at the end of the list
            encoded_images.pop()

        return accuracy/target_image.size(1), crossentropy_loss/target_image.size(1)

标签:set,Network,self,算法,target,images,hat,Matching,size
From: https://www.cnblogs.com/horolee/p/mn.html

相关文章

  • 棋盘覆盖——分治算法的典例
    问题描述在一个\({2^k}\times{2^k}(K\geqslant0)\)个方格组成的棋盘中,恰有一个方格与其他方格不同,称该方格为特殊方格。棋盘覆盖问题要求用图所示的4种不同形状的\(L\)型骨牌覆盖给定棋盘上除特殊方格以外的所有方格,且任何2个\(L\)型骨牌不得重叠覆盖。问题分析算法设计......
  • 贪心算法实现
    贪心算法顾名思义,贪心算法总是作出在当前看来最好的选择。也就是说贪心算法并不从整体最优考虑,它所作出的选择只是在某种意义上的局部最优选择。当然,希望贪心算法得到的最终结果也是整体最优的。虽然贪心算法不能对所有问题都得到整体最优解,但对许多问题它能产生整体最优解。如单......
  • 边缘检测算法
    边缘检测算法是在数字图像处理中常用的一种技术,用于检测图像中物体边缘的位置。以下是几种常见的边缘检测算法:Sobel算子:Sobel算子是一种基于梯度的算法,通过计算图像的水平和垂直方向的梯度值,并将其组合起来得到边缘强度。Sobel算子具有简单、快速的特点,常用于实时应用。Prewitt算子......
  • 【图像分割】基于回溯搜索算法BSA的多阈值图像分割算法研究附Matlab代码
    ✅作者简介:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可私信。......
  • 算法训练day39LeetCode738.968.
    算法训练day39LeetCode738.968.738.单调递增的数字题目738.单调递增的数字-力扣(LeetCode)题解代码随想录(programmercarl.com)classSolution{public:intmonotoneIncreasingDigits(intn){stringstrNum=to_string(n);//int转换string......
  • 什么是Nagle 算法和延迟确认
    一、Nagle算法和延迟确认是干什么的?当我们TCP报⽂的承载的数据⾮常⼩的时候,例如⼏个字节,那么整个⽹络的效率是很低的,因为每个TCP报⽂中都会有20个字节的TCP头部,也会有20个字节的IP头部,⽽数据只有⼏个字节,所以在整个报⽂中有效数据占有的比例就会⾮常低。这就好像快递......
  • 部分算法总结
    小部分算法总结部分题目请见:https://github.com/ZhangFirst1/Algorithm-problem-code异或运算a^=b相当于a=a^b,将十进制数字转化为二进制进行运算,相同为0,相异为1,0和任何数异或运算都是原来的那个数。可以用来判断数组中哪个数字只出现过一次(通过将所有数与0进行异或运算)快......
  • Linux (7) NetworkManager重置resolve.conf
    《WindowsAzurePlatform系列文章目录》 在默认情况下,AzureLinuxVM会安装waagent,而waagent会依赖于NetworkManager服务。当我们修改了resolve.conf的时候,如果重启NetworkManager或者重启了LinuxVM,NetworkManager会重置resolve.conf。 目前有两个......
  • 10.21算法
    颠倒二进制位颠倒给定的32位无符号整数的二进制位。提示:请注意,在某些语言(如Java)中,没有无符号整数类型。在这种情况下,输入和输出都将被指定为有符号整数类型,并且不应影响您的实现,因为无论整数是有符号的还是无符号的,其内部的二进制表示形式都是相同的。在Java中,编译器使用二......
  • 常见密码学算法简介
    1.常见对称加解密算法对称加密算法是一种加密算法,使用相同的密钥来加密和解密数据。这些算法在保护数据安全性方面起着重要作用。下面是一些常用的对称加密算法的介绍:1.1AdvancedEncryptionStandard(AES)简介:AES是一种高级加密标准,用于保护敏感数据。它使用128、192或......