首页 > 编程语言 >推荐算法之-召回中的随机负采样

推荐算法之-召回中的随机负采样

时间:2024-01-17 17:46:31浏览次数:36  
标签:采样 pv val withColumn trainItems 算法 召回 id row

// 二分查找
  def fetchBinarySearch(trainItems: Array[(String, Double)], target: Double): String = {
//    val trainItems = Array(("1", 0), ("2", 1), ("3", 3), ("4", 4), ("5", 6))
//    val target = 6.0000000000018032
    if (trainItems.length == 0) {
      ""
    } else {
      var left = 0
      var right = trainItems.length - 1
      while(left < right) {
        val mid = ((left + right)/2).toInt
        if (trainItems(mid)._2 < target) {
          left = mid + 1
        } else {
          right = mid
        }
      }
      trainItems(left)._1
    }
  }

  // 获取采样负样本用户集
  def fetchFullSampleItemsUdf(trainItems: Array[(String, Double)], trainItemsSize: Int, negNum: Int): UserDefinedFunction = udf(
    (app_info: String) => {
      val sampleItems = app_info.split(",").map(t => t.split(":")(0)).toBuffer
      val sampleItemsSet = scala.collection.mutable.Set[String]() ++ sampleItems.toSet
      val posNum = sampleItems.size
      var tmpNegNum = posNum*negNum
//      val trainItems = Array(("1", 0.1), ("2", 0.2), ("3", 0.3), ("4", 0.4))
      val probabilities = DenseVector(trainItems.map(_._2))
      while(tmpNegNum > 0) {
//        // 随机负采样
//        val randomIndex = (new Random).nextInt(trainItemsSize)
//        val negItem = trainItems(randomIndex)._1
        // 带权负采样(二分查找)
        val randomTarget = (new Random).nextDouble()
        val negItem = fetchBinarySearch(trainItems, randomTarget)
//        // 带权负采样(调用接口函数)
//        val randomIndex =  new Multinomial(probabilities).sample(1).head
//        val negItem = trainItems(randomIndex)._1
        if (!sampleItemsSet.contains(negItem)) {
          sampleItems.append(negItem)
          tmpNegNum = tmpNegNum - 1
        }
      }
      sampleItems.zipWithIndex.map{
        case (item, i) =>
          val label = if (i < posNum) 1 else 0
          (item, label)
      }
    }
  )




// 样本数据拼接
  def fetchSampleData(spark: SparkSession, day: String, part: String, negNum: Int): DataFrame = {
//    val part = "0"
    val targetData = fetchTargetData(spark, day, part)
    val userMap = {
      targetData.select("user_id").dropDuplicates("user_id").rdd
        .map {row =>
          val user_id = row.getAs[String]("user_id")
          (user_id, "1")
        }.collect().toMap
    }
//    val trainItems = fetchItemSampleData(spark, day).dropDuplicates("appid").rdd.map{
//      row => row.getAs[String]("appid")
//    }.collect()

    val win = Window.partitionBy("day")
    val win2 = Window.partitionBy("day").orderBy("pv")
    val win3 = Window.partitionBy("day").orderBy("rank")
    val trainItems = {
      fetchItemSampleData(spark, day).groupBy("day", "appid").agg(expr("power(count(user_id), 0.75) as pv"))
        .withColumn("pv_sum", sum("pv").over(win))
        .withColumn("fw", col("pv")/col("pv_sum"))
        .withColumn("rank", row_number().over(win2))
        .withColumn("fp", sum("fw").over(win3)).rdd        //相同值累计求和有问题
        .map{row =>
          val appid = row.getAs[String]("appid")
          val fp = row.getAs[Double]("fp")
          val pv = row.getAs[Double]("pv")
          val fw = row.getAs[Double]("fw")
          (appid, fp)
        }.collect()
    }.sortBy(_._2)
//    trainItems.reverse.take(10)
    val trainItemsSize = trainItems.length
//    targetData.
//      withColumn("fullSampleItems", fetchFullSampleItemsUdf(trainItems, trainItemsSize, 5)(col("app_info"))).
//      withColumn("fullSampleItems", explode(col("fullSampleItems"))).
//      withColumn("item_id", col("fullSampleItems").getField("_1")).
//      withColumn("target", col("fullSampleItems").getField("_2")).
//      groupBy("item_id").agg(expr("count(if(target == '1', user_id, null)) as pos_pv"),
//      expr("count(if(target == '0', user_id, null)) as neg_pv")).orderBy(desc("pos_pv")).
//      show(10, false)
    val userFeatures = fetchUserFeatures(spark, day, userMap)
    val itemFeatures = fetchItemFeatures(spark, day)
    val sampleData = {
      targetData.join(userFeatures, Seq("user_id"), "left")
        .withColumn("fullSampleItems", fetchFullSampleItemsUdf(trainItems, trainItemsSize, negNum)(col("app_info")))
        .withColumn("fullSampleItems", explode(col("fullSampleItems")))
        .withColumn("item_id", col("fullSampleItems").getField("_1"))
        .withColumn("target", col("fullSampleItems").getField("_2"))
        .join(broadcast(itemFeatures), Seq("item_id"), "left")
    }
    sampleData
  }

 

标签:采样,pv,val,withColumn,trainItems,算法,召回,id,row
From: https://www.cnblogs.com/WSX1994/p/17970597

相关文章

  • 区域人数统计AI智能分析网关V4客流统计AI算法介绍及应用场景
    客流量统计AI算法是一种基于人工智能技术的数据分析方法,通过机器学习、深度学习等算法,实现对客流量的实时监测和统计。该算法主要基于机器学习和计算机视觉技术,其基本流程包括图像采集、图像预处理、目标检测、目标跟踪和客流量统计等步骤,通过在监控视频中识别和跟踪人的轮廓或特......
  • 基础算法(二)归并排序模板
    模板如下#include<iostream>usingnamespacestd;constintN=1000010;intq[N],tmp[N];voidmerge_sort(intq[],intl,intr){if(l>=r)return;intmid=(l+r)>>1;intk=0,i=l,j=mid+1;merge_sort(q,l,mid);merge_sort(q,......
  • HanLP — HMM隐马尔可夫模型 -- 维特比(Viterbi)算法 --完整示例代码
    完成代码importpicklefromtqdmimporttqdmimportnumpyasnpimportosdefmake_label(text_str):"""从单词到label的转换,如:今天---->BE麻辣肥牛:--->BMME的--->S"""text_len=len(text_str)iftext_len==1:......
  • 基于内容的电影推荐算法研究
    引言今天读的文章为一篇名为《基于内容的电影推荐算法研究》的文章,文章提出了一种基于内容的电影推荐算法,通过分析电影特征和用户兴趣,实现更精准的电影推荐。文章中使用到了TF-IDF向量化方法,将电影类型和导演信息转化为特征向量,然后使用余弦相似度来衡量电影之间的相关性,接下来......
  • 【源码系列#06】Vue3 Diff算法
    专栏分享:vue2源码专栏,vue3源码专栏,vuerouter源码专栏,玩具项目专栏,硬核......
  • 成为一个合格程序员所必备的三种常见LeetCode排序算法
    排序算法是一种通过特定的算法因式将一组或多组数据按照既定模式进行重新排序的方法。通过排序,我们可以得到一个新的序列,该序列遵循一定的规则并展现出一定的规律。经过排序处理后的数据可以更方便地进行筛选和计算,从而大大提高了计算效率。因此,掌握排序算法是每个程序员的基本功......
  • 路径刷新算法
    deffind_paths_and_modify(data,target_path,new_value):paths=[]stack=[([],data)]whilestack:current_path,current_data=stack.pop()ifisinstance(current_data,dict):forkey,valueincurrent_data.item......
  • 算法-二分
    1.整数二分适用于有单调性的数列和部分没有单调性的数列本质:通过一个性质把数列分成两个序列,然后找到分界点注意:每一次循环后数组区间都会变成1,即你所查到的分界点2.实数二分因为没有整除的问题,每次区间都会严格减小一半如果题目要求保留x位小数,则循环条件为r-l<1e-(x+2);......
  • m基于码率兼容打孔LDPC码ms最小和译码算法的LDPC编译码matlab误码率仿真
    1.算法仿真效果matlab2022a仿真结果如下:2.算法涉及理论知识概要码率兼容打孔LDPC码BP译码算法是一种改进的LDPC译码算法,能够在不同码率下实现更好的译码性能。该算法通过在LDPC码中引入打孔操作,使得码率可以灵活地调整,同时利用BP(BeliefPropagation)译码算法进行迭代译码,提高了......
  • m基于码率兼容打孔LDPC码ms最小和译码算法的LDPC编译码matlab误码率仿真
    1.算法仿真效果matlab2022a仿真结果如下:    2.算法涉及理论知识概要       码率兼容打孔LDPC码BP译码算法是一种改进的LDPC译码算法,能够在不同码率下实现更好的译码性能。该算法通过在LDPC码中引入打孔操作,使得码率可以灵活地调整,同时利用BP(BeliefPropagation......