免责声明
测的比较随意,有吹黑哨的嫌疑。看一下就好了。
测试对象
\(n=1000\),测试 \(n\times n\) 矩阵乘 \(n\times n\) 的 atcoder::modint998244353
的矩阵乘法速度。
矩阵数字生成:mt19937 rng{1}
正确结果矩阵元素异或和:6597111
编译选项: g++ $< -o $@ -O2 -std=c++14 -static
编译器版本:g++ (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
系统:Linux LAPTOP-VTBPQCQP 5.15.153.1-microsoft-standard-WSL2 #1 SMP Fri Mar 29 23:14:13 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux
测试代码
#!/bin/env pypy3
import os
src_head = """#include <bits/stdc++.h>
#include "atcoder/modint"
using namespace std;
using LL = long long;
using mint = atcoder::modint998244353;
int n = 1000;
"""
src_tail = """
int main() { mt19937 rng{1}; init(); for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { a[i][j] = rng(); } } for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { b[i][j] = rng(); } } mul(); int ret = 0; for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { ret ^= c[i][j].val(); } } cout << ret << endl; return 0; } """
def test(code, name, order):
with open("main.cpp", "w") as file:
print(src_head, file=file)
print(code.format(order[0], order[1], order[2]), file=file)
print(src_tail, file=file)
print(name, order)
os.system("./test.sh")
array_array = """
mint a[1000][1000], b[1000][1000], c[1000][1000];
void init() {{ }}
void mul() {{
for (int {0} = 0; {0} < n; {0}++)
for (int {1} = 0; {1} < n; {1}++)
for (int {2} = 0; {2} < n; {2}++)
c[i][j] += a[i][k] * b[k][j];
}}"""
for o in ["ijk", "ikj", "jik", "jki", "kij", "kji"]:
test(array_array, "mint[][]", o)
#!/bin/bash
make main || exit 1
(time ./main) 2>&1 | awk 'NR==4 {print $0}'
(time ./main) 2>&1 | awk 'NR==4 {print $0}'
(time ./main) 2>&1 | awk 'NR==4 {print $0}'
(time ./main) 2>&1 | awk 'NR==4 {print $0}'
(time ./main) 2>&1 | awk 'NR==4 {print $0}'
纯原生数组
代码
mint a[1000][1000], b[1000][1000], c[1010][1010];
void mul() {
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
for (int k = 0; k < n; k++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
}
结果
mint[][] ijk
user 0m1.901s
user 0m1.881s
user 0m1.953s
user 0m1.924s
user 0m1.921s
mint[][] ikj
user 0m1.405s
user 0m1.418s
user 0m1.417s
user 0m1.420s
user 0m1.413s
mint[][] jik
user 0m1.851s
user 0m1.843s
user 0m1.838s
user 0m1.851s
user 0m1.809s
mint[][] jki
user 0m2.549s
user 0m2.796s
user 0m3.098s
user 0m2.451s
user 0m2.441s
mint[][] kij
user 0m1.448s
user 0m1.461s
user 0m1.432s
user 0m1.438s
user 0m1.441s
mint[][] kji
user 0m2.368s
user 0m2.301s
user 0m2.312s
user 0m2.343s
user 0m2.393s
结论
纯原生数组,\(i, k, j\) 或 \(k, i, j\) 的顺序跑的最快。
vector 数组
代码
void mul() {
for (int i = 0; i < n; i++) c[i].resize(n);
for (int i = 0; i < n; i++) {
for (int k = 0; k < n; k++) {
for (int j = 0; j < n; j++) {
c[i][j] += a[i][k] * b[k][j];
}
}
}
}
结果
vector[] ijk
user 0m2.241s
user 0m2.186s
user 0m2.173s
user 0m2.232s
user 0m2.137s
vector[] ikj
user 0m1.314s
user 0m1.271s
user 0m1.273s
user 0m1.279s
user 0m1.272s
vector[] jik
user 0m2.130s
user 0m2.144s
user 0m2.122s
user 0m2.121s
user 0m2.110s
vector[] jki
user 0m3.463s
user 0m3.498s
user 0m4.141s
user 0m4.225s
user 0m3.928s
vector[] kij
user 0m1.305s
user 0m1.275s
user 0m1.264s
user 0m1.287s
user 0m1.256s
vector[] kji
user 0m3.164s
user 0m3.174s
user 0m3.231s
user 0m3.184s
user 0m3.677s
结论
vector
数组比纯数组快,\(i, k, j\) 或 \(k, i, j\) 的顺序跑的最快。
valarray 数组
代码
valarray<mint> a[1000], b[1000], c[1000];
void init() { for (int i = 0; i < n; i++) a[i].resize(n), b[i].resize(n), c[i].resize(n); }
void mul() {
for (int {0} = 0; {0} < n; {0}++)
for (int {1} = 0; {1} < n; {1}++)
c[i] += a[i][k] * b[k];
}
结果
valarray[] ik
user 0m1.325s
user 0m1.319s
user 0m1.322s
user 0m1.332s
user 0m1.332s
valarray[] ki
user 0m1.331s
user 0m1.310s
user 0m1.323s
user 0m1.344s
user 0m1.323s
vector 套 valarray
代码
vector<valarray<mint>> a, b, c;
void init() { a.resize(n, valarray<mint>(n)); c.resize(n, valarray<mint>(n)); b.resize(n, valarray<mint>(n)); }
void mul() {
for (int {0} = 0; {0} < n; {0}++)
for (int {1} = 0; {1} < n; {1}++)
c[i] += a[i][k] * b[k];
}
结果
vector<valarray> ik
user 0m1.261s
user 0m1.290s
user 0m1.273s
user 0m1.276s
user 0m1.270s
vector<valarray> ki
user 0m1.334s
user 0m1.321s
user 0m1.311s
user 0m1.310s
user 0m1.312s
vector 套 vector
代码
vector<vector<mint>> a, b, c;
void init() {{ a.resize(n,vector<mint>(n)); c.resize(n,vector<mint>(n)); b.resize(n,vector<mint>(n)); }}
void mul() {{
for (int {0} = 0; {0} < n; {0}++)
for (int {1} = 0; {1} < n; {1}++)
for (int {2} = 0; {2} < n; {2}++)
c[i][j] += a[i][k] * b[k][j];
}}
结果
vector<vector> ijk
user 0m2.218s
user 0m2.170s
user 0m2.273s
user 0m2.206s
user 0m2.226s
vector<vector> ikj
user 0m1.262s
user 0m1.230s
user 0m1.254s
user 0m1.245s
user 0m1.240s
vector<vector> jik
user 0m2.156s
user 0m2.143s
user 0m2.167s
user 0m2.158s
user 0m2.175s
vector<vector> jki
user 0m3.277s
user 0m3.171s
user 0m3.080s
user 0m3.225s
user 0m3.149s
vector<vector> kij
user 0m1.273s
user 0m1.263s
user 0m1.274s
user 0m1.287s
user 0m1.288s
vector<vector> kji
user 0m3.351s
user 0m3.021s
user 0m2.867s
user 0m2.879s
user 0m2.824s
原生数组展平(暴力实现)
代码
mint a[1000000], b[1000000], c[1000000];
void init() { }
void mul() {
for (int {0} = 0; {0} < n; {0}++)
for (int {1} = 0; {1} < n; {1}++)
for (int {2} = 0; {2} < n; {2}++)
c[i * n + j] += a[i * n + k] * b[k * n + j];
}
结果
mint[i * n + j] ijk
user 0m3.086s
user 0m3.122s
user 0m3.082s
user 0m3.052s
user 0m3.258s
mint[i * n + j] ikj
user 0m2.122s
user 0m2.106s
user 0m2.112s
user 0m2.098s
user 0m2.092s
mint[i * n + j] jik
user 0m3.084s
user 0m3.128s
user 0m3.088s
user 0m3.122s
user 0m3.076s
mint[i * n + j] jki
user 0m3.519s
user 0m3.773s
user 0m3.571s
user 0m3.515s
user 0m3.487s
mint[i * n + j] kij
user 0m2.160s
user 0m2.130s
user 0m2.179s
user 0m2.188s
user 0m2.180s
mint[i * n + j] kji
user 0m3.174s
user 0m3.202s
user 0m3.217s
user 0m3.306s
user 0m3.330s
原生数组展平(针对性优化)
ijk
user 0m2.769s
user 0m2.688s
user 0m2.701s
user 0m2.678s
user 0m2.737s
ikj
user 0m1.281s
user 0m1.271s
user 0m1.256s
user 0m1.273s
user 0m1.271s
kij
user 0m1.869s
user 0m1.847s
user 0m1.875s
user 0m1.826s
user 0m1.788s
其它不测了
vector 展平
ikj
user 0m1.358s
user 0m1.377s
user 0m1.330s
user 0m1.323s
user 0m1.405s
valarray 展平(slice)
ikj
注意 slice_array
没有重载数乘,很自闭。
user 0m1.420s
user 0m1.404s
user 0m1.486s
user 0m1.472s
user 0m1.406s
总结
\(i, k, j\) 实至名归。这之下,vector<valarray<mint>>
、vector<vector<mint>>
、原生数组展平(但必须写指针加加减减的形式)都比较好。
不知道为什么能测出这样的结论,说好的 vector<vector<mint>>
的储存不连续呢?那写矩阵题到底应该将矩阵封装成什么?
作为经验丰富的想象学竞赛选手,读者自行想象不难。
标签:int,mint,void,矩阵,++,vector,user,写法,乘法 From: https://www.cnblogs.com/caijianhong/p/18324995