好题,考场上想到做法了,没写出来,被薄纱了,记录一下。
主要是做的比较顺一下就想到了。
我们先转换一下 \(f\) 函数
\(f(S,a,b)=\sum\limits_{i=1}^k S_i\times (a^{k-i}-b^{k-i})\)
我们可以发现对于位数 \(>2\) 的,一定满足 \(a\le \frac {(x+1)}2\) ,因为如果不是的话 \(a^2-(a-1)^2=2a-1\) (当 \(b=a-1\) 时,显然是能达到的最小值),移一下项就有了。
所以我们发现当位数 \(>2\) 时, \(a\) 的实际有效位数很少,考虑分开考虑,这是经典套路,因为其实 \(a^{k-i}-b^{k-i}\) 是指数级增长,有点像进制,对于进制的题目,分开来计算是经典套路。
- 考虑位数 \(=2\) 的情况
就是求 \(S_1\times (a-b)=x\) 的情况有多少种,枚举每个 \(x\) 的因数,考虑枚举 \(S_1\) ,枚举完之后对于 \(b<10\) 的情况暴力整整,然后对于 \(b>=10\) 的情况再直接计算就好了(但是考场上枚举了 \(a-b\) 细节没处理好寄了),因为 \(S_k\) 是可以随便填的(因为 \(a^0-b^0=0\))
这一部分时间复杂度是 \(O(10\sqrt x)\)
- 考虑位数 \(>2\) 的情况
我们可以先枚举一下位数,因为显然位数是 \(\log x\) 级别的,然后暴力枚举 \(x\) ,再枚举有效的 \(b\) ,这部分的复杂度是一个常数极小的 \(x\sqrt x\) 。
然后如果我们再去 \(dp\) 时间复杂度承受不了,找性质。
\(1\times (a^{k}+b^{k})>min(10,b-1)(a^{k-1}-b^{k-1})\)
我们钦定右边最小值是 \(b-1\)
移一下项得到
\((a-b+1)a^{-1}-b^{k-1}\) 这个式子显然是大于 \(0\) 的,因为 \(a>b\) 。
然后如果右边最小值不是 \(b-1\) ,但是因为更大的 \(b-1\) 都没有左边大,那 \(10\) 显然也没有左边大。
那么也就是说对于前面的位数如果能填 \(x\) ,那么你就一定要填 \(x\) ,不然你后面无论填什么都凑不够前面损失的贡献,就像进制一样。
所以我们的 \(dp\) ,就转换成了一个判断问题,从 \(1\) 往 \(k\) 扫就好了。
时间复杂度 \(O(x\sqrt x\log^2 x)\) 常数极小,可以通过,当然实际上最开始的枚举位数不需要,因为对于一组 \(a,b\) ,他的可能贡献的答案位数是固定的。理论上可以做到 \(x\sqrt x\log x\) ,反正常数极小。
点击查看代码
#include<bits/stdc++.h>
#define int long long
typedef long long LL;
using namespace std;
const int MAXN=2e5+10,MODD=998244353;
LL n,x;
LL ans;
void calc(int q) {
for(LL a=q+1+(x/q);a<=n;++a) {
LL b=a-x/q;
if(b>=10) break;
ans=(ans+min(10ll,b))%MODD;
}
if(n-(x/q)>=10) {
ans=(ans+(10ll*(n-(x/q)-10+1))%MODD)%MODD;
}
}
void tp1() {
for(int i=1;i<=x;++i) {
if(x%i==0) {
if(i<10) calc(i);
}
}
}
LL pw[MAXN][23];
void tp2() {
for(int i=1;i<=x;++i) {
pw[i][2]=i;
}
for(int i=3;i<=22;++i) { //枚举位数
for(int j=1;j<x;++j) {//枚举 a
if(j>n) break;
LL res=1;
for(int q=2;q<=i;++q) {
res=res*j;
if(res>1e13) break;
}
if(res>1e13) break;
pw[j][i]=res;
for(int q=j-1;q>=1;--q) {//枚举 b
if(pw[j][i]-pw[q][i]>x) break;
LL pp=x,sf=0;
for(int w=i;w>=2;--w) {
if(pp/(pw[j][w]-pw[q][w])>=min(10ll,q)) sf=1;
pp%=(pw[j][w]-pw[q][w]);
}
if(sf||pp) continue;
ans+=min(q,10ll);
ans%=MODD;
}
}
}
}
signed main () {
scanf("%lld%lld",&n,&x);
tp1();
LL ls_ans=ans;
tp2();
printf("%lld\n",ans);
return 0;
}