妙妙题 #dp

转换一下 \(a_i^2\),发现这个值等价于操作 \(2\) 次最后得到结果一样的方案数


\(dp_{k,i,j}\) 表示操作了 \(k\) 轮,第一次的上面取了 \(i\) 个,第二次的上面取了 \(j\) 个

转移分 \(4\) 种暴力就行


// Author: xiaruize
bool start_of_memory_use;
#define debug(x)
#include <bits/stdc++.h>
using namespace std;
clock_t start_clock = clock();
namespace __DEBUG_UTIL__
	using namespace std;
	/* Primitive Datatypes Print */
	void print(const char *x) { cerr << x; }
	void print(bool x) { cerr << (x ? "T" : "F"); }
	void print(char x) { cerr << '\'' << x << '\''; }
	void print(signed short int x) { cerr << x; }
	void print(unsigned short int x) { cerr << x; }
	void print(signed int x) { cerr << x; }
	void print(unsigned int x) { cerr << x; }
	void print(signed long int x) { cerr << x; }
	void print(unsigned long int x) { cerr << x; }
	void print(signed long long int x) { cerr << x; }
	void print(unsigned long long int x) { cerr << x; }
	void print(float x) { cerr << x; }
	void print(double x) { cerr << x; }
	void print(long double x) { cerr << x; }
	void print(string x) { cerr << '\"' << x << '\"'; }
	template <size_t N>
	void print(bitset<N> x) { cerr << x; }
	void print(vector<bool> v)
	{ /* Overloaded this because stl optimizes vector<bool> by using
		  _Bit_reference instead of bool to conserve space. */
		int f = 0;
		cerr << '{';
		for (auto &&i : v)
			cerr << (f++ ? "," : "") << (i ? "T" : "F");
		cerr << "}";
	/* Templates Declarations to support nested datatypes */
	template <typename T>
	void print(T &&x);
	template <typename T>
	void print(vector<vector<T>> mat);
	template <typename T, size_t N, size_t M>
	void print(T (&mat)[N][M]);
	template <typename F, typename S>
	void print(pair<F, S> x);
	template <typename T, size_t N>
	struct Tuple;
	template <typename T>
	struct Tuple<T, 1>;
	template <typename... Args>
	void print(tuple<Args...> t);
	template <typename... T>
	void print(priority_queue<T...> pq);
	template <typename T>
	void print(stack<T> st);
	template <typename T>
	void print(queue<T> q);
	/* Template Datatypes Definitions */
	template <typename T>
	void print(T &&x)
		/*  This works for every container that supports range-based loop
			i.e. vector, set, map, oset, omap, dequeue */
		int f = 0;
		cerr << '{';
		for (auto &&i : x)
			cerr << (f++ ? "," : ""), print(i);
		cerr << "}";
	template <typename T>
	void print(vector<vector<T>> mat)
		int f = 0;
		cerr << "\n~~~~~\n";
		for (auto &&i : mat)
			cerr << setw(2) << left << f++, print(i), cerr << "\n";
		cerr << "~~~~~\n";
	template <typename T, size_t N, size_t M>
	void print(T (&mat)[N][M])
		int f = 0;
		cerr << "\n~~~~~\n";
		for (auto &&i : mat)
			cerr << setw(2) << left << f++, print(i), cerr << "\n";
		cerr << "~~~~~\n";
	template <typename F, typename S>
	void print(pair<F, S> x)
		cerr << '(';
		cerr << ',';
		cerr << ')';
	template <typename T, size_t N>
	struct Tuple
		static void printTuple(T t)
			Tuple<T, N - 1>::printTuple(t);
			cerr << ",", print(get<N - 1>(t));
	template <typename T>
	struct Tuple<T, 1>
		static void printTuple(T t) { print(get<0>(t)); }
	template <typename... Args>
	void print(tuple<Args...> t)
		cerr << "(";
		Tuple<decltype(t), sizeof...(Args)>::printTuple(t);
		cerr << ")";
	template <typename... T>
	void print(priority_queue<T...> pq)
		int f = 0;
		cerr << '{';
		while (!pq.empty())
			cerr << (f++ ? "," : ""), print(pq.top()), pq.pop();
		cerr << "}";
	template <typename T>
	void print(stack<T> st)
		int f = 0;
		cerr << '{';
		while (!st.empty())
			cerr << (f++ ? "," : ""), print(st.top()), st.pop();
		cerr << "}";
	template <typename T>
	void print(queue<T> q)
		int f = 0;
		cerr << '{';
		while (!q.empty())
			cerr << (f++ ? "," : ""), print(q.front()), q.pop();
		cerr << "}";
	/* Printer functions */
	void printer(const char *) {} /* Base Recursive */
	template <typename T, typename... V>
	void printer(const char *names, T &&head, V &&...tail)
		/* Using && to capture both lvalues and rvalues */
		int i = 0;
		for (size_t bracket = 0; names[i] != '\0' and (names[i] != ',' or bracket != 0); i++)
			if (names[i] == '(' or names[i] == '<' or names[i] == '{')
			else if (names[i] == ')' or names[i] == '>' or names[i] == '}')
		cerr.write(names, i) << " = ";
		if (sizeof...(tail))
			cerr << " ||", printer(names + i + 1, tail...);
			cerr << "]\n";
	/* PrinterArr */
	void printerArr(const char *) {} /* Base Recursive */
	template <typename T, typename... V>
	void printerArr(const char *names, T arr[], size_t N, V... tail)
		size_t ind = 0;
		for (; names[ind] and names[ind] != ','; ind++)
			cerr << names[ind];
		for (ind++; names[ind] and names[ind] != ','; ind++)
		cerr << " = {";
		for (size_t i = 0; i < N; i++)
			cerr << (i ? "," : ""), print(arr[i]);
		cerr << "}";
		if (sizeof...(tail))
			cerr << " ||", printerArr(names + ind + 1, tail...);
			cerr << "]\n";
#define debug(...) std::cerr << __LINE__ << ": [", __DEBUG_UTIL__::printer(#__VA_ARGS__, __VA_ARGS__)
#define debugArr(...) std::cerr << __LINE__ << ": [", __DEBUG_UTIL__::printerArr(#__VA_ARGS__, __VA_ARGS__)
#define debug(...)
#define debugArr(...)
// #define int long long
#define ull unsigned long long
#define ALL(a) (a).begin(), (a).end()
#define pb push_back
#define mk make_pair
#define pii pair<int, int>
#define pis pair<int, string>
#define sec second
#define fir first
#define sz(a) int((a).size())
#define Yes cout << "Yes" << endl
#define YES cout << "YES" << endl
#define No cout << "No" << endl
#define NO cout << "NO" << endl
#define mms(arr, n) memset(arr, n, sizeof(arr))
#define rep(i, a, n) for (int i = (a); i <= (n); ++i)
#define per(i, n, a) for (int i = (n); i >= (a); --i)
int max(int a, int b)
	if (a > b)
		return a;
	return b;
int min(int a, int b)
	if (a < b)
		return a;
	return b;
const int INF = 0x3f3f3f3f3f3f3f3f;
const int MOD = 1024523;
const int N = 5e2 + 10;

int n, m;
char s[N], t[N];
int dp[2][505][505];

void solve()
	cin >> n >> m;
	cin >> (s + 1) >> (t + 1);
	reverse(s + 1, s + n + 1);
	reverse(t + 1, t + m + 1);
	dp[0][0][0] = 1;
	rep(r, 1, n + m)
		mms(dp[r & 1], 0);
		rep(i, 0, n)
			rep(j, 0, n)
				if (s[i] == s[j] && i && j)
					dp[r & 1][i][j] += dp[(r & 1) ^ 1][i - 1][j - 1];
				if (r - j > 0 && s[i] == t[r - j] && i)
					dp[r & 1][i][j] += dp[(r & 1) ^ 1][i - 1][j];
				if (r - i > 0 && t[r - i] == s[j] && j)
					dp[r & 1][i][j] += dp[(r & 1) ^ 1][i][j - 1];
				if (r - j > 0 && r - i > 0 && t[r - i] == t[r - j])
					dp[r & 1][i][j] += dp[(r & 1) ^ 1][i][j];
				dp[r & 1][i][j] %= MOD;
	cout << dp[(n + m) & 1][n][n] << endl;

bool end_of_memory_use;

signed main()
	int testcase = 1;
	// cin >> testcase;
	while (testcase--)
	cerr << "Memory use:" << (&end_of_memory_use - &start_of_memory_use) / 1024.0 / 1024.0 << "MiB" << endl;
	cerr << "Time use:" << (double)clock() / CLOCKS_PER_SEC * 1000.0 << "ms" << endl;
	return 0;

