from pyspark.sql import Window
from pyspark.sql import functions as F
import functools
from datetime import datetime
def generate_new_rating_data(w_df, count_a, distinct_a, flag_a, suffix):
if flag_a:
w_df = w_df.where(
(w_df[f"NR_Count{suffix}"] > 0) & (w_df[f"NR_Count{suffix}"] == w_df[f"Rate_Count{suffix}"])
)
window_spec = Window.partitionBy("ID").orderBy("rating_order")
return {
w_df.where(F.col(f"Rating_Rank{suffix}") >= 250)
.withColumn("rank", F.row_number().over(window_spec))
.where(F.col("rank") == 1)
.select(
F.col("ID"),
F.col("Source").alias(f"Source{suffix}"),
F.col(f"Rating_Rank{suffix}"),
F.col(f"NormCode{suffix}").alias(f"Rating{suffix}")
)
}
elif count_a == 3 and distinct_a == 2:
temp_df = w_df.where(
(w_df[f"NormCode{suffix}"] != "NR")
& (w_df[f"NonNR_Count{suffix}"] == count_a)
& (w_df[f"NonNR_Distinct{suffix}"] == distinct_a)
)
count_df = temp_df.groupby("ID", f"Rating_Rank{suffix}") \
.agg(F.count("*").alias("total_count")) \
.where(F.col("total_count") == 2)
return {
temp_df.join(count_df, on=["ID", f"Rating_Rank{suffix}"], how="inner")
.withColumn("rank", F.row_number().over(window_spec))
.where(F.col("rank") == 1)
.select(
F.col("ID"),
F.col("Source").alias(f"Source{suffix}"),
F.col(f"Rating_Rank{suffix}"),
F.col(f"NormCode{suffix}").alias(f"Rating{suffix}")
)
}
else:
w_df = w_df.where(F.col(f"NormCode{suffix}") != "NR")
temp_df = w_df.where(
(w_df[f"NonNR_Count{suffix}"] == count_a)
& (w_df[f"NonNR_Distinct{suffix}"] == distinct_a)
)
window_spec = Window.partitionBy("ID").orderBy("rating_order")
return {
temp_df.where(F.col(f"Rating_Rank{suffix}") == F.col(f"lr_rating"))
.withColumn("rank", F.row_number().over(window_spec))
.where(F.col("rank") == 1)
.select(
F.col("ID"),
F.col("Source").alias(f"Source{suffix}"),
F.col(f"Rating_Rank{suffix}"),
F.col(f"NormCode{suffix}").alias(f"Rating{suffix}")
)
}
def loop_new_ratings(df, suffix=""):
window_params = [
[0, 0, f"Low{suffix}", F.asc(f"Rank{suffix}"), True],
[3, 3, f"Mid{suffix}", F.asc(f"Rank{suffix}"), False],
[3, 2, f"Low{suffix}", F.asc(f"Rank{suffix}"), False],
[3, 1, f"Low{suffix}", F.asc(f"Rank{suffix}"), False],
[2, 2, f"High{suffix}", F.desc(f"Rank{suffix}"), False],
[2, 1, f"Low{suffix}", F.asc(f"Rank{suffix}"), False],
[1, 1, f"Low{suffix}", F.asc(f"Rank{suffix}"), False],
]
shortened_df = (
df
.select(
"ID", "Source", f"NormCode{suffix}",
f"Rank{suffix}", f"Rating_Rank{suffix}",
f"Mid{suffix}", f"Distinct{suffix}", f"Rate_Count{suffix}", f"NR_Count{suffix}",
f"NonNR_Distinct{suffix}"
)
)
return functools.reduce(DataFrame.union, [generate_new_rating_data(shortened_df, *params, suffix) for params in window_params])
final_df = loop_new_ratings(rating_data)
final_df_short = loop_new_ratings(rating_data, "Short")
final_df_long = loop_new_ratings(rating_data, "Long")
combined_final_df = (
join_dfs(
[
final_df,
final_df_short,
final_df_long
],
select_list=[
"ID", "Source", "SourceShort", "SourceLong",
"Rating_Rank", "Rating_RankShort", "Rating_RankLong",
"Rating", "RatingShort", "RatingLong"
]
)
)
print(datetime.now(), "final combined df:")
combined_final_df.show()
我有一个 PySpark 脚本,它使用窗口函数和聚合处理评级数据。该代码工作正常,但没有优化,因为 它使用 for 循环和 functools.reduce 来组合 DataFrame,我相信这可以改进。我希望通过避免 for 循环和减少 来优化此脚本,同时保持相同的功能。
我尝试使用同一 DataFrame 中的窗口函数和条件将 DataFrame 转换合并为单个操作。但是,我找不到完全消除循环和 functools.reduce 的方法。
from pyspark.sql import Window
from pyspark.sql import functions as F
from datetime import datetime
def generate_new_rating_data(w_df, suffix):
window_spec = Window.partitionBy("ID").orderBy("rating_order")
return (
w_df
.withColumn(
"flag_a",
(F.col(f"NR_Count{suffix}") > 0) & (F.col(f"NR_Count{suffix}") == F.col(f"Rate_Count{suffix}"))
)
.withColumn(
"count_a",
F.when((F.col(f"NormCode{suffix}") != "NR") & (F.col(f"NonNR_Distinct{suffix}") == 2), 3).otherwise(
F.when(F.col(f"NormCode{suffix}") != "NR", F.col(f"NonNR_Count{suffix}")).otherwise(None)
)
)
.withColumn(
"distinct_a",
F.when(F.col(f"NormCode{suffix}") != "NR", F.col(f"NonNR_Distinct{suffix}")).otherwise(None)
)
.withColumn("rank", F.row_number().over(window_spec))
.filter(
(F.col("flag_a") & (F.col(f"Rating_Rank{suffix}") >= 250)) |
((F.col("count_a") == 3) & (F.col("distinct_a") == 2) & (F.col("total_count") == 2)) |
((F.col("count_a").isin(3, 2, 1)) & (F.col(f"Rating_Rank{suffix}") == F.col(f"lr_rating")))
)
.groupBy("ID", "Source", f"Rating_Rank{suffix}", f"NormCode{suffix}", "count_a", "distinct_a")
.agg(F.min("rank").alias("rank"))
.filter(F.col("rank") == 1)
.select(
F.col("ID"),
F.col("Source").alias(f"Source{suffix}"),
F.col(f"Rating_Rank{suffix}"),
F.col(f"NormCode{suffix}").alias(f"Rating{suffix}")
)
)
def process_ratings(df, suffix=""):
shortened_df = (
df
.select(
"ID", "Source", f"NormCode{suffix}",
f"Rank{suffix}", f"Rating_Rank{suffix}",
f"Mid{suffix}", f"Distinct{suffix}", f"Rate_Count{suffix}", f"NR_Count{suffix}",
f"NonNR_Distinct{suffix}", f"lr_rating", "rating_order"
)
.withColumn("total_count", F.count("*").over(Window.partitionBy("ID", f"Rating_Rank{suffix}")))
)
return generate_new_rating_data(shortened_df, suffix)
final_df = process_ratings(rating_data)
final_df_short = process_ratings(rating_data, "Short")
final_df_long = process_ratings(rating_data, "Long")
combined_final_df = (
final_df
.join(final_df_short, on="ID", how="inner")
.join(final_df_long, on="ID", how="inner")
.select(
"ID", "Source", "SourceShort", "SourceLong",
"Rating_Rank", "Rating_RankShort", "Rating_RankLong",
"Rating", "RatingShort", "RatingLong"
)
)
print(datetime.now(), "final combined df:")
combined_final_df.show()
以下是代码更改和优化的说明:
-
消除
loop_new_ratings
函数和 for 循环:-
我们不需要迭代
window_params
列表,因为我们可以使用条件聚合和窗口函数在单个 DataFrame 操作中实现相同的结果。
-
我们不需要迭代
-
使用条件聚合简化
generate_new_rating_data
函数:-
我们使用
F.when
创建新的条件列flag_a
、count_a
和distinct_a
,以根据条件应用不同的逻辑。 -
我们使用
F.count("*").over(Window.partitionBy("ID", f"Rating_Rank{suffix}"))
计算total_count
,用于特定条件。
-
我们使用
-
合并 DataFrame 操作:
- 我们将所有筛选、分组和聚合操作合并到单个 DataFrame 转换链中,以避免创建不必要的中间 DataFrame。
-
使用
join
代替functools.reduce
:-
我们使用
join
操作有效地组合最终的 DataFrame,而不是使用functools.reduce
。
-
我们使用
这些更改消除了
for
循环和
functools.reduce
,从而使代码更简洁、更高效。通过利用 PySpark DataFrame API 的强大功能,我们可以在单个转换链中表达复杂的逻辑,从而提高性能并降低代码复杂性。