首页 > 其他分享 >【模板】任意模数多项式乘法:三模 NTT

【模板】任意模数多项式乘法:三模 NTT

时间:2024-04-10 21:48:51浏览次数:28  
标签:const int rhs NTT return 模数 vector 三模 modint

前置知识

https://www.cnblogs.com/caijianhong/p/template-crt.html

https://www.cnblogs.com/caijianhong/p/template-fft.html

题目描述

任意模数多项式乘法

solution

首先我们打开 https://blog.miskcoo.com/2014/07/fft-prime-table 这篇文章找到 \(998244353\) 附近的几个质数:

  • \(167772161 = 5 \times 2^ {25} + 1\) 的原根为 $ g = 3$
  • \(469762049 = 7 \times 2^ {26} + 1\) 的原根为 $ g = 3$
  • \(998244353 = 119 \times 2^ {23} + 1\) 的原根为 $ g = 3$
  • \(1004535809 = 479 \times 2^ {21} + 1\) 的原根为 $ g = 3$
  • \(2013265921 = 15 \times 2^ {27} + 1\) 的原根为 $ g = 31$

因为最终的值域为 \(p^2\max(n, m)\) 大约为 \(10^{24}\),可以选择三个质数,使他们乘积 \(>10^{24}\)。比如说可以选 \(998244353\) 与它前面的两个质数,刚好原根都是 \(3\)。

然后以三个质数为模数分别跑三次 NTT 求卷积。

最后 CRT 合并答案,就是正常的中国剩余定理。

实现

点击查看代码

#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
typedef long long LL;
template <class T>
using must_int = enable_if_t<is_integral<T>::value, int>;
template <unsigned umod>
struct modint {
  static constexpr int mod = umod;
  unsigned v;
  modint() : v(0) {}
  template <class T, must_int<T> = 0>
  modint(T x) {
    x %= mod;
    v = x < 0 ? x + mod : x;
  }
  modint operator+() const { return *this; }
  modint operator-() const { return modint() - *this; }
  friend int raw(const modint &self) { return self.v; }
  friend ostream &operator<<(ostream &os, const modint &self) {
    return os << raw(self);
  }
  modint &operator+=(const modint &rhs) {
    v += rhs.v;
    if (v >= umod) v -= umod;
    return *this;
  }
  modint &operator-=(const modint &rhs) {
    v -= rhs.v;
    if (v >= umod) v += umod;
    return *this;
  }
  modint &operator*=(const modint &rhs) {
    v = 1ull * v * rhs.v % umod;
    return *this;
  }
  modint &operator/=(const modint &rhs) {
    assert(rhs.v);
    return *this *= qpow(rhs, mod - 2);
  }
  template <class T, must_int<T> = 0>
  friend modint qpow(modint a, T b) {
    modint r = 1;
    for (; b; b >>= 1, a *= a)
      if (b & 1) r *= a;
    return r;
  }
  friend modint operator+(modint lhs, const modint &rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint &rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint &rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint &rhs) { return lhs /= rhs; }
  bool operator==(const modint &rhs) const { return v == rhs.v; }
  bool operator!=(const modint &rhs) const { return v != rhs.v; }
};
int glim(int n) { return 1 << (32 - __builtin_clz(n - 1)); }
int bitctz(int n) { return __builtin_ctz(n); }
template <class mint>
void ntt(vector<mint> &a, int op) {
  static vector<mint> wns;
  if (wns.empty()) {
    int exp = mint::mod - 1;
    while (exp % 2 == 0) wns.push_back(qpow(mint(3), exp >>= 1));
  }
  int n = a.size();
  for (int i = 1, r = 0; i < n; i++) {
    r ^= n - (1 << (bitctz(n) - bitctz(i) - 1));
    debug("i = %d, r = %d\n", i, r);
    if (r > i) swap(a[r], a[i]);
  }
  vector<mint> w(n);
  for (int k = 1, len = 2; len <= n; len <<= 1, k <<= 1) {
    mint wn = wns[bitctz(k)];
    for (int i = raw(w[0] = 1); i < k; i++) w[i] = w[i - 1] * wn;
    for (int i = 0; i < n; i += len) {
      for (int j = 0; j < k; j++) {
        mint x = a[i + j], y = a[i + j + k] * w[j];
        a[i + j] = x + y;
        a[i + j + k] = x - y;
      }
    }
  }
  if (op == -1) {
    reverse(a.begin() + 1, a.end());
    mint iz = mint(1) / n;
    for (int i = 0; i < n; i++) a[i] *= iz;
  }
}
template <class mint>
vector<mint> convolution(vector<mint> a, vector<mint> b) {
  int rlen = a.size() + b.size() - 1, len = glim(rlen);
  a.resize(len), ntt(a, 1);
  b.resize(len), ntt(b, 1);
  for (int i = 0; i < len; i++) a[i] *= b[i];
  ntt(a, -1), a.resize(rlen);
  return a;
}

主要的部分在这里,实现比较逆天,供参考。其中那些巨大数字由 python 程序生成:

#!/bin/env python3
n = 3
a = [998244353, 1004535809, 469762049]
M = a[0] * a[1] * a[2]
m = [M // a[i] for i in range(n)]
t = [pow(m[i], -1, a[i]) for i in range(n)] # 模意义下逆元
coe = [m[i] * t[i] % M for i in range(n)]
typ = "unsigned __int128"
prog = f"""
{typ} operator\"\"_ubi(const char* str) {{
  int len = strlen(str);
  {typ} x = 0;
  for (int i = 0; i < len; i++) x = x * 10 + str[i] - '0';
  return x;
}}
"""
print(prog)
print(*[f"({coe[i]}_ubi * raw(ret{i + 1}[i]) % {M}_ubi)" for i in range(n)], sep = " + ")

a = b = c = 1463
func = f"((401276874248923522479908641 * {a} % 471064322751194440790966273) +(185347017962817624218731910 * {b} % 471064322751194440790966273)  +(355504753290647734883291996 * {c} % 471064322751194440790966273) ) % M"
print(eval(func)) # 验算

注意要先彻底模完 \(M\) 再去模 \(p\)。

__int128 operator""_ubi(const char *str) {
  int len = strlen(str);
  __int128 x = 0;
  for (int i = 0; i < len; i++) x = x * 10 + str[i] - '0';
  return x;
}
void print(__int128 x) {
  const LL base = 1e18;
#ifdef LOCAL
  if (x <= base)
    cerr << (LL)x << endl;
  else
    cerr << (LL)(x / base) << (LL)(x % base) << endl;
#endif
}
vector<int> convolution_fun(const vector<int> &a, const vector<int> &b, int p) {
  static constexpr unsigned mods[] = {998244353, 1004535809, 469762049};
  typedef modint<mods[0]> mint1;
  typedef modint<mods[1]> mint2;
  typedef modint<mods[2]> mint3;
  int rlen = a.size() + b.size() - 1, len = glim(rlen);
  vector<int> ret(rlen);
  auto ret1 = convolution(vector<mint1>(a.begin(), a.end()),
                          vector<mint1>(b.begin(), b.end()));
  auto ret2 = convolution(vector<mint2>(a.begin(), a.end()),
                          vector<mint2>(b.begin(), b.end()));
  auto ret3 = convolution(vector<mint3>(a.begin(), a.end()),
                          vector<mint3>(b.begin(), b.end()));
  for (int i = 0; i < rlen; i++) {
    ret[i] = ((401276874248923522479908641_ubi * raw(ret1[i]) %
               471064322751194440790966273_ubi) +
              (185347017962817624218731910_ubi * raw(ret2[i]) %
               471064322751194440790966273_ubi) +
              (355504753290647734883291996_ubi * raw(ret3[i]) %
               471064322751194440790966273_ubi)) %
             471064322751194440790966273_ubi % p;
  }
  return ret;
}
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);
#endif
  int n, m, p;
  cin >> n >> m >> p;
  vector<int> a(n + 1);
  for (int i = 0; i <= n; i++) cin >> a[i];
  vector<int> b(m + 1);
  for (int i = 0; i <= m; i++) cin >> b[i];
  for (int x : convolution_fun(a, b, p)) cout << x << " ";
  cout << endl;
  return 0;
}

标签:const,int,rhs,NTT,return,模数,vector,三模,modint
From: https://www.cnblogs.com/caijianhong/p/18127492

相关文章

  • JSX.Element 和 React.ElementType的区别是什么?
    在React和TypeScript中,JSX.Element和React.ElementType代表了两种不同的概念:JSX.Element:JSX.Element是一个类型,表示由JSX编译后生成的实际React元素对象。当你在React应用中使用JSX编写组件时,每一个JSX表达式都会编译为一个JSX.Element对象。例如:constMyComponent:React.......
  • 公钥私钥和模数指数相互转换
    pem格式公钥私钥读取解析公钥私钥pem格式加解密示例根据私钥pem生成模数和指数NED生成模数和指数NED的公钥私钥NED导出pem格式#include<SylixOS.h>#include<stdio.h>#include<crypto.h>#include<mbedtls/ssl.h>#include<mbedtls/platform.h>......
  • WPF GroupBox Expander ExpandDirection="Down" Expander.HeaderTemplate Expander.C
    //xaml<Windowx:Class="WpfApp43.MainWindow"xmlns="http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x="http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d="http://schemas.mic......
  • JAVA注解-ElementType详解
    ava中元注解(用来标识注解的注解)有四个: @Retention@Target@Document@Inherited; @Retention:注解的保留位置@Retention(RetentionPolicy.SOURCE)   //注解仅存在于源码中,在class字节码文件中不包含@Retention(RetentionPolicy.......
  • 备战蓝桥杯第三模块之二分模版+二分练题
    前言因本系列是为了蓝桥杯前几天快速过知识点所准备,所以有些部分会很简洁模版整数二分intMid(inta,intb){intl=1;intr=1e9+10;while(l<r){intmid=l+r>>2;if(a/mid<=b)//这儿一般情况下看需要用mid判断什么条件r=mid;elsel=......
  • devexpress schedulerControl Gantt View 使用
    usingSystem;usingSystem.Collections.Generic;usingSystem.ComponentModel;usingSystem.Data;usingSystem.Drawing;usingSystem.Text;usingSystem.Linq;usingSystem.Windows.Forms;usingDevExpress.XtraEditors;usingF.Studio.DevExpressUI;usingDevExp......
  • 【做题纪要】衡实初三模拟测试三
    本来以为打完最多能拿\(120\)分所以没打,事实上自己做法能拿\(170\)分也就能到rk1,血亏本次模拟赛不知道怎么拼出来的,一共4道题有3道题需要文件输出,最后出现了9道题的题解都没写代码,凑合着看,思路肯定是能过的(吧?)网格图这个题一眼过去可以用暴力bfs来打,复杂度\(O(n^2k^2)\)可......
  • 数据规模:如何处理大规模数据集
    1.背景介绍1.1数据的爆炸式增长随着互联网的普及和物联网的发展,数据量呈现出爆炸式增长。据统计,全球每天产生的数据量已经达到2.5亿GB,预计到2025年,全球数据总量将达到175ZB。这些数据包括社交媒体、电子商务、物联网设备、科学研究等各个领域的数据。如何有效地处理和分......
  • 【代码更新】SPI时序——AD数模数转换
    【代码更新】SPI时序——AD数模数转换AD芯片手册:https://www.ti.com.cn/cn/lit/ds/symlink/ads8558.pdf?ts=1709473143911&ref_url=https%253A%252F%252Fwww.ti.com.cn%252Fproduct%252Fcn%252FADS8558下图是芯片需要配置信号的时序图寄存器配置时序: 项目硬件连接图:......
  • python列表、集合、字典转换要点以及查找速度区别,如何在大规模数据中实现快速查找
    1.list与set的区别与优缺点:循环速度:list最适合做固定长度的遍历,而且有顺序。set是无序的,list转换为set会乱序,若用set给list去重,转化为list时须用原list的index排序:new_list.sort(key=old_list.index)。所以这种循环尽量用list查询速度:set>list,set查询的key都是ha......