简要题意
给你一个 \(n\) 个元素的集合,它由前 \(n\) 个正整数构成。你需要求出它有多少个非空子集,满足若 \(x\) 在这个子集中, \(2x,3x\) 不能在子集中。由于答案可能很大,你只需要对 \(10^9+1\) 取模即可。
\(1 \leq n \leq 10^5\)
思路
这道题的构造思想非常清奇。我们可以构造一个矩阵:
\[A=\begin{bmatrix} 1 & 2 & 4 & 8 & \cdots\\ 3 & 6 & 12 & 24 & \cdots\\ 9 & 18 & 36 & 72 & \cdots\\ 27 & 54 & 108 & 216 & \cdots\\ 81 & 162 & 324 & 648 & \cdots \\ \cdots & \cdots & \cdots & \cdots & \cdots \end{bmatrix} \]具体来说:
\[A_{i,j}=\begin{cases} & 1 & i=1,j=1\\ & 3\cdot f_{i-1,j} & i\neq 1,j=1\\ & 2\cdot f_{i,j-1} & \text{otherwise} \end{cases} \]这样子我们就将原问题转化成了给出一个矩阵,如果你选择 \((i,j)\),就不能选择 \((i-1,j)\) 和 \((i,j-1)\),求方案数。这个问题可以使用状压 DP 解决。
我们设 \(f_{i,S}\) 为考虑到第 \(i\) 行,这一行选择 \(S\) 中的元素的方案数。不难发现:
\[f_{i,S}=\begin{cases} & \operatorname{valid}(S) & i=1 \\ & \sum\limits_{T\cup S=\emptyset,\operatorname{valid}(T)}{f_{i-1,T}} & \text{otherwise} \end{cases} \]其中 \(\operatorname{valid}(S)\) 是指选择该行中 \(S\) 中的元素是否合法,也就是两两是否相邻。用状态压缩的话可以简单地这样实现:
\[\operatorname{valid}(S)=S\&(S>>1)?0:1 \]当然左移也可以。其实原理就是将原本一样的位错开,相邻的进行与运算。
最后注意这个表不是所有元素都会覆盖到(具体来说,只会覆盖到 \(\forall i,j\in \mathbb{N},2^{i}3^{j}\))。所以我们如果遇到了一个没有被之前覆盖到的元素,我们需要将它设为 \(f_{1,1}\) 重新生成矩阵 \(A\),并重新 DP,最后按照乘法原理(因为这些都可以同时选)将结果累乘。
然后这道题就做完了。最后提醒大家一句,位运算优先级比较低,建议大家勤添括号。
代码
点击查看代码
#include <bits/stdc++.h>
#define int long long
#define valid(x) (x&(x>>1)?0:1)
using namespace std;
const int mod = 1e9+1;
int M(const int x){return (x%mod+mod)%mod;}
const int N = 1e5+5;
int n,vis[N],a[25][25],col[N],f[25][1000005],final,ans=1;
inline void init(int x){
for(int i=1;i<=11;i++){
if(i==1) a[i][1]=x;
else a[i][1]=a[i-1][1]*3;
if(a[i][1]>n) break;
vis[a[i][1]]=1;col[i]=1;final=i;
for(int j=2;j<=18;j++){
a[i][j]=a[i][j-1]<<1;
if(a[i][j]>n) break;
col[i]=j;vis[a[i][j]]=1;
}
}
}
inline int dp(int x){
for(int i=0;i<(1<<col[1]);i++){
f[1][i]=valid(i);
}
for(int i=2;i<=final;i++){
for(int j=0;j<(1<<col[i]);j++){
if(!valid(j)) continue;
f[i][j]=0;
for(int k=0;k<(1<<col[i-1]);k++){
if(valid(k) && ((k&j) == 0)) f[i][j]=M(f[i][j]+f[i-1][k]);
}
}
}
int ret=0;
for(int i=0;i<(1<<col[final]);i++) ret=M(ret+f[final][i]);
return ret;
}
signed main(){
cin>>n;
for(int i=1;i<=n;i++){
if(vis[i]) continue;
init(i);ans=M(ans*dp(i));
}
cout<<ans;
return 0;
}