A[POJ1741].给定一棵树,边有权,求长度不超过\(k\)的路径数目。
B[HDU4871].给定一张图,边有权,求它的最短路径树上恰含\(k\)个点的路径中最长路径的长度及数目。
C[HDU4812].给定一棵树,点有权,求字典序最小的一个点对,其路径上的所有点权之积模\(100003\)等于\(k\)。
D[HDU5469].给定一棵树,点上有字母,给定一个询问串,求是否存在一条路径,其所有点的字母依次连接为询问串。
E[HDU4670].给定一棵树,点有权,点权的质因数集合给定(大小\(\le30\)),求所有点权乘积为立方数的路径数目。
F[HDU5664].给定一棵树,边有权,对所有不具有祖先关系的无序点对求距离,求距离中的第\(k\)大值。
G[HDU5977].给定一棵树,点有类型(数目\(\le10\)),求所有类型的点都出现的路径数。
H[HDU5314].给定一棵树,点有权,求点权极差不超过\(d\)的路径数目。
I[HDU5909].给定一棵树,点有权(\(\le2^m,m\le10\)),对\(0\le i \lt 2^m\),求点权异或和为\(i\)的连通块数目。
J[HDU5102].给定一棵树,求边数最小的\(k\)条路径的边数之和。
K[洛谷P6329].给定一棵树,在线修改点权,询问距离某个点的距离不超过\(k\)的点权和。
A点分治模板。
B还是模板。
C依然是模板。
D要用哈希优化,其它还是模板。
E用三进制状压,其它还是模板。
F二分,注意预处理点分治的结果。
G状压,统计答案时直接暴力枚举。
H用树状数组优化,比较板。
I点分治,求DFS序转化为序列DP,具体看代码。
J和F差不多。
K点分树模板。
点击查看A题代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<cstring>
using namespace std;
const int N=1e4+5,M=1e7+5,INF=1<<30;
int n,k,ans;
int rt,sum,vis[N],siz[N],mx[N];
int head[N],ver[N<<1],nxt[N<<1],val[N<<1],tot;
void adde(int u,int v,int w){
ver[++tot]=v;
val[tot]=w;
nxt[tot]=head[u];
head[u]=tot;
}
int c[M+5];
void add(int x,int v){for(x++;x<=M;x+=x&-x)c[x]+=v;}
int ask(int x){int res=0;for(x++;x;x-=x&-x)res+=c[x];return res;}
queue<int> tmp,tmp1;
void calcsize(int u,int fa){
siz[u]=1;mx[u]=0;
for(int i=head[u];i;i=nxt[i])
if(ver[i]!=fa&&!vis[ver[i]]){
calcsize(ver[i],u);
siz[u]+=siz[ver[i]];
mx[u]=max(mx[u],siz[ver[i]]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt])rt=u;
}
void calcdist(int u,int fa,int dis){
if(dis>k)return;
ans=ans+ask(k-dis);
tmp.push(dis);tmp1.push(dis);
for(int i=head[u];i;i=nxt[i])
if(ver[i]!=fa&&!vis[ver[i]])
calcdist(ver[i],u,dis+val[i]);
}
void dfs(int u,int fa){
vis[u]=1;add(0,1);
for(int i=head[u];i;i=nxt[i]){
if(ver[i]==fa||vis[ver[i]])continue;
calcdist(ver[i],u,val[i]);
while(!tmp.empty()){
add(tmp.front(),1);
tmp.pop();
}
}
while(!tmp1.empty()){
add(tmp1.front(),-1);
tmp1.pop();
}add(0,-1);
for(int i=head[u];i;i=nxt[i])
if(ver[i]!=fa&&!vis[ver[i]]){
sum=siz[ver[i]];rt=0;
mx[rt]=INF;
calcsize(ver[i],u);
calcsize(rt,-1);
dfs(rt,u);
}
}
void init(){
tot=ans=0;
memset(head,0,sizeof(head));
memset(vis,0,sizeof(vis));
}
int main(){
while(scanf("%d%d",&n,&k),n!=0&&k!=0){
init();
for(int i=1,u,v,w;i<n;i++){
scanf("%d%d%d",&u,&v,&w);
adde(u,v,w);adde(v,u,w);
}
rt=0;sum=n;mx[rt]=INF;
calcsize(1,-1);
calcsize(rt,-1);
dfs(rt,-1);
printf("%d\n",ans);
}
return 0;
}
点击查看B题代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1.2e5+5;
int Test,n,m,k;
struct Graph{
int head[N],nxt[N],ver[N],val[N],tot=0;
void add(int u,int v,int w){
ver[++tot]=v;val[tot]=w;
nxt[tot]=head[u];head[u]=tot;
}
void init(){
tot=1;
memset(head,0,sizeof(head));
}
}G,T;
struct Dijkstra{
int dis[N],vis[N],pre[N];
priority_queue<pair<int,int> >Q;
void solve(){
memset(dis,0x3f,sizeof(dis));
memset(vis,0,sizeof(vis));
Q.push(make_pair(1,0));dis[1]=0;
while(!Q.empty()){
int u=Q.top().first;Q.pop();
if(vis[u])continue;
for(int i=G.head[u],v;i;i=G.nxt[i])
if(dis[v=G.ver[i]]>dis[u]+G.val[i]||
dis[v]==dis[u]+G.val[i]&&u<G.ver[pre[v]]){
dis[v]=dis[u]+G.val[i];pre[v]=i^1;
Q.push(make_pair(v,-dis[v]));
}
}
}
}Dijk;
struct Point_Divide{
int sum,rt,sz[N],mx[N],vis[N],lim;
ll ans,cnt,maxd[N],cntd[N];
void csize(int u,int fa){
sz[u]=1;mx[u]=0;
for(int i=T.head[u],v;i;i=T.nxt[i])
if((v=T.ver[i])!=fa&&!vis[v])
csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
mx[u]=max(mx[u],sum-sz[u]);
if(mx[u]<mx[rt])rt=u;
}
void calc(int u,int fa,ll dis,int dep){
if(dep<k&&cntd[k-dep-1]){
if(dis+maxd[k-dep-1]>ans)ans=dis+maxd[k-dep-1],cnt=cntd[k-dep-1];
else if(dis+maxd[k-dep-1]==ans)cnt+=cntd[k-dep-1];
}
for(int i=T.head[u],v;i;i=T.nxt[i])
if((v=T.ver[i])!=fa&&!vis[v])calc(v,u,dis+T.val[i],dep+1);
}
void upd(int u,int fa,ll dis,int dep){
lim=max(lim,dep);
if(dis>maxd[dep])maxd[dep]=dis,cntd[dep]=1;
else if(dis==maxd[dep])++cntd[dep];
for(int i=T.head[u],v;i;i=T.nxt[i])
if((v=T.ver[i])!=fa&&!vis[v])upd(v,u,dis+T.val[i],dep+1);
}
void solve(int u){
maxd[0]=0;cntd[0]=1;lim=0;
for(int i=T.head[u],v;i;i=T.nxt[i])
if(!vis[v=T.ver[i]]){calc(v,u,T.val[i],1);upd(v,u,T.val[i],1);}
for(int i=1;i<=lim;i++)maxd[i]=cntd[i]=0;
}
void dfs(int u){
vis[u]=1;solve(u);
for(int i=T.head[u],v;i;i=T.nxt[i])
if(!vis[v=T.ver[i]]){sum=sz[v];rt=0;csize(v,-1);csize(rt,-1);dfs(rt);}
}
void work(){
memset(maxd,0,sizeof(maxd));
memset(cntd,0,sizeof(cntd));
memset(vis,0,sizeof(vis));
ans=cnt=0;mx[rt=0]=sum=n;
csize(1,-1);csize(rt,-1);dfs(rt);
}
}PD;
int main(){
scanf("%d",&Test);
while(Test--){
G.init();T.init();
scanf("%d%d%d",&n,&m,&k);
for(int i=1,u,v,w;i<=m;i++){
scanf("%d%d%d",&u,&v,&w);
G.add(u,v,w);G.add(v,u,w);
}
Dijk.solve();
for(int i=2;i<=n;i++){
T.add(i,G.ver[Dijk.pre[i]],G.val[Dijk.pre[i]]);
T.add(G.ver[Dijk.pre[i]],i,G.val[Dijk.pre[i]]);
}
PD.work();
printf("%lld %lld\n",PD.ans,PD.cnt);
}
return 0;
}
点击查看C题代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5,mod=1e6+3;
int n,k,sum,rt,siz[N],mx[N],vis[N];
int now,a[N],ansu,ansv,f[mod+5],inv[mod+5],z;
int head[N],nxt[N<<1],ver[N<<1],tot;
int tmp[N][2],num;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
void calcsize(int u,int fa){
siz[u]=1;mx[u]=0;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v]){
calcsize(v,u);siz[u]+=siz[v];
mx[u]=max(mx[u],siz[v]);
}
mx[u]=max(mx[u],sum-siz[u]);
if(mx[u]<mx[rt])rt=u;
}
void chk(int u,int v){
if(u==v)return;
if(u>v)swap(u,v);
if(ansu==-1)ansu=u,ansv=v;
else if(u<ansu||u==ansu&&v<ansv)ansu=u,ansv=v;
}
void calc(int u,int fa,int val){
++num;tmp[num][0]=val;tmp[num][1]=u;
if(1ll*val*a[now]%mod==k)chk(u,now);
if(f[z=1ll*k*inv[val]%mod*inv[a[now]]%mod]!=-1)chk(u,f[z]);
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])
calc(v,u,1ll*val*a[v]%mod);
}
void dfs(int u,int fa){
now=u;vis[u]=1;num=0;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v]){
int num0=num+1;calc(v,u,a[v]);
for(int j=num0;j<=num;j++){
if(f[tmp[j][0]]==-1)f[tmp[j][0]]=tmp[j][1];
else f[tmp[j][0]]=min(f[tmp[j][0]],tmp[j][1]);
}
}
for(int j=1;j<=num;j++)f[tmp[j][0]]=-1;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v]){
sum=siz[v];rt=0;
calcsize(v,u);calcsize(rt,-1);
dfs(rt,u);
}
}
void init(){
ansu=ansv=-1;num=tot=0;
for(int i=1;i<=n;i++)head[i]=vis[i]=0;
}
int main(){
memset(f,-1,sizeof(f));inv[1]=1;
for(int i=2;i<mod;i++)inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod;
while(scanf("%d%d",&n,&k)!=EOF){
init();
for(int i=1;i<=n;i++)scanf("%d",a+i);
for(int i=1,u,v;i<n;++i){
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
sum=n;rt=0;mx[rt]=1<<29;
calcsize(1,-1);calcsize(rt,-1);dfs(rt,-1);
if(ansu!=-1)printf("%d %d\n",ansu,ansv);
else printf("No solution\n");
}
return 0;
}
点击查看D题代码
#pragma GCC optimize(2)
#pragma GCC optimize(3."Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
typedef unsigned long long ull;
const int N=10005;
int T,n,len,ans,sum,rt,sz[N],mx[N],vis[N],y[N],a[2*N],f[2*N],cnt,p0,lenx;
char s[N],t[N];
int head[N],nxt[N<<1],ver[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
ull pre[N],suf[N],base=131,p[N],x[2*N];
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
inline void csize(int u,int fa){
sz[u]=1;mx[u]=0;
for(register int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v]){
csize(v,u);sz[u]+=sz[v];
mx[u]=max(mx[u],sz[v]);
}
mx[u]=max(mx[u],sum-sz[u]);
if(mx[u]<mx[rt])rt=u;
}
inline void calc(int u,int fa,ull val,int dep){
if(dep>len)return;
p0=lower_bound(x+1,x+lenx+1,val)-x;
if(x[p0]==val){y[++cnt]=p0;if(f[a[p0]])ans=1;}
for(register int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])
calc(v,u,val+p[dep+1]*s[v],dep+1);
}
inline void dfs(int u,int fa){
vis[u]=1;cnt=0;
p0=lower_bound(x+1,x+lenx+1,1llu*s[u])-x;
if(x[p0]==1llu*s[u])y[++cnt]=p0,f[p0]=1;
for(register int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v]){
int cnt0=cnt+1;calc(v,u,s[u]+base*s[v],1);
for(int j=cnt0;j<=cnt;++j)f[y[j]]=1;
}
for(register int j=1;j<=cnt;++j)f[y[j]]=0;
for(register int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v]){
sum=sz[v];rt=0;
csize(v,u);csize(rt,-1);dfs(rt,u);
}
}
int main(){
T=read();
for(register int cas=1;cas<=T;++cas){
ans=tot=0;
for(int i=1;i<=n;i++)head[i]=vis[i]=0;
n=read();
for(register int i=1,u,v;i<n;++i){
u=read();v=read();
add(u,v);add(v,u);
}
scanf("%s%s",s+1,t+1);
len=strlen(t+1);
if(len==1){
for(register int i=1;i<=n;++i)
if(s[i]==t[1])ans=1;
}
else{
p[0]=1;pre[0]=suf[len+1]=0;
for(register int i=1;i<=len;++i){
x[i]=pre[i]=pre[i-1]*base+t[i];
p[i]=p[i-1]*base;
}
for(register int i=len;i>=1;--i)
x[i+len]=suf[i]=suf[i+1]*base+t[i];
sort(x+1,x+2*len+1);
lenx=unique(x+1,x+2*len+1)-x-1;
for(register int i=1;i<=len;++i){
int p1=lower_bound(x+1,x+lenx+1,pre[i])-x,
p2=lower_bound(x+1,x+lenx+1,suf[i])-x;
a[p1]=p2;a[p2]=p1;
}
sum=n;rt=0;mx[rt]=1<<30;
csize(1,-1);csize(rt,-1);dfs(rt,-1);
}
printf("Case #%d: ",cas);
puts(ans?"Find":"Impossible");
}
return 0;
}
点击查看E题代码
#pragma GCC optimize(2)
#pragma GCC optimize(3."Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef vector<int> vi;
const int N=5e4+5;
int n,k,p[35],sum,rt,sz[N],mx[N],vis[N];
ll ans,pwr3[35],x;
int head[N],nxt[N<<1],ver[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
struct P{ll v;}a[N];
P operator +(const P&a,const P&b){
P c;c.v=0;
for(int i=0;i<k;i++)c.v+=(a.v/pwr3[i]%3+b.v/pwr3[i]%3)%3*pwr3[i];
return c;
}
P opp(P a){
P b;b.v=0;
for(int i=0;i<k;i++)b.v+=(3-a.v/pwr3[i]%3)%3*pwr3[i];
return b;
}
bool operator <(const P&a,const P&b){return a.v<b.v;}
map<P,ll> M;
void csize(int u,int fa){
sz[u]=1;mx[u]=0;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v]){
csize(v,u);sz[u]+=sz[v];
mx[u]=max(mx[u],sz[v]);
}
mx[u]=max(mx[u],sum-sz[u]);
if(mx[u]<mx[rt])rt=u;
}
void calc(int u,int fa,P val){
ans+=M[opp(val)];
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])calc(v,u,val+a[v]);
}
void upd(int u,int fa,P val){
++M[val];
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])upd(v,u,val+a[v]);
}
void dfs(int u,int fa){
if(a[u].v==0)++ans;
vis[u]=1;M.clear();M[P{0}]=1;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v]){calc(v,u,a[u]+a[v]);upd(v,u,a[v]);}
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v]){
sum=sz[v];rt=0;
csize(v,u);csize(rt,-1);dfs(rt,u);
}
}
int main(){
pwr3[0]=1;
for(int i=1;i<=30;i++)pwr3[i]=3*pwr3[i-1];
while(scanf("%d",&n)!=EOF){
ans=tot=0;
for(int i=1;i<=n;i++)head[i]=vis[i]=a[i].v=0;
scanf("%d",&k);
for(int i=1;i<=k;i++)scanf("%d",p+i);
for(int i=1;i<=n;i++){
scanf("%lld",&x);
for(int j=1,c=0;j<=k;j++){
for(c=0;x%p[j]==0;c++)x/=p[j];
a[i].v+=c%3*pwr3[j-1];
}
}
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
sum=n;rt=0;mx[rt]=1<<30;
csize(1,-1);csize(rt,-1);dfs(rt,-1);
printf("%lld\n",ans);
}
return 0;
}
点击查看F题代码
#include<bits/stdc++.h>
using namespace std;
const int N=5e4+5;
int T,n,m,k,d,ans,cnt[N],sum,rt,sz[N],mx[N],vis[N],now0,now;
int head[N],nxt[N<<1],ver[N<<1],val[N<<1],tot;
void add(int u,int v,int w){
ver[++tot]=v;val[tot]=w;
nxt[tot]=head[u];head[u]=tot;
}
void csize(int u,int fa){
sz[u]=1;mx[u]=0;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])
csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
mx[u]=max(mx[u],sum-sz[u]);
if(mx[u]<mx[rt])rt=u;
}
struct node{int id,dis;};
vector<node> a[N];
bool operator <(const node&a,const node&b){return a.dis<b.dis;}
void calc(int u,int fa,int dist){
a[now0].push_back(node{now,dist});
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])calc(v,u,dist+val[i]);
}
void dfs(int u){
vis[now0=u]=1;a[u].push_back(node{0,0});
for(int i=head[u],v;i;i=nxt[i])
if(!vis[v=ver[i]])now=v,calc(v,u,val[i]);
sort(a[u].begin(),a[u].end());
for(int i=head[u],v;i;i=nxt[i])
if(!vis[v=ver[i]]){
sum=sz[v];rt=0;
csize(v,-1);csize(rt,-1);
dfs(rt);
}
}
void solve1(int u,int fa){
int t=a[u].size();
for(int l=0;l<t;l++)cnt[a[u][l].id]=0;
for(int l=0,r=t-1;l<t;l++){
while(r>=0&&a[u][l].dis+a[u][r].dis>=d)++cnt[a[u][r].id],--r;
ans+=t-1-r-cnt[a[u][l].id];
}
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa)solve1(v,u);
}
vector<int> D;
void solve2(int u,int fa,int dis){
D.push_back(dis);
ans-=upper_bound(D.begin(),D.end(),dis-d)-D.begin();
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa)solve2(v,u,dis+val[i]);
D.pop_back();
}
int main(){
mx[0]=1<<30;
scanf("%d",&T);
while(T--){
tot=0;
for(int i=1;i<=n;i++)head[i]=vis[i]=0,a[i].clear();
scanf("%d%d%d",&n,&m,&k);
for(int i=1,u,v,w;i<n;i++){
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);add(v,u,w);
}
sum=n;rt=0;
csize(1,-1);csize(rt,-1);dfs(rt);
int L=1,R=5e8,res=-1;
while(L<=R){
d=L+R>>1;
ans=0;solve1(m,-1);ans/=2;solve2(m,-1,0);
if(ans>=k)res=d,L=d+1;
else R=d-1;
}
if(res==-1)printf("NO\n");
else printf("%d\n",res);
}
return 0;
}
点击查看G题代码
#include<bits/stdc++.h>
using namespace std;
const int N=5e4+5;
int n,k,col[N],f[N],x[N],cnt,sum,rt,sz[N],mx[N],vis[N];long long ans;
int head[N],nxt[N<<1],ver[N<<1],tot;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
void csize(int u,int fa){
sz[u]=1;mx[u]=0;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])
csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
mx[u]=max(mx[u],sum-sz[u]);
if(mx[u]<mx[rt])rt=u;
}
void calc(int u,int fa,int st){
int t=((1<<k)-1)^st;ans+=f[t];
for(int i=st;i;i=(i-1)&st)ans+=f[i^t];
x[++cnt]=st;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])calc(v,u,st|(1<<col[v]));
}
void dfs(int u){
vis[u]=1;f[x[cnt=1]=(1<<col[u])]=1;
for(int i=head[u],v;i;i=nxt[i])
if(!vis[v=ver[i]]){
int cnt0=cnt+1;calc(v,u,(1<<col[u])|(1<<col[v]));
for(int j=cnt0;j<=cnt;j++)f[x[j]]++;
}
for(int i=0;i<(1<<k);i++)f[i]=0;
for(int i=head[u],v;i;i=nxt[i])
if(!vis[v=ver[i]]){
sum=sz[v];rt=0;
csize(v,u);csize(rt,u);dfs(rt);
}
}
int main(){
while(scanf("%d%d",&n,&k)!=EOF){
ans=tot=0;
for(int i=1;i<=n;i++)head[i]=vis[i]=0;
for(int i=1;i<=n;i++)scanf("%d",col+i),--col[i];
for(int i=1,u,v;i<n;i++){scanf("%d%d",&u,&v);add(u,v);add(v,u);}
if(k==1){printf("%lld\n",1ll*n*n);continue;}
sum=n;mx[rt=0]=1<<30;
csize(1,-1);csize(rt,-1);dfs(rt);
printf("%lld\n",2*ans);
}
return 0;
}
点击查看H题代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int T,n,m,k,d,p[N],q[N],sum,rt,sz[N],mx[N],vis[N],now;long long ans;
int head[N],nxt[N<<1],ver[N<<1],tot;
void adde(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
void csize(int u,int fa){
sz[u]=1;mx[u]=0;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])
csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
mx[u]=max(mx[u],sum-sz[u]);
if(mx[u]<mx[rt])rt=u;
}
struct node{int min,max;}a[N];
bool operator <(const node&a,const node&b){
return a.max==b.max?a.min<b.min:a.max<b.max;
}
int c[N];
void modify(int x,int v){for(;x<=m;x+=x&-x)c[x]+=v;}
int query(int x){int res=0;for(;x;x-=x&-x)res+=c[x];return res;}
void calc(int u,int fa,int mn,int mx){
a[++k]={mn,mx};
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])calc(v,u,min(mn,p[v]),max(mx,p[v]));
}
void dfs(int u){
vis[u]=1;a[k=1]={p[u],p[u]};
for(int i=head[u],v;i;i=nxt[i])
if(!vis[v=ver[i]]){
int k0=k+1;
calc(v,u,min(p[u],p[v]),max(p[u],p[v]));
sort(a+k0,a+k+1);
for(int j=k0;j<=k;j++){
int tmp=lower_bound(q+1,q+m+1,q[a[j].max]-d)-q-1;
if(a[j].min>tmp)ans-=j-k0-query(tmp);modify(a[j].min,1);
}
for(int j=k0;j<=k;j++)modify(a[j].min,-1);
}
sort(a+1,a+k+1);
for(int j=1;j<=k;j++){
int tmp=lower_bound(q+1,q+m+1,q[a[j].max]-d)-q-1;
if(a[j].min>tmp)ans+=j-1-query(tmp);modify(a[j].min,1);
}
for(int j=1;j<=k;j++)modify(a[j].min,-1);
for(int i=head[u],v;i;i=nxt[i])
if(!vis[v=ver[i]]){
sum=sz[v];rt=0;
csize(v,-1);csize(rt,-1);
dfs(rt);
}
}
int main(){
mx[0]=1<<30;
scanf("%d",&T);
while(T--){
ans=tot=0;
for(int i=1;i<=n;i++)head[i]=vis[i]=0;
scanf("%d%d",&n,&d);
for(int i=1;i<=n;i++)scanf("%d",p+i),q[i]=p[i];
sort(q+1,q+n+1);m=unique(q+1,q+n+1)-q-1;
for(int i=1;i<=n;i++)p[i]=lower_bound(q+1,q+m+1,p[i])-q;
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
adde(u,v);adde(v,u);
}
sum=n;rt=0;
csize(1,-1);csize(rt,-1);dfs(rt);
printf("%lld\n",ans*2);
}
return 0;
}
点击查看I题代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1145,mod=1e9+7;
int Test,n,m,a[N];
struct Graph{
int head[N],nxt[N<<1],ver[N<<1],tot=0;
void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
void init(){tot=1;memset(head,0,sizeof(head));}
}T;
struct Point_Divide{
int sum,rt,sz[N],mx[N],vis[N];
int R[N],dfn,id[N],f[N][N],ans[N];
void csize(int u,int fa){
sz[u]=1;mx[u]=0;
for(int i=T.head[u],v;i;i=T.nxt[i])
if((v=T.ver[i])!=fa&&!vis[v])
csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
mx[u]=max(mx[u],sum-sz[u]);
if(mx[u]<mx[rt])rt=u;
}
void dfs2(int u,int fa){
id[++dfn]=u;
for(int i=T.head[u],v;i;i=T.nxt[i])
if((v=T.ver[i])!=fa&&!vis[v])dfs2(v,u);
R[u]=dfn;
}
void solve(int u){
dfn=0;dfs2(u,-1);
for(int i=0;i<=dfn+1;i++)for(int j=0;j<m;j++)f[i][j]=0;
f[1][a[u]]=1;
for(int i=2;i<=dfn;i++)
for(int j=0;j<m;j++){
f[i][j^a[id[i]]]=(f[i][j^a[id[i]]]+f[i-1][j])%mod;
f[R[id[i]]][j]=(f[R[id[i]]][j]+f[i-1][j])%mod;
}
for(int j=0;j<m;j++)ans[j]=(ans[j]+f[dfn][j])%mod;
}
void dfs(int u){
vis[u]=1;solve(u);
for(int i=T.head[u],v;i;i=T.nxt[i])
if(!vis[v=T.ver[i]]){sum=sz[v];rt=0;csize(v,-1);csize(rt,-1);dfs(rt);}
}
void work(){
memset(ans,0,sizeof(ans));
memset(vis,0,sizeof(vis));
mx[rt=0]=sum=n;
csize(1,-1);csize(rt,-1);dfs(rt);
}
}PD;
int main(){
scanf("%d",&Test);
while(Test--){
T.init();
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)scanf("%d",a+i);
for(int i=1,u,v;i<n;i++){
scanf("%d%d",&u,&v);
T.add(u,v);T.add(v,u);
}
PD.work();
printf("%d",PD.ans[0]);
for(int j=1;j<m;j++)printf(" %d",PD.ans[j]);
printf("\n");
}
return 0;
}
点击查看J题代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=1e5+5;
int T,n,k,cnt[N],cntl[N],sum,rt,sz[N],mx[N],vis[N],now0,now;ll length,ans;
int head[N],nxt[N<<1],ver[N<<1],tot;
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
return x*f;
}
inline void add(int u,int v){ver[++tot]=v;nxt[tot]=head[u];head[u]=tot;}
inline void csize(int u,int fa){
sz[u]=1;mx[u]=0;
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])
csize(v,u),sz[u]+=sz[v],mx[u]=max(mx[u],sz[v]);
mx[u]=max(mx[u],sum-sz[u]);
if(mx[u]<mx[rt])rt=u;
}
struct node{int id,dis;};
vector<node> a[N];
bool operator <(const node&a,const node&b){return a.dis<b.dis;}
inline void calc(int u,int fa,int dist){
a[now0].push_back(node{now,dist});
for(int i=head[u],v;i;i=nxt[i])
if((v=ver[i])!=fa&&!vis[v])calc(v,u,dist+1);
}
inline void dfs(int u){
vis[now0=u]=1;a[u].push_back(node{0,0});
for(int i=head[u],v;i;i=nxt[i])
if(!vis[v=ver[i]])now=v,calc(v,u,1);
sort(a[u].begin(),a[u].end());
for(int i=head[u],v;i;i=nxt[i])
if(!vis[v=ver[i]]){sum=sz[v];rt=0;csize(v,-1);csize(rt,-1);dfs(rt);}
}
inline void solve(int d){
for(int u=1;u<=n;u++){
int t=a[u].size();long long len=0;
for(int l=0;l<t;++l)cnt[a[u][l].id]=cntl[a[u][l].id]=0;
for(int l=0;l<t;++l){
++cnt[a[u][l].id];
cntl[a[u][l].id]+=a[u][l].dis;
len+=a[u][l].dis;
}
for(int l=0,r=t-1;l<t;++l){
while(r>=0&&a[u][l].dis+a[u][r].dis>d){
--cnt[a[u][r].id];
cntl[a[u][r].id]-=a[u][r].dis;
len-=a[u][r].dis;
--r;
}
int tmp=r+1-cnt[a[u][l].id];
ans+=tmp;length+=len-cntl[a[u][l].id]+1ll*tmp*a[u][l].dis;
}
}
}
int main(){
mx[0]=1<<30;
T=read();
while(T--){
tot=0;
for(int i=1;i<=n;++i)head[i]=vis[i]=0,a[i].clear();
n=read();k=read();
for(int i=1,u,v;i<n;++i){u=read();v=read();add(u,v);add(v,u);}
sum=n;rt=0;csize(1,-1);csize(rt,-1);dfs(rt);
int L=1,R=n,res=-1;
while(L<=R){
int mid=L+R>>1;ans=0;solve(mid);ans/=2;
if(ans>=k)res=mid,R=mid-1;else L=mid+1;
}
ans=length=0;solve(res);
printf("%lld\n",length/2-(ans/2-k)*res);
}
return 0;
}