## 1.换根dp
1.kamp
首先肯定它是换根dp
因为最后一次不用回到起点,所以先不去想别的,最后再减去一个最长链
设g\([i],\)为以i为根前往在这颗子树中的家的最小距离
\(f[i],\)为以i为起点,送回家的最小花费;
首先对于\(g[i],g[u]= g[to]+2\times w;\)
\(sz[i]\)表示在i这颗子树中有多少家
考虑dp:
\(1.sz[to]=0,\)那么\(f[to]=f[u]+2 \times w;\)
\(2.sz[to]=K,\)那么\(f[to]=g[to]\)
\(3.f[to]=f[u]\)
考虑维护最长链,因为之前已经维护好了以1位根的子树中的最长和次长,考虑换根情况下子树之外的最长链
如果to是最长链上的,那么这个链不就废了,所以还需要维护一个次长链
第一种情况下:\(L[to]=L[u]+w;\)之前在维护子树内链的时候,只有sz[i]!=0,才更新,并且L[u]中也包含了子树外的情况,所以正确
第二种情况下:不用维护,本来就是子树内
第三种情况:
u的最长链去更新to的最长链,必须保证to不在最长链上
u的次长链去更新to的最长链
u的最长链去更新to的次长链,必须保证to不在最长链上
u的次长链去更新to的次长链
考虑到u与to相连,所以id[i]记录一下最长链经过的第一个节点。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=2e6+10;
#define int long long
typedef double db;
int n,K;
int head[N],len,sz[N],g[N],f[N],L[N],l[N],id[N];
/*维护最长、次长链 以及最长链经过的第一个节点*/
bool vis[N];
struct node{
int from,to,next,w;
}e[N<<1];
int read(){
int res=0,l=1;char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')l=-1;
c=getchar();
}
while(c>='0'&&c<='9')res=res*10+c-'0',c=getchar();
return res*l;
}
struct Graph{
void add(int from,int to,int w){
e[++len].from=from;
e[len].to=to;
e[len].w=w;
e[len].next=head[from];
head[from]=len;
}
void dfs(int u,int fe){
if(vis[u])sz[u]=1;
for(int i=head[u];i;i=e[i].next){
int to=e[i].to;
if(to==fe)continue;
dfs(to,u);
int w=e[i].w;
if(sz[to]){
g[u]+=g[to]+2*w;
int NEW=L[to]+w;
if(NEW>=L[u]){
l[u]=L[u],L[u]=NEW,id[u]=to;
}else if(NEW>l[u])l[u]=NEW;
}
sz[u]+=sz[to];
}
}
void dp(int u,int fe){
for(int i=head[u];i;i=e[i].next){
int to=e[i].to;
if(to==fe)continue;
int w=e[i].w;
if(!sz[to])f[to]=f[u]+2*w,L[to]=L[u]+w;
else if(K-sz[to]){
f[to]=f[u];
if(id[u]!=to&&L[u]+w>=L[to]){l[to]=L[to];L[to]=L[u]+w;id[to]=u;}
else if(l[u]+w>=L[to]){l[to]=L[to];L[to]=l[u]+w;id[to]=u;}
else if(id[u]!=to&&L[u]+w>=l[to])l[to]=L[u]+w;
else if(l[u]+w>=l[to])l[to]=l[u]+w;
}else f[to]=g[to];
dp(to,u);
}
}
}G;
signed main(){
cin>>n>>K;
for(int i=1;i<n;++i){
int a,b,c;
cin>>a>>b>>c;
G.add(a,b,c);
G.add(b,a,c);
}
for(int i=1,a;i<=K;++i)cin>>a,vis[a]=1;
G.dfs(1,0);
f[1]=g[1];
G.dp(1,0);
for(int i=1;i<=n;++i)printf("%lld\n",f[i]-L[i]);
return 0;
}