题意简述
一副 \(n+m\) 张牌的扑克牌,\(m\) 张 joker。初始牌堆里有这样一副牌。随机抽一张牌拿走,如果是 joker,将所有牌放回牌堆并打乱。问你抽到过所有 \(n\) 张正常牌的期望抽牌次数是多少?对 \(M = 19260817\) 取模。
\(n \leq 10^8\),\(m \leq 10^{18}\)。
题目分析
概率期望类题目,考虑 DP,并且期望 DP 套路是从后往前递推。
显然应该可以状压 DP,但是其非常不利于后续优化。所以尝试使用线性 DP。
DP 记录值显然是当前状态到终态需要的期望抽牌次数。状态有哪些呢?牌堆、抽到过哪些牌了。如果朴素记录就是状压了,但是我们发现,\(n\) 张普通牌和 \(m\) 张 joker 并没有本质差别,是等价的,无非前者需要区分有没有被抽到过。
不妨使用 \(f_{i,j}\) 表示当前已经抽到过 \(i\) 张牌,牌堆里有 \(j\) 张牌,到终态的期望抽牌次数。我们需要明确的是,哪些状态是合法的,显然需要 \(j\) 中包含 \(m\) 张牌,和 \(n-i\) 张 \(i\) 中没有的普通牌,即 \(j \geq n+m-i\)。对于 \(i+j \geq n+m\) 的情况表示 \(i+j-n-m\) 张牌已经抽到过了,但后来被重新加入牌堆中。
明确好状态,就可以转移了。我们有 \(\frac{n-i}{j}\) 的概率,抽到一张全新的牌,转移到 \(f_{i+1,j-1}\);有 \(\frac{i+j-n-m}{j}\) 的概率,抽到一张抽到过的牌,转移到 \(f_{i,j-1}\);有 \(\frac{m}{j}\) 的概率,抽到 joker,转移到 \(f_{i,n+m}\)。验证一下,\(\frac{n-i}{j}+\frac{i+j-n-m}{j}+\frac{m}{j}=1\),没有问题。
\[\Large f_{i,j}={\textstyle \frac{n-i}{j}}f_{i+1,j-1}+{\textstyle \frac{i+j-n-m}{j}}f_{i,j-1}+{\textstyle \frac{m}{j}}f_{i,n+m}+1 \]边界 \(f_{n,j}=0\),答案 \(f_{0,n+m}\)。这不好递推,怎么办呢?
我们可以把它看做二维平面内的随机游走,向左下、左、行末行走。这个往行末行走就很经典。我们可以设 \(f_{i,j}=k_{i,j}\cdot f_{i,n+m}+b_{i,j}\),从 \(j=n+m-i\) 推到 \(j=n+m\),就是一个方程,方程解出来,\(f_{i}\) 就解出来了。具体可以见文末代码。
上述 DP 时空复杂度 \(\Theta(n^2)\),需要优化。经过打表发现,\(f_{i}\) 对 \(j\) 为等差数列。
【404 not found】
作者太菜了,还不会证。
我们设 \(f_{i,j}=\lambda_i+\mu_i\cdot(n+m-j)\),我们只需要任意两项 \(j\),就能确定 \(\lambda_i, \mu_i\),也就确定了 \(f_{i}\),为了方便起见,取末两项解方程。
\[\begin{aligned} &\Large\left\{\begin{aligned} & f_{i,n+m} = {\textstyle \frac{n-i}{n+m}}f_{i+1,n+m-1}+{\textstyle\frac{i}{n+m}}f_{i,n+m-1}+{\textstyle\frac{m}{n+m}}f_{i,n+m}+1 \\ & f_{i,n+m-1}= {\textstyle \frac{n-i}{n+m-1}}f_{i+1,n+m-2}+{\textstyle\frac{i-1}{n+m-1}}f_{i,n+m-2}+{\textstyle\frac{m}{n+m-1}}f_{i,n+m}+1 \end{aligned}\right. \\\\ \Large\Rightarrow&\Large\left\{\begin{aligned} & \lambda_i = {\textstyle \frac{n-i}{n+m}}(\lambda_{i+1}+\mu_{i+1})+{\textstyle\frac{i}{n+m}}(\lambda_i+\mu_i)+{\textstyle\frac{m}{n+m}}\lambda_i+1 \\ & \lambda_i+\mu_i= {\textstyle \frac{n-i}{n+m-1}}(\lambda_{i+1}+2\mu_{i+1})+{\textstyle\frac{i-1}{n+m-1}}(\lambda_i+2\mu_i)+{\textstyle\frac{m}{n+m-1}}\lambda_i+1 \end{aligned}\right. \\\\ \Large\Rightarrow&\Large\left\{\begin{aligned} & \lambda_i={\textstyle\frac{i}{n-i}}\mu_i+\lambda_{i+1}+\mu_{i+1}+{\textstyle\frac{n+m}{n-i}} \\ & {\normalsize(n+m-2i+1)}\mu_i={\normalsize(i-n)}\lambda_i+{\normalsize n+m-1}+{\normalsize(n-i)}(\lambda_{i+1}+2\mu_{i+1}) \end{aligned}\right. \\\\ \Large\Rightarrow&\Large\left\{\begin{aligned} & \lambda_i={\normalsize\frac{i\cdot\mu_i}{n-i}+\lambda_{i+1}+\mu_{i+1}+\frac{n+m}{n-i}} \\ & \mu_i={\normalsize\frac{(n-i)\cdot\mu_{i+1}-1}{n+m-i+1}} \end{aligned}\right. \\\\ \end{aligned} \]于是可以 \(\mathcal{O}(n \log M)\),若 \(m = \mathcal{O}(n)\),则可以完全线性 \(\mathcal{O}(n)\)。边界 \(\lambda_n=\mu_n=0\),答案 \(\lambda_0\)。
代码
取模板子
namespace Mod_Int_Class {
template <typename T, typename _Tp>
constexpr bool in_range(_Tp val) {
return std::numeric_limits<T>::min() <= val && val <= std::numeric_limits<T>::max();
}
template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
static constexpr inline bool is_prime(_Tp val) {
if (val < 2) return false;
for (_Tp i = 2; i * i <= val; ++i)
if (val % i == 0)
return false;
return true;
}
template <auto _mod = 19260817, typename T = int, typename S = long long>
class Mod_Int {
static_assert(in_range<T>(_mod), "mod must in the range of type T.");
static_assert(std::is_integral<T>::value, "type T must be an integer.");
static_assert(std::is_integral<S>::value, "type S must be an integer.");
public:
constexpr Mod_Int() noexcept = default;
template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
constexpr Mod_Int(_Tp v) noexcept: val(0) {
if (0 <= S(v) && S(v) < mod) val = v;
else val = (S(v) % mod + mod) % mod;
}
constexpr T const& raw() const {
return this -> val;
}
static constexpr T mod = _mod;
template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
constexpr friend Mod_Int pow(Mod_Int a, _Tp p) {
return a ^ p;
}
constexpr friend Mod_Int sub(Mod_Int a, Mod_Int b) {
return a - b;
}
constexpr friend Mod_Int& tosub(Mod_Int& a, Mod_Int b) {
return a -= b;
}
constexpr friend Mod_Int add(Mod_Int a) { return a; }
template <typename... args_t>
constexpr friend Mod_Int add(Mod_Int a, args_t... args) {
return a + add(args...);
}
constexpr friend Mod_Int mul(Mod_Int a) { return a; }
template <typename... args_t>
constexpr friend Mod_Int mul(Mod_Int a, args_t... args) {
return a * mul(args...);
}
template <typename... args_t>
constexpr friend Mod_Int& toadd(Mod_Int& a, args_t... b) {
return a = add(a, b...);
}
template <typename... args_t>
constexpr friend Mod_Int& tomul(Mod_Int& a, args_t... b) {
return a = mul(a, b...);
}
template <T __mod = mod, typename = std::enable_if_t<is_prime(__mod)>>
static constexpr inline T inv(T a) {
assert(a != 0);
return _pow(a, mod - 2);
}
constexpr Mod_Int& operator + () const {
return *this;
}
constexpr Mod_Int operator - () const {
return _sub(0, val);
}
constexpr Mod_Int inv() const {
return inv(val);
}
constexpr friend inline Mod_Int operator + (Mod_Int a, Mod_Int b) {
return _add(a.val, b.val);
}
constexpr friend inline Mod_Int operator - (Mod_Int a, Mod_Int b) {
return _sub(a.val, b.val);
}
constexpr friend inline Mod_Int operator * (Mod_Int a, Mod_Int b) {
return _mul(a.val, b.val);
}
constexpr friend inline Mod_Int operator / (Mod_Int a, Mod_Int b) {
return _mul(a.val, inv(b.val));
}
template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
constexpr friend inline Mod_Int operator ^ (Mod_Int a, _Tp p) {
return _pow(a.val, p);
}
constexpr friend inline Mod_Int& operator += (Mod_Int& a, Mod_Int b) {
return a = _add(a.val, b.val);
}
constexpr friend inline Mod_Int& operator -= (Mod_Int& a, Mod_Int b) {
return a = _sub(a.val, b.val);
}
constexpr friend inline Mod_Int& operator *= (Mod_Int& a, Mod_Int b) {
return a = _mul(a.val, b.val);
}
constexpr friend inline Mod_Int& operator /= (Mod_Int& a, Mod_Int b) {
return a = _mul(a.val, inv(b.val));
}
template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
constexpr friend inline Mod_Int& operator ^= (Mod_Int& a, _Tp p) {
return a = _pow(a.val, p);
}
constexpr friend inline bool operator == (Mod_Int a, Mod_Int b) {
return a.val == b.val;
}
constexpr friend inline bool operator != (Mod_Int a, Mod_Int b) {
return a.val != b.val;
}
constexpr Mod_Int& operator ++ () {
this -> val + 1 == mod ? this -> val = 0 : ++this -> val;
return *this;
}
constexpr Mod_Int& operator -- () {
this -> val == 0 ? this -> val = mod - 1 : --this -> val;
return *this;
}
constexpr Mod_Int operator ++ (int) {
Mod_Int res = *this;
this -> val + 1 == mod ? this -> val = 0 : ++this -> val;
return res;
}
constexpr Mod_Int operator -- (int) {
Mod_Int res = *this;
this -> val == 0 ? this -> val = mod - 1 : --this -> val;
return res;
}
friend std::istream& operator >> (std::istream& is, Mod_Int<mod, T, S>& x) {
T ipt;
return is >> ipt, x = ipt, is;
}
friend std::ostream& operator << (std::ostream& os, Mod_Int<mod, T, S> x) {
return os << x.val;
}
protected:
T val;
static constexpr inline T _add(T a, T b) {
return a >= mod - b ? a + b - mod : a + b;
}
static constexpr inline T _sub(T a, T b) {
return a < b ? a - b + mod : a - b;
}
static constexpr inline T _mul(T a, T b) {
return static_cast<S>(a) * b % mod;
}
template <typename _Tp, typename = std::enable_if_t<std::is_integral<_Tp>::value>>
static constexpr inline T _pow(T a, _Tp p) {
T res = 1;
for (; p; p >>= 1, a = _mul(a, a))
if (p & 1) res = _mul(res, a);
return res;
}
};
using mint = Mod_Int<>;
using mod_t = mint;
constexpr mint operator ""_m (unsigned long long x) {
return mint(x);
}
constexpr mint operator ""_mod (unsigned long long x) {
return mint(x);
}
}
using namespace Mod_Int_Class;
$\mathcal{O}(n^2)$ 部分分 & 打表
#include <cstdio>
#include <iostream>
#include <limits>
#include <cassert>
#include <vector>
using namespace std;
int n, m;
namespace $1 {
bool check() {
return n <= 1000;
}
void solve() {
vector<vector<mint>> f(n + 1, vector<mint>(n + 1));
for (int i = n - 1; i >= 0; --i) {
vector<mint> k(i + 1), b(i + 1);
k[0] = 1_mod * m / (n + m - i);
b[0] = 1_mod * (n - i) / (n + m - i) * f[i + 1][0] + 1;
for (int j = 1; j <= i; ++j) {
k[j] = 1_mod * m / (n + m - i + j)
+ 1_mod * j / (n + m - i + j) * k[j - 1];
b[j] = 1_mod * (n - i) / (n + m - i + j) * f[i + 1][j]
+ 1_mod * j / (n + m - i + j) * b[j - 1] + 1;
}
f[i][i] = b[i] / (1 - k[i]);
for (int j = 0; j < i; ++j)
f[i][j] = k[j] * f[i][i] + b[j];
}
printf("%d\n", f[0][0].raw());
vector<vector<double>> g(n + 1, vector<double>(n + 1));
for (int i = n - 1; i >= 0; --i) {
vector<double> k(i + 1), b(i + 1);
k[0] = 1. * m / (n + m - i);
b[0] = 1. * (n - i) / (n + m - i) * g[i + 1][0] + 1;
for (int j = 1; j <= i; ++j) {
k[j] = 1. * m / (n + m - i + j)
+ 1. * j / (n + m - i + j) * k[j - 1];
b[j] = 1. * (n - i) / (n + m - i + j) * g[i + 1][j]
+ 1. * j / (n + m - i + j) * b[j - 1] + 1;
}
g[i][i] = b[i] / (1 - k[i]);
for (int j = 0; j < i; ++j)
g[i][j] = k[j] * g[i][i] + b[j];
for (int j = 0; j <= i; ++j)
printf("%.10lf ", g[i][j]);
puts("");
// for (int j = 1; j <= i; ++j)
// printf("%.10lf ", g[i][j] - g[i][j - 1]);
// puts("");
}
}
}
signed main() {
#ifndef XuYueming
freopen("toad.in", "r", stdin);
freopen("toad.out", "w", stdout);
#endif
scanf("%d%d", &n, &m);
if ($1::check()) return $1::solve(), 0;
$yzh::solve();
return 0;
}
$\mathcal{O}(n \log M)$ 正解
namespace $yzh {
const int N = 1000010;
mint lambda[N], mu[N];
void solve() {
lambda[n] = mu[n] = 0;
for (int i = n - 1; i >= 0; --i) {
mu[i] = ((n - i) * mu[i + 1] - 1) / (n + m - i + 1);
lambda[i] = i * mu[i] / (n - i) + lambda[i + 1] + mu[i + 1] + 1_mod * (n + m) / (n - i);
}
printf("%d", lambda[0].raw());
}
}
卡常后
#pragma GCC optimize("Ofast", "inline", "fast-math", "unroll-loops")
#include <cstdio>
const int N = 1000010, mod = 19260817;
int n, m, lambda, mu, Inv[N << 1];
inline int add(int a, int b) { return a >= mod - b ? a + b - mod : a + b; }
signed main() {
freopen("toad.in", "r", stdin);
freopen("toad.out", "w", stdout);
scanf("%d%d", &n, &m), Inv[1] = 1;
for (register int i = 2, *I = Inv + 2; i <= n + m + 1; ++i, ++I)
*I = 1ll * (mod - Inv[mod % i]) * (mod / i) % mod;
for (register int i = 1; i <= n; ++i) {
int t = mu;
mu = 1ll * add(1ll * i * mu % mod, mod - 1) * Inv[m + i + 1] % mod;
lambda = add(1ll * add(n + m, 1ll * (n - i) * mu % mod) * Inv[i] % mod, add(lambda, t));
}
printf("%d", lambda);
return 0;
}