点分治
板子题,记一下 $ mod \ 3 $ 意义下余数分别为 $ 1 $ $ 2 $ $ 0 $ 的个数,合并时统计即可;
板子题,开个二元组记录一下权值和边数即可;
板子题,和第一题类似,只不过开个树状数组记录一下前缀和,然后就解决了;
这题。。。我TM调了三个小时,结果学校OJ上还得卡常!!!
只要用上线段树等数据结构,学校OJ就过不去
顺便发泄一下自己的情绪:上次模拟赛T2开了
27个线段树,常数确实有些大,但是Luogu上过了,学校OJ上就咋都过不去;
你可能会说,学校OJ咋能和Luogu比呢?
但今天上午,一道虚树的入门题,时限2s,在Luogu上跑刚过1s,结果在学校OJ上直接TLE俩点,经过我严谨的时间复杂度分析,大约是2e8+1e7,我就不理解了,就TM多这么100ms咋就跑不过去了?(其实可能确实是我菜,整不出题解的优秀复杂度),跟题解一比,多了个线段树的复杂度,整的我现在打比赛都不敢用线段树,但就是每场比赛都TM能想出来用线段树的卡常做法,结果今天上午这题经过_lhx_和cpa一个多小时的大力卡常才勉强过;
现在能力没咋提升,倒是卡常进步不少;
回归正题;
其实思路不难,但细节太多了。。。
首先,路径还是能拆分成两类:经过根的和不经过根的;
所以可以点分治;
首先将每个点的儿子按大小排序,因为这样我们就可以比较方便的处理到根的路径颜色相同的子树们;
然后进行点分治,我们开两个线段树,把与当前路径颜色相同的放进一个线段树,不同的放进一个线段树,然后正常跑就行;
注意线段树的清空,可以直接在根节点上打懒标记;
然后就是一些细节,不说了,可以看代码;
点击查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
int n, m, l, r;
int c[500005];
int rt, sum;
struct sss{
int t, ne, w;
}e[1000005];
int h[1000005], cnt;
void add(int u, int v, int ww) {
e[++cnt].t = v;
e[cnt].ne = h[u];
e[cnt].w = ww;
h[u] = cnt;
}
vector<pair<int, int> > v[200005];
struct sas{
int dis, sum;
}dis[200005], rem[200005], po[200005];
int maxp[1000005], siz[1000005], dep[1000005];
bool vis[1000005];
int ans;
namespace seg{
inline int ls(int x) {
return x << 1;
}
inline int rs(int x) {
return x << 1 | 1;
}
struct asa{
int l, r, ma, lz;
}tr[2][900005];
inline void push_up(int s, int id) {
tr[s][id].ma = max(tr[s][ls(id)].ma, tr[s][rs(id)].ma);
}
inline void push_down(int s, int id) {
if(tr[s][id].lz != 0) {
tr[s][ls(id)].lz = tr[s][id].lz;
tr[s][rs(id)].lz = tr[s][id].lz;
tr[s][ls(id)].ma = tr[s][id].lz;
tr[s][rs(id)].ma = tr[s][id].lz;
tr[s][id].lz = 0;
}
}
void bt(int s, int id, int l, int r) {
tr[s][id].l = l;
tr[s][id].r = r;
if (l == r) {
tr[s][id].ma = -0x3f3f3f3f;
tr[s][id].lz = 0;
return;
}
int mid = (l + r) >> 1;
bt(s, ls(id), l, mid);
bt(s, rs(id), mid + 1, r);
push_up(s, id);
}
inline void clear(int s) {
tr[s][1].lz = -0x3f3f3f3f;
tr[s][1].ma = -0x3f3f3f3f;
}
int ask(int s, int id, int l, int r) {
if (tr[s][id].l >= l && tr[s][id].r <= r) {
return tr[s][id].ma;
}
push_down(s, id);
int mid = (tr[s][id].l + tr[s][id].r) >> 1;
if (r <= mid) return ask(s, ls(id), l, r);
else if (l > mid) return ask(s, rs(id), l, r);
else return max(ask(s, ls(id), l, mid), ask(s, rs(id), mid + 1, r));
}
void add(int s, int id, int pos, int d) {
if (tr[s][id].l == tr[s][id].r) {
tr[s][id].ma = max(tr[s][id].ma, d);
tr[s][id].lz = 0;
return;
}
push_down(s, id);
int mid = (tr[s][id].l + tr[s][id].r) >> 1;
if (pos <= mid) add(s, ls(id), pos, d);
else add(s, rs(id), pos, d);
push_up(s, id);
}
}
void get_rt(int x, int f) {
siz[x] = 1;
maxp[x] = 0;
for (int i = h[x]; i; i = e[i].ne) {
int u = e[i].t;
if (u == f || vis[u]) continue;
get_rt(u, x);
siz[x] += siz[u];
maxp[x] = max(maxp[x], siz[u]);
}
maxp[x] = max(maxp[x], sum - siz[x]);
if (maxp[rt] > maxp[x]) rt = x;
}
void get_dis(int x, int f, int pre) {
dep[x] = dep[f] + 1;
if (dep[x] > r) return;
dis[x].dis = dep[x];
for (int i = h[x]; i; i = e[i].ne) {
int u = e[i].t;
if (vis[u] || u == f) continue;
dis[u].sum = dis[x].sum; //注意继承的问题;
if (e[i].w != pre && e[i].w) dis[u].sum += c[e[i].w]; //注意判断;
get_dis(u, x, e[i].w);
}
}
void dfs(int x, int f) {
if (dis[x].dis == 0) return;
rem[++rem[0].sum] = sas{dis[x].dis, dis[x].sum};
for (int i = h[x]; i; i = e[i].ne) {
int u = e[i].t;
if (u == f || vis[u]) continue;
dfs(u, x);
}
}
void calc(int x) {
int color = 0;
int o = 0;
dep[x] = 0;
for (int i = h[x]; i; i = e[i].ne) {
int u = e[i].t;
if (vis[u]) continue;
if (color == 0) {
color = e[i].w;
} else if (color != e[i].w) {
color = e[i].w;
seg::clear(1);
for (int j = 1; j <= o; j++) {
seg::add(0, 1, po[j].dis, po[j].sum);
}
o = 0;
}
rem[0].sum = 0;
dis[u].sum = c[e[i].w];
get_dis(u, x, e[i].w);
dfs(u, x);
for (int j = 1; j <= rem[0].sum; j++) {
if (rem[j].dis > r) continue;
if (rem[j].dis >= l && rem[j].dis <= r) {
ans = max(ans, rem[j].sum);
}
if (rem[j].dis == r) continue;
if (rem[j].dis == 0) continue;
int aa = seg::ask(0, 1, max(0, l - rem[j].dis), r - rem[j].dis);
int bb = seg::ask(1, 1, max(0, l - rem[j].dis), r - rem[j].dis);
bb -= c[e[i].w];
ans = max(ans, max(rem[j].sum + aa, rem[j].sum + bb));
}
for (int j = 1; j <= rem[0].sum; j++) {
if (rem[j].dis == 0) continue;
o++;
po[o].dis = rem[j].dis;
po[o].sum = rem[j].sum;
}
for (int j = 1; j <= rem[0].sum; j++) {
if (rem[j].dis == 0) continue;
seg::add(1, 1, rem[j].dis, rem[j].sum);
}
}
seg::clear(0);
seg::clear(1);
}
void solve(int x) {
vis[x] = true;
calc(x);
for (int i = h[x]; i; i = e[i].ne) {
int u = e[i].t;
if (vis[u]) continue;
rt = 0;
maxp[rt] = 0x3f3f3f3f;
sum = siz[u];
get_rt(u, 0);
solve(rt);
}
}
int main() {
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin >> n >> m >> l >> r;
for (int i = 1; i <= m; i++) {
cin >> c[i];
}
int x, y, w;
for (int i = 1; i <= n - 1; i++) {
cin >> x >> y >> w;
v[x].push_back({w, y});
v[y].push_back({w, x});
}
for (int i = 1; i <= n; i++) {
sort(v[i].begin(), v[i].end());
}
for (int i = 1; i <= n; i++) {
for (int j = 0; j < v[i].size(); j++) {
add(i, v[i][j].second, v[i][j].first);
}
}
seg::bt(0, 1, 0, n);
seg::bt(1, 1, 0, n);
ans = -0x3f3f3f3f;
rt = 0;
maxp[rt] = 0x3f3f3f3f;
sum = n;
get_rt(1, 0);
solve(rt);
cout << ans;
return 0;
}
貌似题解有单调队列的优秀做法,但我不会,有兴趣的可以去看看;
走了,去卡常了;