// 二分查找 def fetchBinarySearch(trainItems: Array[(String, Double)], target: Double): String = { // val trainItems = Array(("1", 0), ("2", 1), ("3", 3), ("4", 4), ("5", 6)) // val target = 6.0000000000018032 if (trainItems.length == 0) { "" } else { var left = 0 var right = trainItems.length - 1 while(left < right) { val mid = ((left + right)/2).toInt if (trainItems(mid)._2 < target) { left = mid + 1 } else { right = mid } } trainItems(left)._1 } } // 获取采样负样本用户集 def fetchFullSampleItemsUdf(trainItems: Array[(String, Double)], trainItemsSize: Int, negNum: Int): UserDefinedFunction = udf( (app_info: String) => { val sampleItems = app_info.split(",").map(t => t.split(":")(0)).toBuffer val sampleItemsSet = scala.collection.mutable.Set[String]() ++ sampleItems.toSet val posNum = sampleItems.size var tmpNegNum = posNum*negNum // val trainItems = Array(("1", 0.1), ("2", 0.2), ("3", 0.3), ("4", 0.4)) val probabilities = DenseVector(trainItems.map(_._2)) while(tmpNegNum > 0) { // // 随机负采样 // val randomIndex = (new Random).nextInt(trainItemsSize) // val negItem = trainItems(randomIndex)._1 // 带权负采样(二分查找) val randomTarget = (new Random).nextDouble() val negItem = fetchBinarySearch(trainItems, randomTarget) // // 带权负采样(调用接口函数) // val randomIndex = new Multinomial(probabilities).sample(1).head // val negItem = trainItems(randomIndex)._1 if (!sampleItemsSet.contains(negItem)) { sampleItems.append(negItem) tmpNegNum = tmpNegNum - 1 } } sampleItems.zipWithIndex.map{ case (item, i) => val label = if (i < posNum) 1 else 0 (item, label) } } ) // 样本数据拼接 def fetchSampleData(spark: SparkSession, day: String, part: String, negNum: Int): DataFrame = { // val part = "0" val targetData = fetchTargetData(spark, day, part) val userMap = { targetData.select("user_id").dropDuplicates("user_id").rdd .map {row => val user_id = row.getAs[String]("user_id") (user_id, "1") }.collect().toMap } // val trainItems = fetchItemSampleData(spark, day).dropDuplicates("appid").rdd.map{ // row => row.getAs[String]("appid") // }.collect() val win = Window.partitionBy("day") val win2 = Window.partitionBy("day").orderBy("pv") val win3 = Window.partitionBy("day").orderBy("rank") val trainItems = { fetchItemSampleData(spark, day).groupBy("day", "appid").agg(expr("power(count(user_id), 0.75) as pv")) .withColumn("pv_sum", sum("pv").over(win)) .withColumn("fw", col("pv")/col("pv_sum")) .withColumn("rank", row_number().over(win2)) .withColumn("fp", sum("fw").over(win3)).rdd //相同值累计求和有问题 .map{row => val appid = row.getAs[String]("appid") val fp = row.getAs[Double]("fp") val pv = row.getAs[Double]("pv") val fw = row.getAs[Double]("fw") (appid, fp) }.collect() }.sortBy(_._2) // trainItems.reverse.take(10) val trainItemsSize = trainItems.length // targetData. // withColumn("fullSampleItems", fetchFullSampleItemsUdf(trainItems, trainItemsSize, 5)(col("app_info"))). // withColumn("fullSampleItems", explode(col("fullSampleItems"))). // withColumn("item_id", col("fullSampleItems").getField("_1")). // withColumn("target", col("fullSampleItems").getField("_2")). // groupBy("item_id").agg(expr("count(if(target == '1', user_id, null)) as pos_pv"), // expr("count(if(target == '0', user_id, null)) as neg_pv")).orderBy(desc("pos_pv")). // show(10, false) val userFeatures = fetchUserFeatures(spark, day, userMap) val itemFeatures = fetchItemFeatures(spark, day) val sampleData = { targetData.join(userFeatures, Seq("user_id"), "left") .withColumn("fullSampleItems", fetchFullSampleItemsUdf(trainItems, trainItemsSize, negNum)(col("app_info"))) .withColumn("fullSampleItems", explode(col("fullSampleItems"))) .withColumn("item_id", col("fullSampleItems").getField("_1")) .withColumn("target", col("fullSampleItems").getField("_2")) .join(broadcast(itemFeatures), Seq("item_id"), "left") } sampleData }
标签:采样,pv,val,withColumn,trainItems,算法,召回,id,row From: https://www.cnblogs.com/WSX1994/p/17970597