来一个不一样的工程做法。
我们直接对于每一个颜色 \(i\) 建虚树,显然可以通过树形 DP 来判断颜色 \(i\) 的节点是否在一条路径上。原题。
然后存一下这条路径的左右端点,这个实现的话可以分类讨论一下,就分一下竖直链还是绕过 LCA 的链。
做完这个就可以算方案数。
-
如果颜色 \(i\) 没有点。那么方案数显然为路径总数,\(\frac{n(n-1)}{2}\)。
-
如果颜色 \(i\) 只有一个节点 \(u\),那么任何一个包含 \(u\) 的路径都可。这里分成两部分。
-
一种情况是,在它的各个儿子所在的子树中,每个子树选一个点,两两配对;
-
或者,考虑在它儿子中选一个点,和 \(u\) 子树外的节点配对。
这个容易实现,记录一下子树大小然后一遍前缀和即可。
-
-
如果 \(i\) 有多个节点,设两端点为 \(x,y\),\(x\) 深度较小。
-
如果 \(x\) 是 \(y\) 的祖先,那么合法路径的端点可存在于三部分,\(y\) 的儿子,\(x\) 的除了 \(x\to y\) 这个子树上的节点,以及 \(x\) 子树外的节点。这个的实现方法根第二种情况一样。
-
如果 \(x,y\) 绕过了 \(x,y\) 的 LCA,那么方案数肯定是 \(x,y\) 子树各选一个点的方案数。\(siz[x]\times siz[y]\)。
-
然后你就开始写,很快啊!压压行后 6kb 就直接上去了啊!不过完全不过瘾啊!
时间复杂度 \(\Theta(n\log n)\)。
高能代码:
const int MAXN(2e6+10);
int n,vir[MAXN];
vector<int>col[MAXN],G[MAXN];
bool flag[MAXN];
struct E{int to,nxt;};
E edge[MAXN<<1];
int head[MAXN],tot;
inline void add_edge(int u,int v)
{
edge[++tot].nxt=head[u];
head[u]=tot;
edge[tot].to=v;
return;
}
int par[MAXN][23],dep[MAXN],dfn[MAXN],t,siz[MAXN];
inline void dfs(int u,int fa)
{
par[u][0]=fa,dep[u]=dep[fa]+1;
dfn[u]=++t,siz[u]=1;
rep(i,1,22) par[u][i]=par[par[u][i-1]][i-1];
gra(i,u)
{
E e=edge[i];
if(e.to==fa) continue;
dfs(e.to,u);
siz[u]+=siz[e.to];
}
return;
}
inline int LCA(int x,int y)
{
if(dep[x]<dep[y]) Swap(x,y);
rev(i,22,0) if(dep[par[x][i]]>=dep[y]) x=par[x][i];
if(x==y) return x;
rev(i,22,0) if(par[x][i]!=par[y][i]) x=par[x][i],y=par[y][i];
return par[x][0];
}
inline bool cmp(int x,int y){return dfn[x]<dfn[y];}
int que[MAXN],tail,rt,in[MAXN],out[MAXN],f[MAXN];
inline void build(int k)
{
sort(vir+1,vir+1+k,cmp);
int now=k;
rep(i,2,k)
{
int g=LCA(vir[i],vir[i-1]);
if(g!=vir[i]&&g!=vir[i-1]) vir[++now]=g;
}
sort(vir+1,vir+1+now);
now=unique(vir+1,vir+1+now)-vir-1;
sort(vir+1,vir+1+now,cmp);
que[++tail]=vir[1];
rep(i,2,now)
{
int g=LCA(vir[i],vir[i-1]);
G[g].push_back(vir[i]);
que[++tail]=g;
que[++tail]=vir[i];
++in[vir[i]],++out[g];
}
rep(i,1,tail)
{
int u=que[i];
if(in[u]==0) rt=u;
}
return;
}
inline void MEMSET(int c)
{
rep(i,0,(int)col[c].size()-1)
{
int now=col[c][i];
flag[now]=false;
}
rep(i,1,tail)
{
int u=que[i];
G[u].clear();
in[u]=0;
f[u]=-1;
}
tail=0;
return;
}
int ans[2];
int tp2,st2[MAXN];
inline void findinline(int u)
{
bool pp=false;
if(!ans[1]&&flag[u]) ans[1]=u;
rep(i,0,(int)G[u].size()-1)
{
int v=G[u][i];
if(f[v]==1) findinline(v);
pp=true;
}
if(!pp) ans[0]=u;
return;
}
inline void dfs2(int u)
{
bool pp=false;
rep(i,0,(int)G[u].size()-1)
{
int v=G[u][i];
dfs2(v);
pp=true;
}
if(!pp)
{
if(!ans[0]) ans[0]=u;
else ans[1]=u;
}
return;
}
inline void findans()
{
if(f[rt]==1) findinline(rt);
else dfs2(rt);
}
inline ll C(int x)
{
return 1ll*x*(x-1)/2;
}
int st[MAXN],tp;
inline ll fuckyou(int u)
{
ll res=0;
tp=0;
gra(i,u)
{
int v=edge[i].to;
if(v!=par[u][0]) st[++tp]=siz[v];
else st[++tp]=n-siz[u];
}
ll sum=0;
rev(i,tp,1)
{
res=res+st[i]*sum;
sum+=st[i];
}
rep(i,1,tp) res+=st[i];
return res;
}
inline void checkline(int u)
{
bool pp(false);
int tot1=0,tot2=0;
rep(i,0,(int)G[u].size()-1)
{
int v=G[u][i];
checkline(v);
if(f[v]==1) ++tot1;
else if(f[v]==0) ++tot2;
else
{
f[u]=-1;
return;
}
pp=true;
}
if(!pp)
{
f[u]=1;
return;
}
if(tot1+tot2>=3) f[u]=-1;
else if(tot2==2) f[u]=-1;
else if(tot1==2) f[u]=0;
else if(tot2==2) f[u]=-1;
else if(tot1==1&&tot2==1) f[u]=-1;
else if(tot1==1&&tot2==0) f[u]=1;
else if(tot1==0&&tot2==1)
{
if(flag[u]) f[u]=-1;
else f[u]=0;
}
return;
}
inline ll fuckyou2(int x,int y)
{
tp=0,tp2=0;
gra(i,x)
{
int v=edge[i].to;
if(v!=par[x][0])
{
if(LCA(v,y)!=v) st[++tp]=siz[v];
}
else st[++tp]=n-siz[x];
}
gra(i,y)
{
int v=edge[i].to;
if(v!=par[y][0]) st2[++tp2]=siz[v];
}
ll res=0;
rep(i,1,tp) res+=st[i];
rep(i,1,tp2) res+=st2[i];
++res;
ll sum=0;
rep(i,1,tp2) sum+=st2[i];
rep(i,1,tp) res=res+1ll*st[i]*sum;
return res;
}
inline void solve(int k)
{
int fuck=0;
rep(i,1,tail) if(out[que[i]]==0) ++fuck;
if(fuck>2)
{
puts("0");
return;
}
checkline(rt);
if(f[rt]==-1)
{
puts("0");
return;
}
findans();
if(k==1)
{
printf("%lld\n",fuckyou(ans[0]));
return;
}
if(LCA(ans[0],ans[1])==ans[1])
{
// puts("YES");
printf("%lld\n",fuckyou2(ans[1],ans[0]));
return;
}
printf("%lld\n",1ll*siz[ans[0]]*siz[ans[1]]);
return;
}
int main()
{
n=read();
rep(i,1,n)
{
int x=read();
col[x].push_back(i);
}
rep(i,2,n)
{
int u=read(),v=read();
add_edge(u,v),add_edge(v,u);
}
dfs(1,0);
dep[0]=0,dep[n+1]=INF;
memset(f,-1,sizeof(f));
rep(i,1,n)
{
if(col[i].empty())
{
printf("%lld\n",1ll*(n-1)*n/2);
continue;
}
ans[0]=0,ans[1]=0;
int k=col[i].size();
rep(j,0,k-1) vir[j+1]=col[i][j];
rep(i,1,k) flag[vir[i]]=true;
build(k);
solve(k);
MEMSET(k);
}
return 0;
}
/*
Date : 2022/9/22
Author : UperFicial
Start coding at : 19:41
finish debugging at : 7:30
*/
标签:par,return,int,题解,rep,else,ans,hegm,P5588
From: https://www.cnblogs.com/UperFicial/p/16721441.html