首页 > 其他分享 >MAMO模型复现

MAMO模型复现

时间:2022-09-26 20:48:19浏览次数:76  
标签:phi embedding sum 用户 self 复现 MAMO grad 模型

原文为MAMO: Memory-Augmented Meta-Optimization for Cold-start Recommendation
本文重点是用几个记忆网络保存用户、物品及任务的一些信息,在遇到新任务时,用之前任务学到的一些先验知识进行初始化,之后经过少量迭代就可以收敛到一个不错的效果。

for i in range(self.n_epoch):
    start = time.time()
    u_grad_sum, i_grad_sum, r_grad_sum = self.model.get_zero_weights()
    
    for u in self.train_users[:100]:
        bias_term, att_values = user_mem_init(self.user_feature[u], self.device, self.FeatureMEM, self.x1_loading,self.alpha)

        self.model.init_u_mem_weights(self.phi_u, bias_term, self.tao, self.phi_i, self.phi_r)

        self.model.init_ui_mem_weights(att_values, self.TaskMEM)
        
        interaction_info = [[self.user_feature[u]] * (self.support_size + self.query_size) , 
                            [self.item_feature[item_id] for item_id in self.dataset_['ui_interactions'][u]],
                            self.dataset_['ui_ratings'][u] - 1,
                            self.dataset_['ui_states'][u]]
        
        user_module = LOCALUpdate(self.model, u, self.dataset, self.support_size, self.query_size, self.batch_size,
                                  self.n_inner_loop, self.rho, top_k=3, device=self.device, interaction_info=interaction_info)
        

        u_grad, i_grad, r_grad = user_module.train()
        u_grad_sum, i_grad_sum, r_grad_sum = grads_sum(u_grad_sum, u_grad), grads_sum(i_grad_sum, i_grad), \
                                             grads_sum(r_grad_sum, r_grad)
        
        self.FeatureMEM.write_head(u_grad, self.beta)
        u_mui = self.model.get_ui_mem_weights()
        self.TaskMEM.write_head(u_mui[0], self.gamma)
        
    self.phi_u, self.phi_i, self.phi_r = maml_train(self.phi_u, self.phi_i, self.phi_r,
                                                    u_grad_sum, i_grad_sum, r_grad_sum, self.lamda)
    
    self.test_with_meta_optimization(i)
    end = time.time()
    print('Epoch {} cost {:.2f}s'.format(i+1, end-start))

实际流程和文章中给出的伪代码略有些不同。
第6行是获取\(b_u\)和\(a_u\),\(b_u\)是个性化偏置,\(a_u\)是用户与用户记忆的attention分数,同时更新\(M_P\)。
第8行是获取\(\theta_u\)、\(\theta_i\)、\(\theta_r\),分别对应用户embedding生成模块、物品embedding生成模块、推荐模块的参数。
第10行是获取\(M_{u,I}\),获取用户偏好矩阵。
第12~15行是获取本轮训练的数据。
第17~18和21行进行inner update,更新模型参数,并且获取到用户embedding生成模块、物品embedding生成模块、推荐模块的梯度。
第22~23行,梯度累加。
第25行,更新\(M_U\)。
第27行,更新\(M_{U,I}\)。
第29~30行,更新全局的用户embedding生成模块、物品embedding生成模块、推荐模块参数。

这里想要吐槽一下,作者给出的源码把每个用户的数据单独存储,处理后的数据大小变成1个G,我优化了一下存储数据的方法,数据只有几十M,而且运行速度和作者的基本上没差。
本文的思路很好,为每个用户生成独有的embedding生成器和推荐器,但是实验结果不太稳定,甚至在一些情况下MeLU比它要好不少,我觉得该方法应该还有较大的改进空间。

在movielens-1m上进行实验,四个场景分别为老用户-老物品、老用户-新物品、新用户-老物品、新用户-新物品,评估指标为MAE。

W-W W-C C-W C-C
原文结果 0.8725 0.9306 0.8967 0.8894
复现结果 0.8234 1.0043 0.7904 0.7778

标签:phi,embedding,sum,用户,self,复现,MAMO,grad,模型
From: https://www.cnblogs.com/ambition-hhn/p/16732200.html

相关文章

  • 软件质量模型保证(SQA)
    软件质量模型保证(SQA)目的:使软件制作的过程对于领导层是可见的定义:它是一套计划和方法来向领导层保证五个基本目标:1.保证有计划地进行2.保证遵循了步骤和需求3.及时通......
  • 软件质量模型
    软件质量模型(ISO9126)1.功能性2.可靠性(1.尽量不出问题2.出了问题不影响主体功能3.如果影响了主体功能要能尽快修复)3.易用性(用户体验要好)4.效率5.可维持性(更新)6.可移......
  • MS17-010复现
    一、环境准备功击方:kali(192.168.43.132)目标机:win7(192.168.43.134)win7打开smb服务漏洞的产生:Sbm服务445端口二、扫描1.nmap扫描2、使用msf的auxiliary二次判......
  • Yolov3模型转caffe模型
    Yolov3模型转caffe模型前提下载docker和nvidia-docker百度去下载即可基础镜像拉取sudodockerpullbvlc/caffe:gpu配置/etc/docker/daemon.json{"registry......
  • 寻找领域不变量:从生成模型到因果表征
    1领域不变的表征在迁移学习/领域自适应中,我们常常需要寻找领域不变的表征(Domain-invariantRepresentation)[1],这种表示可被认为是学习到各领域之间的共性,并基于此共性......
  • 在强化学习算法性能测试时使用训练好的模型运行游戏,此时如何控制实时游戏画面的帧数
    问题:在强化学习算法性能测试时使用训练好的模型运行游戏,此时如何控制实时游戏画面的帧数?  ========================================  看到很多训练好的模型......
  • WEB自动化-10-Page Object 模型
    10PageObject模型10.1概述  在针对一个WEB页面编写自动化测试用例时,需要引用页面中的元素(数据)才能进行操作(动作)并显示出页面内容。如果编写的测试用例是直接针对......
  • 重复暴FRB 20201124A的观测和模型
    重复暴FRB20201124A的观测和模型ArticlePublished:21September2022AfastradioburstsourceatacomplexmagnetizedsiteinabarredgalaxyH.Xu,J.R.Ni......
  • 候捷-C++程序设计(Ⅱ)兼谈对象模型
    目录笔记参考学习目标转换函数与explicitpointer-likeclassesfunction-likeclasses模板template模板特化与偏特化模板模板参数引用(reference)关于虚指针(vptr)和虚表(vtbl)关......
  • 【博学谷学习记录】超强总结,用心分享|Java基础分享-TCP/IP 网络模型有哪几层?
    目录1.应用层2.传输层3.网络层4.网络接口层5.总结TCP/IP网络模型有哪几层?问大家,为什么要有TCP/IP网络模型?对于同一台设备上的进程间通信,有很多种方式,比如有管道......