题目
C - Sum of Number of Divisors of Product
定义一个合法的序列为:长度在 \([1, n]\) 间,且每个元素均为 \([1, m]\) 中整数的序列。
定义一个合法序列的权值为:令 \(X\) 为序列中所有元素的乘积,则权值为 \(X\) 的约数个数。
对所有 \(\sum_{k = 1}^n m^k\) 个合法序列,求它们的权值之和对 \(998244353\) 取模的值。
\(1 \leq n \leq 10^{18}\),\(1 \leq m \leq 16\)。
题解
看到约数个数,我们联想到通过分解质因数求约数个数的公式。
发现 \(m \leq 16\)。因此,有用的质因数只有 \(6\) 个,即 \(2, 3, 5, 7, 11, 13\)。
而即便如此,直接使用公式转移也非常困难。因此我们不妨思考一下这个公式的本质。
考虑对于一个数 \(x\),令 \(x = 2^{a_0} 3^{a_1} 5^{a_2} 7^{a_3} 11^{a_4} 13^{a_5}\),其中每个 \(a_i\) 均为非负整数。
\(x\) 的每个约数 \(y\) 都能被表示成 \(y = 2^{b_0} 3^{b_1} 5^{b_2} 7^{b_3} 11^{b_4} 13^{b_5}\),其中 \(0 \leq b_i \leq a_i\)。
根据质因数分解,我们不妨将 \(x\) 视作以下 \(6\) 个集合的组合:
- \(A_0 = \{ 1, 2, 2^2, 2^3, \cdots, 2^{a_0} \}\);
- \(A_1 = \{ 1, 3, 3^2, 3^3, \cdots, 3^{a_1} \}\);
- \(A_2 = \{ 1, 5, 5^2, 5^3, \cdots, 5^{a_2} \}\);
- \(A_3 = \{ 1, 7, 7^2, 7^3, \cdots, 7^{a_3} \}\);
- \(A_4 = \{ 1, 11, 11^2, 11^3, \cdots, 11^{a_4} \}\);
- \(A_5 = \{ 1, 13, 13^2, 13^3, \cdots, 13^{a_5} \}\)。
即,\(A_i\) 是 \(a_i + 1\) 个本质不同的元素的组合。称所有 \(A_i\) 为 \(x\) 对应的集合。
而 \(x\) 的每个约数都可以表示成:从每个 \(A_i\) 的 \(a_i + 1\) 个元素中抽取一个,将它们乘起来。
根据乘法原理,\(x\) 的约数个数即为 \((a_0 + 1)(a_1 + 1)(a_2 + 1)(a_3 + 1)(a_4 + 1)(a_5 + 1)\),这便是该公式的由来。
通过以上推导,我们不难发现:数 \(x\) 的约数个数等价于从每个 \(A_i\) 中抽取一个数,这么做的方案数。
对一个固定的合法序列,考虑从前往后依次向序列中加入元素,维护其乘积 \(X\) 对应的 \(6\) 个集合。
注意到最开始每个集合中都有且仅有 \(1\) 个元素。
发现在序列后加入一个数 \(s = 2^{c_0} 3^{c_1} 5^{c_2} 7^{c_3} 11^{c_4} 13^{c_5}\),相当于向每个集合 \(A_i\) 中加入 \(c_i\) 个元素。
由本题解的开头,添加完所有元素后再统计答案是困难的。而这么转化相当于为加入过程中统计答案创造了条件。我们不妨在加入每个元素时,若当前的集合还没有进行抽取操作,则决策该元素是否进行抽取操作。
据此,我们可以设计出如下状态压缩 DP:
令 \(f(i, S)\) 表示考虑长度为 \(i\) 的序列,若已经进行抽取操作的集合的下标构成集合 \(S\),则这么做的方案数。
转移时,决策在序列中的下一位填什么数,以及每个集合 \(A_i\) 是否进行决策。即,转移如下:
其中 \(a(j, k)\) 表示将 \(j\) 质因数分解的表达式中,第 \(k\) 个质因数(\(k \in [0, 5]\))的次数。
即,我们枚举要进行抽取操作的集合 \(S \setminus T\) 以及这一位要填的数 \(j\) 进行转移。
答案即为所有 \(f(i, \{0, 1, 2, 3, 4, 5\})\) 的和,其中 \(i\) 为正整数。
此时时间复杂度为 \(\Theta(3^6 \times 6mn)\)。由于 \(n \leq 10^{18}\),显然无法接受。
但我们发现,这个转移符合矩阵乘法优化的形式。在矩阵乘法时,我们可以记录每个可能的 \(S\) 的 \(f\) 值以及答案的前缀和。
使用矩阵乘法优化 DP 即可做到 \(\Theta((2^6)^3 \cdot \log n)\)。
代码
这里是朴素 DP 的代码
#include <cstdio>
const int N=1e5+3;
const int P=6,mod=998244353;
const int a[17][P]={{},
{0, 0, 0, 0, 0, 0},
{1, 0, 0, 0, 0, 0},
{0, 1, 0, 0, 0, 0},
{2, 0, 0, 0, 0, 0},
{0, 0, 1, 0, 0, 0},
{1, 1, 0, 0, 0, 0},
{0, 0, 0, 1, 0, 0},
{3, 0, 0, 0, 0, 0},
{0, 2, 0, 0, 0, 0},
{1, 0, 1, 0, 0, 0},
{0, 0, 0, 0, 1, 0},
{2, 1, 0, 0, 0, 0},
{0, 0, 0, 0, 0, 1},
{1, 0, 0, 1, 0, 0},
{0, 1, 1, 0, 0, 0},
{4, 0, 0, 0, 0, 0}
};
int m,f[N][1<<P];
long long n;
int main(){
// freopen("SNDP.in","r",stdin);
// freopen("SNDP.out","w",stdout);
int i,j,k,p,q,s1,ans=0;
scanf("%lld%d",&n,&m);
for(i=0;i<(1<<P);i++) f[0][i]=1;
for(i=1;i<=n;i++){
for(j=1;j<=m;j++)
for(p=0;p<(1<<P);p++)
for(q=p;true;q=(q-1)&p){
for(k=0,s1=1;k<P;k++)
if(q>>k&1) s1*=a[j][k];
f[i][p]=(f[i][p]+(long long)s1*f[i-1][p^q]%mod)%mod;
if(q==0) break;
}
ans=(ans+f[i][(1<<P)-1])%mod;
}
printf("%d",ans);
// fclose(stdin);
// fclose(stdout);
return 0;
}
这里是矩阵乘法优化 DP 的代码
#include <cstdio>
const int P=6,mod=998244353;
const int A[17][P]={{},
{0, 0, 0, 0, 0, 0},
{1, 0, 0, 0, 0, 0},
{0, 1, 0, 0, 0, 0},
{2, 0, 0, 0, 0, 0},
{0, 0, 1, 0, 0, 0},
{1, 1, 0, 0, 0, 0},
{0, 0, 0, 1, 0, 0},
{3, 0, 0, 0, 0, 0},
{0, 2, 0, 0, 0, 0},
{1, 0, 1, 0, 0, 0},
{0, 0, 0, 0, 1, 0},
{2, 1, 0, 0, 0, 0},
{0, 0, 0, 0, 0, 1},
{1, 0, 0, 1, 0, 0},
{0, 1, 1, 0, 0, 0},
{4, 0, 0, 0, 0, 0}
};
struct Matrix{
int m,n,a[(1<<P)+3][(1<<P)+3];
}mul,ans;
int m; long long n;
Matrix operator*(Matrix a,Matrix b){
int i,j,k,s1;
Matrix ans={a.m,b.n,{}};
for(i=0;i<a.m;i++)
for(j=0;j<b.n;j++)
for(k=0;k<a.n;k++){
s1=(long long)a.a[i][k]*b.a[k][j]%mod;
ans.a[i][j]=(ans.a[i][j]+s1)%mod;
}
return ans;
}
void get_pow(Matrix& ans,Matrix a,long long b){
while(b>0){
if(b&1) ans=ans*a;
a=a*a; b>>=1;
}
}
int main(){
// freopen("SNDP.in","r",stdin);
// freopen("SNDP.out","w",stdout);
int i,j,k,t,s1;
scanf("%lld%d",&n,&m);
ans.m=1,ans.n=mul.m=mul.n=(1<<P)+1;
for(i=0;i<(1<<P);i++)
ans.a[0][i]=1;
mul.a[(1<<P)-1][(1<<P)]=1;
mul.a[1<<P][(1<<P)]=1;
for(t=1;t<=m;t++)
for(i=0;i<(1<<P);i++)
for(j=i;true;j=(j-1)&i){
for(k=0,s1=1;k<P;k++)
if(j>>k&1) s1*=A[t][k];
mul.a[i^j][i]=(mul.a[i^j][i]+s1)%mod;
if(j==0) break;
}
get_pow(ans,mul,n);
s1=(ans.a[0][1<<P]+ans.a[0][(1<<P)-1])%mod;
printf("%d",(s1+mod-1)%mod);
// fclose(stdin);
// fclose(stdout);
return 0;
}