比较有趣的一个题。
考虑一个弱化版,算 colorful 序列个数。有一个 \(O(nK)\) 的 dp,大概就是设 \(f_{i, j}\) 为考虑到第 \(i\) 个数,当前最长互不相同后缀长度为 \(j\)。
转移考虑若往后面填一个在这 \(j\) 个数以外的数就能使 \(j \gets j + 1\),因此 \(f_{i, j + 1} \gets f_{i - 1, j} \times (K - j)\);或者填这 \(j\) 个数中的一个,有 \(\forall k \in [1, j], f_{i, k} \gets f_{i - 1, j}\),这个前缀和优化即可。
我们的目标是让中途 \(j\) 至少一次变为 \(K\),不妨容斥变为让中途的 \(j\) 都 \(\le K - 1\),变成了算不 colorful 序列的个数,再用总序列数减去即可。
回到原题,仍然考虑容斥,把答案转化成计算 \(a\) 在不 colorful 序列中出现次数,再用总出现次数 \(K^{n - m} \times (n - m + 1)\) 减去它。
先特判掉 \(a\) 已经是 colorful 序列的情况。如果 \(a\) 中的数互不相同,一个观察是任意一个值域 \([1, K]\) 长度为 \(m\) 且数互不相同的序列答案都一样,所以可以转化成算不 colorful 序列的“最长互不相同后缀长度 \(\ge m\) 的位置数”的和,结果再除以值域 \([1, K]\) 长度为 \(m\) 且数互不相同的序列的个数(即 \(\frac{K!}{(K - m)!}\))。
算不 colorful 序列的“最长互不相同后缀长度 \(\ge m\) 的位置数”的和,可以在上文 dp 的基础上再记一个 \(g_{i, j}\) 表示考虑到第 \(i\) 个数且当前最长互不相同后缀长度为 \(j\),要算的那个东西的和。有和 \(f\) 一样的转移:
- \(g_{i, j + 1} \gets (g_{i - 1, j} + [j + 1 \ge m] f_{i - 1, j}) \times (K - j)\);
- \(\forall k \in [1, j], g_{i, k} \gets g_{i - 1, j} + [k \ge m] f_{i - 1, j}\)。
\(g\) 的转移也可以前缀和优化。
然后来考虑 \(a\) 中有相同元素的情况。一个观察是左右侧填数是互不影响,相互独立的(因为不可能存在跨过 \(a\) 的子段是 \(1 \sim K\) 的排列)。既然两边独立那么分别 dp 即可。
设 \(f_{i, j}\) 为 \(a\) 左侧序列长度为 \(i\),当前最长互不相同前缀为 \(j\) 的方案数(每次往最左边加数),设 \(g_{i, j}\) 为 \(a\) 右侧序列长度为 \(i\),当前最长互不相同后缀为 \(j\) 的方案数(每次往最右边加数)。再设 \(a\) 最长互不相同前缀长度为 \(p\),最长互不相同后缀长度为 \(q\),初值有 \(f_{0, p} = g_{0, q} = 1\)。转移和上文 \(f\) 的转移一致。答案即为 \(\sum\limits_{i = 0}^{n - m} (\sum\limits_{j = 1}^{K - 1} f_{i, j}) \times (\sum\limits_{j = 1}^{K - 1} g_{n - m - i, j})\)。
这样这题就做完了。总时间复杂度 \(O(nK)\)。
code
// Problem: F - Colorful Sequences
// Contest: AtCoder - AtCoder Regular Contest 100
// URL: https://atcoder.jp/contests/arc100/tasks/arc100_d
// Memory Limit: 1024 MB
// Time Limit: 2000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mkp make_pair
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef double db;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 25050;
const int N = 25000;
const int maxm = 410;
const ll mod = 1000000007;
inline ll qpow(ll b, ll p) {
ll res = 1;
while (p) {
if (p & 1) {
res = res * b % mod;
}
b = b * b % mod;
p >>= 1;
}
return res;
}
ll n, m, K, a[maxn], fac[maxn], ifac[maxn];
bool vis[maxn];
namespace Sub1 {
ll f[maxn][maxm], g[maxn][maxm];
void solve() {
f[1][1] = K;
g[1][1] = (m == 1 ? K : 0);
for (int i = 2; i <= n; ++i) {
ll sf = 0, sg = 0;
for (int j = K - 1; j; --j) {
sf = (sf + f[i - 1][j]) % mod;
sg = (sg + g[i - 1][j]) % mod;
f[i][j] = sf;
g[i][j] = sg;
if (j >= m) {
g[i][j] = (g[i][j] + sf) % mod;
}
}
for (int j = 1; j < i && j < K; ++j) {
// printf("g[%d][%d] = %lld\n", i - 1, j, g[i - 1][j]);
// for (int k = 1; k <= j; ++k) {
// f[i][k] = (f[i][k] + f[i - 1][j]) % mod;
// g[i][k] = (g[i][k] + g[i - 1][j] + (k >= m ? f[i - 1][j] : 0)) % mod;
// }
f[i][j + 1] = (f[i][j + 1] + f[i - 1][j] * (K - j)) % mod;
g[i][j + 1] = (g[i][j + 1] + (g[i - 1][j] + (j + 1 >= m ? f[i - 1][j] : 0)) * (K - j)) % mod;
}
// for (int j = 1; j < K; ++j) {
// printf("%d %d %lld %lld\n", i, j, f[i][j], g[i][j]);
// }
}
ll ans = 0;
for (int i = 1; i < K; ++i) {
// printf("%d %lld\n", i, g[n][i]);
ans = (ans + g[n][i]) % mod;
}
ans = ans * fac[K - m] % mod * ifac[K] % mod;
// printf("ans: %lld\n", ans);
ans = (qpow(K, n - m) * (n - m + 1) % mod - ans + mod) % mod;
printf("%lld\n", ans);
}
}
namespace Sub2 {
ll f[maxn][maxm], g[maxn][maxm];
bool vis[maxn];
void solve() {
int c1 = 0, c2 = 0;
for (int i = 1; i <= m; ++i) {
if (vis[a[i]]) {
break;
}
++c1;
vis[a[i]] = 1;
}
mems(vis, 0);
for (int i = m; i; --i) {
if (vis[a[i]]) {
break;
}
++c2;
vis[a[i]] = 1;
}
f[0][c1] = 1;
for (int i = 1; i <= n; ++i) {
ll s = 0;
for (int j = K - 1; j; --j) {
s = (s + f[i - 1][j]) % mod;
f[i][j] = s;
}
for (int j = 1; j < K; ++j) {
f[i][j + 1] = (f[i][j + 1] + f[i - 1][j] * (K - j)) % mod;
}
}
g[0][c2] = 1;
for (int i = 1; i <= n; ++i) {
ll s = 0;
for (int j = K - 1; j; --j) {
s = (s + g[i - 1][j]) % mod;
g[i][j] = s;
}
for (int j = 1; j < K; ++j) {
g[i][j + 1] = (g[i][j + 1] + g[i - 1][j] * (K - j)) % mod;
}
}
ll ans = 0;
for (int i = 0; i <= n - m; ++i) {
ll x = 0, y = 0;
for (int j = 1; j < K; ++j) {
x = (x + f[i][j]) % mod;
y = (y + g[n - m - i][j]) % mod;
// printf("%d %d %lld\n", i, j, f[i][j]);
}
ans = (ans + x * y) % mod;
}
ans = (qpow(K, n - m) * (n - m + 1) % mod - ans + mod) % mod;
printf("%lld\n", ans);
}
}
void solve() {
fac[0] = 1;
for (int i = 1; i <= N; ++i) {
fac[i] = fac[i - 1] * i % mod;
}
ifac[N] = qpow(fac[N], mod - 2);
for (int i = N - 1; ~i; --i) {
ifac[i] = ifac[i + 1] * (i + 1) % mod;
}
scanf("%lld%lld%lld", &n, &K, &m);
bool fl = 1;
for (int i = 1; i <= m; ++i) {
scanf("%lld", &a[i]);
fl &= (!vis[a[i]]);
vis[a[i]] = 1;
}
for (int i = 1; i <= m - K + 1; ++i) {
for (int j = 1; j <= K; ++j) {
vis[j] = 0;
}
for (int j = i; j < i + K; ++j) {
vis[a[j]] = 1;
}
bool fl = 1;
for (int j = 1; j <= K && fl; ++j) {
fl &= vis[j];
}
if (fl) {
printf("%lld\n", qpow(K, n - m) * (n - m + 1) % mod);
return;
}
}
if (fl) {
Sub1::solve();
} else {
Sub2::solve();
}
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}