目录
- 题目地址
- 思路
- 代码
- MySQL代码
- 逐行翻译为Pandas代码
- 等效Cypher查询(未验证)
题目地址
https://leetcode.cn/problems/strong-friendship/
思路
就是在无向图中寻找这个pattern:
(* Mathematica *)
GraphData[{"CompleteTripartite", {1, 1, 3}}]
SQL写还是比较麻烦。
更加复杂的查询还是建议把数据迁移到neo4j这样的图数据库,然后写Cypher这样的图数据库查询语句。
代码
MySQL代码
with t1 as( -- 图中找到的所有 v1-e1-v2-e2-v3 pattern
select * from(
select f1.user2_id as uid , f1.user1_id as one_degree_connected , f2.user1_id as two_degree_connected
from Friendship f1
join Friendship f2
on f1.user1_id=f2.user2_id
union
select f1.user2_id as uid , f1.user1_id as one_degree_connected , f2.user2_id as two_degree_connected
from Friendship f1
join Friendship f2
on f1.user1_id=f2.user1_id
union
select f1.user1_id as uid , f1.user2_id as one_degree_connected , f2.user1_id as two_degree_connected
from Friendship f1
join Friendship f2
on f1.user2_id=f2.user2_id
union
select f1.user1_id as uid , f1.user2_id as one_degree_connected , f2.user2_id as two_degree_connected
from Friendship f1
join Friendship f2
on f1.user2_id=f2.user1_id
)tmp1
where uid<>two_degree_connected and uid<>one_degree_connected and one_degree_connected<>two_degree_connected
and uid<two_degree_connected
)
select uid as user1_id, two_degree_connected as user2_id
, count(distinct one_degree_connected) as common_friend
from t1
where concat(uid,",",two_degree_connected) in (select concat(user1_id,",",user2_id) from Friendship) -- 坚定的友谊要求这两人还得是朋友
group by user1_id,user2_id
having common_friend>=3
order by user1_id,user2_id,common_friend
逐行翻译为Pandas代码
import pandas as pd
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
def strong_friendship(friendship: pd.DataFrame) -> pd.DataFrame:
# Step 1: Perform self-joins on the dataframe to find the v1-e1-v2-e2-v3 patterns
# The joins are equivalent to finding one-degree and two-degree connections
patterns = []
# f1.user2_id = f2.user2_id
patterns.append(
pd.merge(friendship.rename(columns={'user1_id': 'one_degree_connected', 'user2_id': 'uid'}),
friendship.rename(columns={'user1_id': 'two_degree_connected'}),
left_on='one_degree_connected', right_on='user2_id')
[['uid', 'one_degree_connected', 'two_degree_connected']]
)
# f1.user1_id = f2.user1_id
patterns.append(
pd.merge(friendship.rename(columns={'user2_id': 'one_degree_connected', 'user1_id': 'uid'}),
friendship.rename(columns={'user2_id': 'two_degree_connected'}),
left_on='one_degree_connected', right_on='user1_id')
[['uid', 'one_degree_connected', 'two_degree_connected']]
)
# f1.user2_id = f2.user1_id
patterns.append(
pd.merge(friendship.rename(columns={'user1_id': 'one_degree_connected', 'user2_id': 'uid'}),
friendship.rename(columns={'user2_id': 'two_degree_connected'}),
left_on='one_degree_connected', right_on='user1_id')
[['uid', 'one_degree_connected', 'two_degree_connected']]
)
# f1.user1_id = f2.user2_id
patterns.append(
pd.merge(friendship.rename(columns={'user2_id': 'one_degree_connected', 'user1_id': 'uid'}),
friendship.rename(columns={'user1_id': 'two_degree_connected'}),
left_on='one_degree_connected', right_on='user2_id')
[['uid', 'one_degree_connected', 'two_degree_connected']]
)
# Step 2: Combine all the patterns into one DataFrame
all_patterns = pd.concat(patterns)
# Step 3: Drop duplicates and filter out invalid patterns
# where uid<>two_degree_connected and uid<>one_degree_connected and one_degree_connected<>two_degree_connected
# and uid<two_degree_connected
filtered_patterns = all_patterns.drop_duplicates().query(
'uid != two_degree_connected and uid != one_degree_connected and one_degree_connected != two_degree_connected and uid < two_degree_connected'
)
# print(f"filtered_patterns=\n{filtered_patterns}")
# Group by uid and two_degree_connected and count distinct one_degree_connected
grouped = filtered_patterns.groupby(['uid', 'two_degree_connected'])['one_degree_connected'].nunique().reset_index()
grouped.rename(columns={'one_degree_connected': 'common_friend'}, inplace=True)
# Filter pairs that are not friends, in the original dataset
friendship_pairs = friendship.apply(lambda row: f"{row['user1_id']},{row['user2_id']}", axis=1)
grouped['pair'] = grouped.apply(lambda row: f"{row['uid']},{row['two_degree_connected']}", axis=1)
strong_pairs = grouped[grouped['pair'].isin(friendship_pairs)]
# Filter out pairs with less than 3 common friends
strong_pairs = strong_pairs[strong_pairs['common_friend'] >= 3]
# Select required columns and sort based on the conditions
result = strong_pairs[['uid', 'two_degree_connected', 'common_friend']].sort_values(by=['uid', 'two_degree_connected', 'common_friend'])
# Rename columns to match the output of the SQL query
result.rename(columns={'uid': 'user1_id', 'two_degree_connected': 'user2_id'}, inplace=True)
return result
等效Cypher查询(未验证)
MATCH (u1)-[:FRIENDSHIP]-(common_friend)-[:FRIENDSHIP]-(u2),
(u1)-[:FRIENDSHIP]-(u2)
WHERE NOT (u1)-[:FRIENDSHIP]-(u2)-[:FRIENDSHIP]-(common_friend)
WITH u1, u2, COLLECT(DISTINCT common_friend) AS common_friends
WHERE SIZE(common_friends) >= 3
RETURN u1 AS user1_id, u2 AS user2_id, SIZE(common_friends) AS common_friend_count
ORDER BY user1_id, user2_id, common_friend_count