2023冲刺国赛模拟20
越来越废物了。
A. 树染色
\(f_{x, 1 / 0}\) 表示考虑 \(x\) 子树内,第一条链为黑色/白色,不考虑第一条链在子树外方案数的答案。
转移枚举第一条链是哪个,用组合数给各个子树的链定序。
code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
int read(){
int x = 0; char c = getchar();
while(!isdigit(c))c = getchar();
do{x = x * 10 + (c ^ 48); c = getchar();}while(isdigit(c));
return x;
}
const int maxn = 5e5 + 55, mod = 998244353;
int qpow(int x, int y){
int ans = 1;
for(; y; y >>= 1, x = 1ll * x * x % mod)if(y & 1)ans = 1ll * ans * x % mod;
return ans;
}
int n, head[maxn], tot;
struct edge{int to, net, a, b;}e[maxn << 1 | 1];
void add(int u, int v, int a, int b){
e[++tot] = {v, head[u], a, b};
head[u] = tot;
}
int fac[maxn], ifac[maxn];
int c(int n, int m){
return 1ll * fac[n] * ifac[m] % mod * ifac[n - m] % mod;
}
int f[maxn][2], dep[maxn], cnt[maxn];
void solve(int x, int fa){
int sp = 1, base = 1;
for(int i = head[x]; i; i = e[i].net){
int v = e[i].to; if(v == fa)continue;
dep[v] = dep[x] + 1;
solve(v, x);
cnt[x] += cnt[v];
base = 1ll * base * ifac[cnt[v]] % mod;
f[v][0] = 1ll * f[v][0] * e[i].a % mod;
f[v][1] = 1ll * f[v][1] * e[i].b % mod;
sp = 1ll * sp * (f[v][0] + f[v][1]) % mod * dep[x] % mod;
}
for(int i = head[x]; i; i = e[i].net){
int v = e[i].to; if(v == fa)continue;
int tmp = 1ll * sp * qpow(1ll * (f[v][0] + f[v][1]) % mod * dep[x] % mod, mod - 2) % mod;
f[x][0] = (f[x][0] + 1ll * f[v][0] * tmp % mod * base % mod * fac[cnt[v]] % mod * fac[cnt[x] - cnt[v]] % mod * c(cnt[x] - 1, cnt[v] - 1)) % mod;
f[x][1] = (f[x][1] + 1ll * f[v][1] * tmp % mod * base % mod * fac[cnt[v]] % mod * fac[cnt[x] - cnt[v]] % mod * c(cnt[x] - 1, cnt[v] - 1)) % mod;
}
if(!cnt[x])f[x][0] = f[x][1] = cnt[x] = 1;
}
int main(){
freopen("treecolor.in","r",stdin);
freopen("treecolor.out","w",stdout);
n = read();
for(int i = 1; i < n; ++i){
int x = read(), y = read(), a = read(), b = read();
add(x, y, a, b); add(y, x, a, b);
}
fac[0] = ifac[0] = 1; for(int i = 1; i <= n; ++i)fac[i] = 1ll * fac[i - 1] * i % mod;
ifac[n] = qpow(fac[n], mod - 2); for(int i = n - 1; i >= 1; --i)ifac[i] = 1ll * ifac[i + 1] * (i + 1) % mod;
dep[1] = 1; solve(1, 0);
printf("%d\n",(f[1][0] + f[1][1]) % mod);
return 0;
}
B. 关路灯
预处理转向点,发现对于每个点出发,只有 \(log\) 个转向点。
对其扫描线。
code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
typedef pair<int, ll> pil;
typedef pair<ll, int> pli;
int read(){
int x = 0; char c = getchar();
while(!isdigit(c))c = getchar();
do{x = x * 10 + (c ^ 48); c = getchar();}while(isdigit(c));
return x;
}
const int maxn = 5e5 + 55;
int n, m, a[maxn], st[20][maxn];
ll ans[maxn];
vector<pii>ql[maxn], qr[maxn];
vector<pil>pl[maxn], pr[maxn];
int queryl(int x, int lim){
for(int i = 18; i >= 0; --i)if(x - (1 << i) + 1 >= 1 && st[i][x - (1 << i) + 1] < lim)x -= (1 << i);
return x;
}
int queryr(int x, int lim){
for(int i = 18; i >= 0; --i)if(x + (1 << i) - 1 <= n && st[i][x] <= lim)x += (1 << i);
return x;
}
pli operator + (const pli &x, const pli &y){return pli(x.first + y.first, x.second + y.second);}
pli operator - (const pli &x, const pli &y){return pli(x.first - y.first, x.second - y.second);}
struct BIT{
pli t[maxn];
int lowbit(int x){return x & -x;}
void add(int x, pli val){while(x <= n)t[x] = t[x] + val, x += lowbit(x);}
pli query(int x){
pli ans;
while(x){ans = ans + t[x]; x -= lowbit(x);}
return ans;
}
pli query(int l, int r){return query(r) - query(l - 1);}
void clear(){memset(t, 0, sizeof(t));}
}T;
int main(){
freopen("light.in","r",stdin);
freopen("light.out","w",stdout);
n = read(), m = read();
for(int i = 1; i <= n; ++i)a[i] = read();
for(int i = 1; i <= m; ++i){
int l = read(), r = read();
ans[i] = 1ll * (r - l + 1) * (a[r] - a[l]);
if(r != l){
ql[l].push_back(pii(r, i));
qr[r].push_back(pii(l, i));
}
}
st[0][1] = INT_MAX;
for(int i = 2; i <= n; ++i)st[0][i] = a[i] - a[i - 1];
for(int i = 1; (1 << i) <= n; ++i)
for(int j = 1; j + (1 << i) - 1 <= n; ++j)
st[i][j] = max(st[i - 1][j], st[i - 1][j + (1 << (i - 1))]);
for(int i = 1; i <= n; ++i){
int l = i, r = i, p = i;
ll sum = 0;
while(l != 1 && r != n){
int pre = a[p] - a[l - 1], nxt = a[r + 1] - a[p];
if(pre < nxt){
int np = queryl(l, nxt);
pl[l - 1].push_back(pil(r, sum + a[p]));
pl[np - 1].push_back(pil(-r, sum + a[p]));
sum += a[p] - a[np];
l = p = np;
}else{
int np = queryr(r + 1, pre) - 1;
pr[r + 1].push_back(pil(l, sum - a[p]));
pr[np + 1].push_back(pil(-l, sum - a[p]));
sum += a[np] - a[p];
r = p = np;
}
}
}
for(int i = 1; i <= n; ++i){
for(auto it : pr[i]){
if(it.first > 0)T.add(it.first, pli(it.second, 1));
else T.add(-it.first, pli(-it.second, -1));
}
for(auto it : qr[i]){
pli res = T.query(it.first + 1, n);
ans[it.second] += res.first + 1ll * res.second * a[i];
}
}
T.clear();
for(int i = n; i >= 1; --i){
for(auto it : pl[i])
if(it.first > 0)T.add(it.first, pli(it.second, 1));
else T.add(-it.first, pli(-it.second, -1));
for(auto it : ql[i]){
pli res = T.query(1, it.first - 1);
ans[it.second] += res.first - 1ll * res.second * a[i];
}
}
for(int i = 1; i <= m; ++i)printf("%lld\n",ans[i]);
return 0;
}
C. 树状数组
\(f_{i, j, k}\) 表示考虑了前 \(i\) 个位置,当前与 \(r\) 的不同的二进制位最高为 \(j\), \(j\) 及以下有 \(k\) 个 \(1\) 的方案数。
倒着推一个转移系数 \(g_{i, j, k}\)。
对于询问把前后缀拼起来。
code
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
int read(){
int x = 0; char c = getchar();
while(!isdigit(c))c = getchar();
do{x = x * 10 + (c ^ 48); c = getchar();}while(isdigit(c));
return x;
}
const int maxn = 1e5 + 55, mod = 998244353;
int n, k, lim, id[25][25], tot, tr[25 * 25][2], f[maxn][25 * 25], g[maxn][25 * 25];
char s[maxn];
void add(int &x, int y){x += y; if(x >= mod)x -= mod;}
int lowbit(int x){return x & -x;}
int tr0(int x){return x - lowbit(x);}
int tr1(int x){return x + lowbit(((1 << k) - 1) ^ x);}
int find(int x){
for(int i = k - 1; i >= 0; --i)if(((x ^ lim) >> i) & 1)return id[i][__builtin_popcount(x & ((1 << i) - 1))];
return 0;
}
int main(){
freopen("fenwick.in","r",stdin);
freopen("fenwick.out","w",stdout);
n = read(), k = read(); lim = read(); scanf("%s",s + 1);
for(int i = 0; i < k; ++i)
for(int j = 0; j <= i; ++j)id[i][j] = ++tot;
for(int i = 0; i < k; ++i)
for(int j = 0; j <= i; ++j){
tr[id[i][j]][0] = j ? id[i][j - 1] : find(tr0(((lim >> i) ^ 1) << i));
tr[id[i][j]][1] = j == i ? find(tr1((((lim >> i) ^ 1) << i) | ((1 << i) - 1))) : id[i][j + 1];
}
tr[find(lim)][0] = find(tr0(lim));
tr[find(lim)][1] = find(tr1(lim));
f[0][find(0)] = 1;
for(int i = 1; i <= n; ++i)
for(int j = 0; j <= tot; ++j){
if(s[i] != '1')add(f[i][tr[j][0]], f[i - 1][j]);
if(s[i] != '0')add(f[i][tr[j][1]], f[i - 1][j]);
}
for(int i = 0; i <= lim; ++i)g[n + 1][find(i)] = 1;
for(int i = n; i >= 1; --i)
for(int j = 0; j <= tot; ++j){
if(s[i] != '1')add(g[i][j], g[i + 1][tr[j][0]]);
if(s[i] != '0')add(g[i][j], g[i + 1][tr[j][1]]);
}
for(int i = 1; i <= n; ++i)
if(s[i] == '1')printf("0\n");
else{
int ans = 0;
for(int s = 0; s <= tot; ++s)add(ans, 1ll * f[i - 1][s] * g[i + 1][tr[s][0]] % mod);
printf("%d\n",ans);
}
return 0;
}