\(O(nd^2)\)
考虑 \(f(i,j,k)\) 表示 dp 到第 \(i\) 维,距离 \(p,q\) 曼哈顿距离 \(j,k\) 的方案数。
考虑朴素转移:
设 \(dis=|p_{i+1}-q_{i+1}|\)。
\[\begin{aligned} f(i+1,j+t,k+dis-t)&\gets f(i,j,k)&(0\leq t\leq dis)\quad &(1)\\ f(i+1,j+d+t,k+t)&\gets f(i,j,k)&(t>0)\quad &(2)\\ f(i+1,j+t,k+d+t)&\gets f(i,j,k)&(t>0)\quad &(3)\\ \end{aligned} \]但是复杂度是 \(O(nd^3)\) 的,优化一下应该可以做到 \(O((n+d)d^2)\)。无法通过。
注意到复杂度瓶颈在区间加,考虑差分优化:
令 \(gs(i,s,mn)\) 为 \(f(i,s-k,k)(k\geq mn)\) 的加法标记,\(g(i,d,mn)\) 为 \(f(i,k-d,k)(k\geq mn)\) 的加法标记。
于是原式改写为:
\[\begin{aligned} g(i,j,k)&=g(i,j,k)+g(i,j,k-1)\\ gs(i,j,k)&=gs(i,j,k)+gs(i,j,k-1)\\ f(i,j,k)&=gs(i,j+k,k)+g(i,j-k,k)\\ gs(i+1,j+k+dis,k)&\gets f(i,j,k)\\ gs(i+1,j+k+dis,k+dis+1)&\gets -f(i,j,k)\\ g(i+1,j+d-k,k+1)&\gets f(i,j,k)\\ g(i+1,j-d-k,k+1)&\gets f(i,j,k)\\ \end{aligned} \]答案即为 \(\sum_{i=0}^d\sum_{j=0}^df(n,i,j)\)。
能过,但不够优秀?
\(O(n^2d)\)
我们考虑分开计数:
如果有 \(j\) 维的坐标在两点之间,即通过 \((1)\) 转移,\(n-j\) 维通过 \((2)(3)\) 转移,那么方案数就是两者之积。
注意到在两点之外对距离的贡献一定是 \(dis_p\gets t,dis_q\gets |p_i-q_i|+t\) 或相反。
先不考虑 \(t\)(超出部分),即在两点之外对距离的贡献是只给一边的距离加上 \(|p_i-q_i|\)。
通过简单 dp,可以求出 \((1)\) 的方案数:
令 \(f(i,j,k)\) 表示 dp 到第 \(i\) 维,有 \(j\) 维超出两点,到 \(p\) 的距离为 \(k\) 的方案数。
于是有转移:
\[\begin{aligned} f(i+1,j,k+t)&\gets f(i,j,k) &(0\le t\le dis)\\ f(i+1,j+1,k)&\gets f(i,j,k)\\ f(i+1,j+1,k+dis)&\gets f(i,j,k) \end{aligned} \]差分优化即可。
考虑超出部分的计数:
设 \(sum=\sum_{i=1}^ndis_i\)。
则考虑枚举(dp 定义中的)\(j,k\),则剩余没用到的距离 \(ri=d-k,rj=d-sum+k\)。
继续枚举 \(T=\sum t\)(超出部分大小)是多少,根据插板法,可以把 \(T\) 非空分配到 \(j\) 维,于是有 \(\binom{T-1}{j-1}\) 种方案。
于是有 \(Ans=\left(\sum_{k=0}^df(n,0,k)\right)+\sum_{j=1}^n\sum_{k=0}^d\sum_{T=0}^{\min(ri,rj)}f(n,j,k)\dbinom{T-1}{j-1}\)。
\(O(nd^2)\),复杂度不对啊?考虑上指标求和。然后变成了:
\(Ans=\left(\sum_{k=0}^df(n,0,k)\right)+\sum_{j=1}^n\sum_{k=0}^df(n,j,k)\dbinom{T}{j}\)。
于是这里就不是瓶颈了。
复杂度 \(O(n^2d)\)。瓶颈在 dp。
\(O(nd^2)\)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 105, D = 1005, p = 998244353;
int f[2][D][D], a[N], b[N], n, d;
int g[2][D * 2][D], g2[2][D * 2][D];
inline int mod(int x) {return x >= p ? x - p : x;}
signed main()
{
ios::sync_with_stdio(0);cin.tie(0);
cin >> n >> d;
for(int i = 1; i <= n; i ++) cin >> a[i];
for(int i = 1; i <= n; i ++) cin >> b[i];
g2[0][0][0] = 1;
for(int i = 0; i <= n; i ++)
{
int i1 = i & 1, i2 = i1 ^ 1;
memset(f[i2], 0, sizeof f[i2]);
memset(g[i2], 0, sizeof g[i2]);
memset(g2[i2], 0, sizeof g2[i2]);
int dis = abs(a[i + 1] - b[i + 1]);
for(int j = 0; j <= d; j ++)
{
for(int k = 0; k <= d; k ++)
{
f[i1][j][k] = mod(g2[i1][j + k][k] + g[i1][j - k + d][k]);
if(j + k + dis <= d * 2)
{
g2[i2][j + k + dis][k + 1] = mod(g2[i2][j + k + dis][k + 1] + f[i1][j][k]);
if(k + dis <= d) g2[i2][j + k + dis][k + dis] = mod(g2[i2][j + k + dis][k + dis] - f[i1][j][k] + p);
}
if(j + dis - k <= d) g[i2][j + dis - k + d][k] = mod(g[i2][j + dis - k + d][k] + f[i1][j][k]);
if(j - dis - k >= -d && k + dis <= d) g[i2][j - dis - k + d][k + dis] = mod(g[i2][j - dis - k + d][k + dis] + f[i1][j][k]);
}
}
for(int j = 0; j <= d * 2; j ++)
for(int k = 1; k <= d; k ++)
{
g2[i2][j][k] = mod(g2[i2][j][k] + g2[i2][j][k - 1]);
g[i2][j][k] = mod(g[i2][j][k] + g[i2][j][k - 1]);
}
}
ll ans = 0;
for(int i = 0; i <= d; i ++)
for(int j = 0; j <= d; j ++)
ans += f[n & 1][i][j];
cout << ans % p;
return 0;
}
\(O(n^2d)\)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1005, D = 1005, p = 998244353;
int a[N], b[N], n, d;
int f[2][N][D], g[2][N][D];
int fac[D], ifac[D];
ll qpow(ll a, ll b)
{
if(!b) return 1;
return ((b & 1) ? a : 1ll) * qpow(a * a % p, b >> 1) % p;
}
void init()
{
fac[0] = ifac[0] = 1;
for(int i = 1; i < D; i ++) fac[i] = 1ll * fac[i - 1] * i % p;
ifac[D - 1] = qpow(fac[D - 1], p - 2);
for(int i = D - 2; i >= 1; i --) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % p;
}
int C(int a, int b)
{
if(a == b) return 1;
if(a < 0 || b < 0 || a < b) return 0;
return 1ll * fac[a] * ifac[b] % p * ifac[a - b] % p;
}
inline int mod(int x) {return x >= p ? x - p : x;}
signed main()
{
ios::sync_with_stdio(0);cin.tie(0);
init();
cin >> n >> d;
int sum = 0;
for(int i = 1; i <= n; i ++) cin >> a[i];
for(int i = 1; i <= n; i ++) cin >> b[i], sum += abs(a[i] - b[i]);
g[1][0][0] = 1, g[1][0][1] = -1;
// int cnt;
for(int i = 1; i <= n + 1; i ++)
{
int i1 = i & 1, i2 = i1 ^ 1;
memset(f[i2], 0, sizeof f[i2]);
memset(g[i2], 0, sizeof g[i2]);
int dis = abs(a[i] - b[i]);
for(int j = 0; j <= i; j ++)
for(int k = 0; k <= d; k ++)
{
if(k) g[i1][j][k] = mod(g[i1][j][k] + g[i1][j][k - 1]);
f[i1][j][k] = mod(f[i1][j][k] + g[i1][j][k]);
f[i2][j + 1][k] = mod(f[i2][j + 1][k] + f[i1][j][k]);
if(k + dis <= d) f[i2][j + 1][k + dis] = mod(f[i2][j + 1][k + dis] + f[i1][j][k]);
g[i2][j][k] = mod(g[i2][j][k] + f[i1][j][k]);
if(k + dis <= d) g[i2][j][k + dis + 1] = mod(g[i2][j][k + dis + 1] - f[i1][j][k] + p);
// cnt ++;
}
}
// cerr << cnt;
int ans = 0;
for(int i = 0; i <= d; i ++)
if(sum - i <= d)
ans = (ans + f[(n + 1) & 1][0][i]) % p;
for(int i = 1; i <= n; i ++)
for(int j = 0; j <= d; j ++)
{
int ri = d - j, rj = d - sum + j;
if(rj < 0 || rj > d) continue;
int s = C(min(ri, rj), i);
ans = (ans + 1ll * f[(n + 1) & 1][i][j] * s) % p;
}
cout << ans;
return 0;
}
标签:gs,int,题解,sum,return,gets,ABC265F,dis
From: https://www.cnblogs.com/adam01/p/18327144