简要题意
给定一棵 \(n\) 个点的树,树有边权。
对每个点维护一个集合 \(S_u\),一开始集合均包含整数 \(123456789123456789\)。
设 \({\rm dis}_{a,b}\) 为树上两点 \(a\),\(b\) 的距离。
共 \(m\) 次操作,分为如下两种:
s t a b
: 设 \(f\) 为 \(s\),\(t\) 路径上的点集,对与 \(\forall t\in f\),给 \(t\) 所对应的集合里加入 \(a\times {\rm dis}_{s,t} + b\)。s t
: 设 \(f\) 为 \(s\),\(t\) 路径上的点集,求 \(\min\{\bigcup_{x\in f} S_x\}\) 的值。
其中 \(1\le n\le 10^5\)。
前置知识
- 树链剖分
- 李超线段树
解题方法
首先看到 \(a\times dis+b\),是一个一次函数的形式,不妨考虑先树剖一下,然后用永久化标记的李超线段树维护最大值。
横坐标自然是取个确定的距离标准。取每个点到根节点的距离 \(dis[i]\) 作为 \(i\) 的横坐标好了,这样对于同一条重链,横坐标还是递增的。
对于每次插入操作,先求出 \(s\) 和 \(t\) 的 \(lca\),预处理出每个点到根节点的距离 \(dis[u]\),然后化简一下式子。
- 对于 \(s\) 到 \(lca\) 的那条路径上的点 \(u\),增加的贡献是
- 对于 \(t\) 到 \(lca\) 的那条路径上的点,增加的贡献是
然后就可以插入了。
现在问题就剩如果维护区间最小值,容易发现对于一个区间内的线段,最小值无非就是从两个端点的地方取到,所以线段树上再维护一个区间最小值,每次向上传递,最后区间查询的时候,经过的区间上存的优势线段也能为答案作出贡献,算上就好了。重点在于怎样合并两条线段,具体可以分为以下三种情况:
-
如果新线段在当前节点的定义域中完全小于原来的线段。应舍去原来的线段,并用当前线段的两个端点更新最小值。
-
新线段在当前节点的定义域中完全大于原来的线段,应舍去这条新线段。
-
新加入的线段和原来的线段在当前的定义域中有交点。求一下两条线段的交点横坐标 \(x\)。如果新线段的左端点大于旧线段的左端点,或 \(x\le mid\),就往左下放;如果新线段的右端点小于旧线段的右交点,或 \(x>mid\),就往右下放。
李超树维护线段是 \(O(\log^2n)\) 的,再加上树剖跳链的复杂度,总的时间复杂度是 \(O(n\log^3n)\),应该是可以过的吧……
#include <iostream>
#define MAXN 100005
using namespace std;
const long long INF = 123456789123456789;
int n, m, u, v, w, op, s, t, tmp, tp;
struct edge{int w, to, nxt;}e[MAXN << 1];
int head[MAXN], cnt = 1;
int fa[MAXN], dep[MAXN], siz[MAXN], son[MAXN];
int dfn[MAXN], vis[MAXN], rnk[MAXN], tot;
long long dis[MAXN], k[MAXN << 1], b[MAXN << 1];
struct seg{
int ls, rs, tag;
long long minn;
}tree[MAXN << 2];
int read(){
int t = 1, x = 0;char ch = getchar();
while(!isdigit(ch)){if(ch == '-')t = -1;ch = getchar();}
while(isdigit(ch)){x = (x << 1)+ (x << 3)+ (ch ^ 48);ch = getchar();}
return x * t;
}
void write(long long x){
if(x < 0){putchar('-');x = -x;}
if(x >= 10)write(x / 10);
putchar(x % 10 ^ 48);
}
void add(int u, int v, int w){
cnt++;e[cnt].w = w;e[cnt].to = v;e[cnt].nxt = head[u];head[u] = cnt;
cnt++;e[cnt].w = w;e[cnt].to = u;e[cnt].nxt = head[v];head[v] = cnt;
}
long long calc(int id, int x){
return k[id] * dis[rnk[x]] + b[id];
}
void dfs1(int now, int fat, int deep){
dep[now] = deep;siz[now] = 1;fa[now] = fat;int maxson = -1;
for(int i = head[now] ; i != 0 ; i = e[i].nxt){
int v = e[i].to, w = e[i].w;
if(v != fat){
dis[v] = dis[now] + w;
dfs1(v, now, deep + 1);siz[now] += siz[v];
if(siz[v] > maxson){
maxson = siz[v];son[now] = v;
}
}
}
}
void dfs2(int now, int fat, int top){
tot++;dfn[now] = tot;rnk[tot] = now;vis[now] = top;
if(son[now] != 0){
dfs2(son[now], now, top);
for(int i = head[now] ; i != 0 ; i = e[i].nxt){
int v = e[i].to;
if(v != fat && v != son[now])dfs2(v, now, v);
}
}
}
int lca(int u, int v){
while(vis[u] != vis[v]){
if(dep[vis[u]] < dep[vis[v]])swap(u, v);
u = fa[vis[u]];
}
return dep[u] < dep[v] ? u : v;
}
void pushup(int node){
if(tree[node].ls == tree[node].rs)
tree[node].minn = min(calc(tree[node].tag, tree[node].ls),
calc(tree[node].tag, tree[node].rs));
else tree[node].minn = min(min(tree[node << 1].minn,
tree[node << 1 | 1].minn),
min(calc(tree[node].tag, tree[node].ls),
calc(tree[node].tag, tree[node].rs)));
}
void build(int node, int left, int right){
tree[node].ls = left;
tree[node].rs = right;
tree[node].minn = INF;
if(left != right){
int mid = (left + right) >> 1;
build(node << 1, left, mid);
build(node << 1 | 1, mid + 1, right);
}
}
void change(int node, int x){
int &y = tree[node].tag;
int mid = (tree[node].ls + tree[node].rs) >> 1;
if(calc(x, mid) < calc(y, mid))swap(x, y);
if(calc(x, tree[node].ls) < calc(y, tree[node].ls))change(node << 1, x);
if(calc(x, tree[node].rs) < calc(y, tree[node].rs))change(node << 1 | 1, x);
}
void update(int node, int left, int right, int x){
if(left <= tree[node].ls && tree[node].rs <= right)return change(node, x), pushup(node), void();
int mid = (tree[node].ls + tree[node].rs) >> 1;
if(left <= mid)update(node << 1, left, right, x);
if(right > mid)update(node << 1 | 1, left, right, x);
pushup(node);
}
long long query(int node, int left, int right){
if(left <= tree[node].ls && tree[node].rs <= right)return tree[node].minn;
int mid = (tree[node].ls + tree[node].rs) >> 1;
long long res = min(calc(tree[node].tag, max(left, tree[node].ls)),
calc(tree[node].tag, min(right, tree[node].rs)));
if(left <= mid)res = min(res, query(node << 1, left, right));
if(right > mid)res = min(res, query(node << 1 | 1, left, right));
return res;
}
void updtree(int u, int v, int x){
while (vis[u] != vis[v]){
if(dep[vis[u]] < dep[vis[v]])swap(u, v);
update(1, dfn[vis[u]], dfn[u], x);
u = fa[vis[u]];
}
if(dep[u] > dep[v])swap(u, v);
update(1, dfn[u], dfn[v], x);
}
long long quetree(int u, int v){
long long res = INF;
while (vis[u] != vis[v]){
if(dep[vis[u]] < dep[vis[v]])swap(u, v);
res = min(res, query(1, dfn[vis[u]], dfn[u]));
u = fa[vis[u]];
}
if(dep[u] > dep[v])swap(u, v);
res = min(res, query(1, dfn[u], dfn[v]));
return res;
}
int main(){
n = read();m = read();
for(int i = 1 ; i < n ; i ++)
u = read(),v = read(),w = read(),add(u, v, w);
dfs1(1, 0, 0);dfs2(1, 0, 1);b[0] = INF;build(1, 1, n);
while(m--){
op = read();
if(op == 1){
u = read();v = read();s = read();t = read();tmp = lca(u, v);
tp++;k[tp] = -s;b[tp] = t + 1LL * dis[u] * s;updtree(u, tmp, tp);
tp++;k[tp] = s;b[tp] = t + 1LL * dis[u] * s - 2LL * dis[tmp] * s;updtree(tmp, v, tp);
}else cout << quetree(read(), read()) << endl;
}
return 0;
}
标签:node,int,题解,线段,tree,P4069,now,dis
From: https://www.cnblogs.com/tsqtsqtsq/p/17818497.html