CF908D-New Year and Arbitrary Arrangement
前言
不是这题为啥星 \(2200\) 啊,感觉做的很多 \(3000\) 左右的题都比这道题水吧。
简化题意
给定空字符串,每次在串尾加入 \(a\) 或 \(b\) ,各有一定概率。
若其中有 \(\ge k\) 个 \(ab\) 子序列 , 则停止加入。
问至加入结束时,含有 \(ab\) 子序列个数的期望值。
\(k \le 1000\)
题解
感觉一眼概率 \(dp\) .
状态设计的话, \(k\) 这么小,感觉就像是二维。
那么我们分析一下,如果 \(k = 1 , p_a = p_b = \frac{1}{2}\)
存在什么情况捏,分两类。
- 第一位是
a
: \(ab , aab , aaab , aaaaaaaaaaaaaaa \dots b\)
注意好像 \(a\) 可以有无限个。
- 第一位是
b
: \(bab , bbab , baab , \dots\)
我们发现这种情况其实就是第一位是 a
的情况前面不知道加几个 b
。
所以我们只用统计第一种情况即可。
我们发现只要不存在逆天 \(aaaaaaaaaaaa \dots\)
\(a\) 的个数应该是 \(k\) 之内。如果存在较多 \(a\) , 那较多 \(a\) 处应该在串尾。
那我们证一下,如果存在前面 \(k + 1\) 个 \(a\) , 那加一个 \(b\) 直接就停止加入了,所以其在串尾。
有些长得帅的小伙伴就问了,好像 \(b\) 的个数更为确定在 \(k\) 内吧,为什么不用 \(b\) 转移?
欸,其实我最早想的就是 \(b\) , 但是由于无法判断新的概率值,所以没法转移。
所以 \(dp\) 状态很明显了, \(dp_{i , j}\) 表示 \(i\) 个 \(a\) , \(j\) 个答案子序列的概率和。
\[dp_{i , j} = dp_{i - 1 , j} \times p_a + dp_{i , j - i} \times p_b \]好的,那概率有了,怎么统计答案?
呃呃呃这个时候就要想一想,如果每个位置都加 \(aaaaa \dots b\) , 那可能会重复的。
例: \(aabab\) 和 \(aab\)
那怎么整?
我们只需将 \(a\) 还没用满的时候,后面加个 \(b\) , \(a\) 用满后,在加 \(aaaaa \dots b\) .
那这个逆天长串答案怎么统计?
设原串中 \(a\) 有 \(x\) 个,显然答案为:
\[\begin{aligned} &= dp_{i , j} \times \left(\sum{p_b p_a^i \times (i + x)} \right) \\ &= dp_{i , j} \times \left(\frac{x p_b}{1 - p_a} + \frac{p_a + p_b}{(1 - p_a)^2}\right) \end{aligned}\]做完啦!!!
code
CODE
#include <bits/stdc++.h>
using namespace std ;
typedef long long ll ;
const int N = 1e3 + 10 ;
const int mod = 1e9 + 7 ;
ll k , p11 , p22 , p1 , p2 , dp[N][N] ;
inline ll Quick_Pow(ll a , ll b) {
ll ans = 1 ;
while (b) {
if (b & 1) ans = (ans * a) % mod ;
b >>= 1 , a = (a * a) % mod ;
}
return ans ;
}
ll ans = 0 ;
signed main() {
ios::sync_with_stdio(0) , cin.tie(0) , cout.tie(0) ;
cin >> k >> p11 >> p22 ;
p1 = (p11 * Quick_Pow(p11 + p22 , mod - 2)) % mod , p2 = (p22 * Quick_Pow(p11 + p22 , mod - 2)) % mod ;
dp[0][0] = 1 ;
for (int i = 1 ; i <= k ; ++ i) {
for (int j = 0 ; j < k ; ++ j) {
dp[i][j] = (dp[i - 1][j] * p1) % mod ;
if (j >= i) dp[i][j] = (dp[i][j] + dp[i][j - i] * p2 % mod) % mod ;
if (i != k && i + j >= k) ans = (ans + (((dp[i][j] * p2) % mod) * ((1ll * i + j) % mod))) % mod ;
}
}
ll sum1 , sum2 , ny ;
for (int i = 0 ; i < k ; ++ i) {
ny = Quick_Pow((1 - p1 + mod) % mod , mod - 2) ;
sum1 = ((k * p2) % mod * ny) % mod , sum2 = (((p1 * p2) % mod) * ((ny * ny) % mod)) % mod ;
ans = (ans + (dp[k][i] * (sum1 + sum2 + i)) % mod) % mod ;
}
ans = (ans * Quick_Pow((1 + mod - p2) % mod , mod - 2)) % mod ;
cout << ans ;
}