这题的方法口糊一下没有很难,没达到3500的水准。但是写起来才发现是真的恶心(主要是容易写错),没写过这么累的题,可能难度就体现在这里吧。
计数的时候是要分类讨论的,但是核心算法都一样:启发式合并,线段树合并。把\(m^2\)对路径分成以下三类,分别统计合法的:
-
两条路径的LCA不同(路径的LCA指的是两个端点的LCA)。发现这两条路径的LCA必须是祖先和后代的关系,不然两条路径不可能有重合。
比如图中的红蓝两条路径就属于这一类,考虑在×处(下面两个端点的LCA)把它们统计进答案。可以在dfs的同时用线段树合并维护所有 有端点在子树内的路径的LCA的深度。在合并两个儿子的时候,把线段树中值的数量较小的拿出来,遍历其中所有的元素,并在大的那个儿子的线段树中询问得到能和当前元素匹配的数量。这部分的复杂度是\(O(nlog^2n)\)。由于n和m同阶,都用n表示了。
-
两条路径的LCA相同,且它们重合的部分分布在LCA的两个子树中。像下面这样:
这种情况和下面的一种情况都需要把所有LCA为x的路径都放到点x处,统一处理它们之间产生的贡献。假设现在处理LCA为root的所有的路径。把这些路径的端点以及root都拿出来建一棵虚树。为了避免重复计数,对于任意两条需要被计数的路径,我们都在它们在原树中dfs序较小的两个端点的LCA处统计,比如上面图中的×处。还是用线段树合并+启发式合并,但这次线段树中只维护每条路径dfs序较小的那个端点的信息。令当前点为pos,在遍历较小的儿子线段树中的一条路径(x,y)时,假设x在pos子树内,y在root的另外一个子树内,则如果我们沿着x→y的方向走k步到点z,那么合法的匹配路径的端点都在z的子树内。同样可以在线段树上查询来统计。
-
两条路径的LCA相同,且它们重合的部分分布在LCA的一个子树中。
这种情况的统计方法和上面是类似的。为了保证重合部分只在一个子树内,需要一次额外dfs对每个点求出它在root的哪个子树里。
总时间复杂度\(O(nlog^2n)\)。
调试太痛苦了
点击查看代码
#include <bits/stdc++.h>
#define rep(i,n) for(int i=0;i<n;++i)
#define repn(i,n) for(int i=1;i<=n;++i)
#define LL long long
#define pii pair <LL,LL>
#define fi first
#define se second
#define mpr make_pair
#define pb push_back
void fileio()
{
#ifdef LGS
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
}
void termin()
{
#ifdef LGS
std::cout<<"\n\nEXECUTION TERMINATED";
#endif
exit(0);
}
using namespace std;
LL n,q,t,fa[150010][23],dep[150010],dfn[150010],ed[150010],ans=0,X[150010],Y[150010],LCA[150010];
vector <LL> g[150010],tg[150010],dford;
LL ll=0;
void dfsPre(int pos,int par,int d)
{
fa[pos][0]=par;dep[pos]=d;dford.pb(pos);
dfn[pos]=ll++;
rep(i,g[pos].size()) if(g[pos][i]!=par) dfsPre(g[pos][i],pos,d+1);
ed[pos]=ll-1;
}
int getLCA(int x,int y)
{
for(int i=19;i>=0;--i) if(fa[x][i]>0&&dep[fa[x][i]]>=dep[y]) x=fa[x][i];
for(int i=19;i>=0;--i) if(fa[y][i]>0&&dep[fa[y][i]]>=dep[x]) y=fa[y][i];
if(x==y) return x;
for(int i=19;i>=0;--i) if(fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
LL getAnces(LL x,LL y){rep(i,20) if(y&(1<<i)) x=fa[x][i];return x;}
namespace st//线段树合并
{
LL n2,dat[10000000],len;
int ls[10000000],rs[10000000];
void init(LL nn)
{
n2=1;while(n2<nn) n2*=2;
len=0;
}
LL newNode()
{
dat[++len]=0;ls[len]=rs[len]=0;
return len;
}
LL newTree(LL lb,LL ub,LL to)
{
LL ret=newNode();dat[ret]=1;
if(lb==ub) return ret;
LL mid=(lb+ub)>>1;
if(to<=mid) ls[ret]=newTree(lb,mid,to);
else rs[ret]=newTree(mid+1,ub,to);
return ret;
}
LL upd(LL k,LL lb,LL ub,LL to)
{
if(k==0) k=newNode();
++dat[k];
if(lb==ub) return k;
LL mid=(lb+ub)>>1;
if(to<=mid) ls[k]=upd(ls[k],lb,mid,to);
else rs[k]=upd(rs[k],mid+1,ub,to);
return k;
}
vector <LL> res;
void getAll(LL k,LL lb,LL ub)
{
if(k==0) return;
if(lb==ub)
{
rep(i,dat[k]) res.pb(lb);
return;
}
LL mid=(lb+ub)>>1;
getAll(ls[k],lb,mid);getAll(rs[k],mid+1,ub);
}
vector <LL> getAll(LL root)
{
res.clear();
getAll(root,0,n2-1);
return res;
}
LL qry(LL k,LL lb,LL ub,LL tlb,LL tub)
{
if(k==0||ub<tlb||tub<lb) return 0;
if(tlb<=lb&&ub<=tub) return dat[k];
return qry(ls[k],lb,(lb+ub)>>1,tlb,tub)+qry(rs[k],((lb+ub)>>1)+1,ub,tlb,tub);
}
LL merge(LL a,LL b)
{
if(a==0||b==0) return a|b;
dat[a]+=dat[b];
ls[a]=merge(ls[a],ls[b]);rs[a]=merge(rs[a],rs[b]);
return a;
}
}
namespace part1
{
vector <LL> v[150010];
LL combine(LL a,LL b,LL curdep)
{
if(a==0||b==0) return a|b;
if(st::dat[a]<st::dat[b]) swap(a,b);
vector <LL> vec=st::getAll(b);
rep(i,vec.size())
{
if(vec[i]>curdep-t) continue;
LL v1=st::qry(a,0,st::n2-1,0,vec[i]-1),v2=st::qry(a,0,st::n2-1,vec[i]+1,curdep-t);
ans+=v1+v2;
}
a=st::merge(a,b);
return a;
}
LL dfs(LL pos,LL par)
{
LL ret=0;
rep(i,v[pos].size())
{
LL nxt=st::newTree(0,st::n2-1,v[pos][i]);
ret=combine(ret,nxt,dep[pos]);
}
rep(i,g[pos].size()) if(g[pos][i]!=par)
{
LL nxt=dfs(g[pos][i],pos);
ret=combine(ret,nxt,dep[pos]);
}
return ret;
}
void countDiffLCA()
{
rep(i,q)
{
v[X[i]].pb(dep[LCA[i]]);
v[Y[i]].pb(dep[LCA[i]]);
}
st::init(n);
dfs(1,0);
}
}
namespace part2
{
vector <pii> pths[150010];
LL curroot,rootdep;
vector <LL> realver;
void buildVT(vector <LL> vers)
{
realver.clear();
rep(i,vers.size()) tg[vers[i]].clear();
sort(vers.begin(),vers.end());vers.erase(unique(vers.begin(),vers.end()),vers.end());
sort(vers.begin(),vers.end(),[](LL xx,LL yy){return dfn[xx]<dfn[yy];});
stack <LL> stk;stk.push(vers[0]);
realver=vers;
repn(i,vers.size()-1)
{
LL pos=vers[i],lca=getLCA(pos,stk.top());
if(lca==stk.top()) stk.push(pos);
else
{
while(dep[stk.top()]>dep[lca])
{
int pp=stk.top();stk.pop();
int nn=stk.top();if(dep[nn]<dep[lca]) nn=lca,tg[lca].clear(),realver.pb(lca);
tg[nn].pb(pp);
}
if(stk.top()!=lca) stk.push(lca);
stk.push(pos);
}
}
while(stk.size()>1)
{
int pp=stk.top();stk.pop();
tg[stk.top()].pb(pp);
}
}
vector <LL> v[150010];
LL fr[150010];
LL walk(LL curpos,LL to,LL stp)
{
LL rd=dep[getLCA(curpos,to)];
LL tot=dep[curpos]+dep[to]-rd*2;
if(tot<stp) return -1;
if(stp<=dep[curpos]-rd) return getAnces(curpos,stp);
return getAnces(to,tot-stp);
}
LL combineTwo(LL a,LL b,LL curpos)
{
if(a==0||b==0) return a|b;
if(st::dat[a]<st::dat[b]) swap(a,b);
vector <LL> vec=st::getAll(b);rep(i,vec.size()) vec[i]=dford[vec[i]];
rep(i,vec.size())
{
LL walkdist=max(t,dep[curpos]-rootdep+1),to=walk(curpos,vec[i],walkdist);
if(to==-1) continue;
LL vv=st::qry(a,0,st::n2-1,dfn[to],ed[to]);
ans+=vv;
}
a=st::merge(a,b);
return a;
}
LL dfsTwo(LL pos)
{
LL ret=0;
rep(i,v[pos].size())
{
LL nxt=st::newTree(0,st::n2-1,dfn[v[pos][i]]);
if(pos!=curroot) ret=combineTwo(ret,nxt,pos);
}
rep(i,tg[pos].size())
{
LL nxt=dfsTwo(tg[pos][i]);
if(pos!=curroot) ret=combineTwo(ret,nxt,pos);
}
return ret;
}
void dfsMarkFr(LL pos,LL mk)
{
if(mk==-1&&pos!=curroot) mk=dfn[pos];
fr[pos]=mk;
rep(i,tg[pos].size()) dfsMarkFr(tg[pos][i],mk);
}
LL combineOne(LL a,LL b)
{
if(a==0||b==0) return a|b;
if(st::dat[a]<st::dat[b]) swap(a,b);
vector <LL> vec=st::getAll(b);
rep(i,vec.size())
{
if(vec[i]==dfn[curroot])
{
ans+=st::dat[a];
continue;
}
LL v1=st::qry(a,0,st::n2-1,0,vec[i]-1),v2=st::qry(a,0,st::n2-1,vec[i]+1,st::n2-1);
ans+=v1+v2;
}
a=st::merge(a,b);
return a;
}
LL dfsOne(LL pos)
{
LL ret=0;
rep(i,v[pos].size())
{
LL nxt=st::newTree(0,st::n2-1,fr[v[pos][i]]);
if(dep[pos]-rootdep>=t) ret=combineOne(ret,nxt);
}
rep(i,tg[pos].size())
{
LL nxt=dfsOne(tg[pos][i]);
if(dep[pos]-rootdep>=t) ret=combineOne(ret,nxt);
}
return ret;
}
void countSameLCA()
{
rep(i,q)
{
if(dfn[X[i]]>dfn[Y[i]]) swap(X[i],Y[i]);
pths[LCA[i]].pb(mpr(X[i],Y[i]));
}
repn(root,n) if(pths[root].size())
{
curroot=root;rootdep=dep[root];
vector <LL> vers={root};
rep(i,pths[root].size()) vers.pb(pths[root][i].fi),vers.pb(pths[root][i].se);
buildVT(vers);
rep(i,realver.size()) v[realver[i]].clear();
rep(i,pths[root].size()) if(pths[root][i].fi!=root&&pths[root][i].se!=root) v[pths[root][i].fi].pb(pths[root][i].se);
st::init(n);
dfsTwo(root);
dfsMarkFr(root,-1);fr[root]=dfn[root];
rep(i,realver.size()) v[realver[i]].clear();
rep(i,pths[root].size()) v[pths[root][i].fi].pb(pths[root][i].se),v[pths[root][i].se].pb(pths[root][i].fi);
st::init(n);
dfsOne(root);
}
}
}
int main()
{
fileio();
cin>>n>>q>>t;
LL x,y;
rep(i,n-1)
{
scanf("%lld%lld",&x,&y);
g[x].pb(y);g[y].pb(x);
}
dfsPre(1,0,0);
rep(i,20) repn(j,n) fa[j][i+1]=fa[fa[j][i]][i];
rep(i,q)
{
scanf("%lld%lld",&X[i],&Y[i]);
LCA[i]=getLCA(X[i],Y[i]);
}
part1::countDiffLCA();
part2::countSameLCA();
cout<<ans<<endl;
termin();
}