题面
将 \(n\) 个正整数不调整顺序分成 \(k\) 段,编号\(1 \dots k\) ,使得\(\sum\limits_{i = 1}^{n} a_{i} \cdot f\left(i\right)\)最大。
我们规定 \(f\left(i\right)\) 表示第 \(i\) 个数属于的段编号。
心路历程
Round 1
\[\begin{aligned} \sum\limits_{i = 1}^{n} a_{i} \cdot f\left(i\right) \to a_{1}+ a_{2} \cdot f\left(2\right) + a_{3} \cdot f\left(3\right) + \dots + a_{n} \cdot f\left(n\right) \end{aligned} \]由于相同的段中元素的 \(f\left(\right)\) 相同,提取公因数,得:
\[a_{1} + a_{2} + \dots + a_{i} + \left(a_{i + 1} + a_{i + 2} + \dots + a_{j}\right) \times 2 + \dots + (a_{k} + a_{k + 1} + \dots + a_{n}) \times k \]拆括号,得:
\[a_{1} + a_{2} + \dots + a_{i}+ 2a_{i + 1} + 2a_{i + 2} + \dots + 2a_{j} + \dots + ka_{k}+ ka_{k + 1} + \dots + ka_n \]思路中断。
Round 2
考虑分段求和,\(sum_i\) 为第 \(i\) 段的和(考虑前缀和,后文 \(s_i\) 即前缀和数组),即:
\[ans = \sum\limits_{i = 1}^{k} i \times sum_{i} \]把\(\sum\limits\) 展开,得:
\[ans = s_{1} + 2\left(s_{{2}} - s_1 \right) + \dots + \left(k - 1\right)\left(s_{k - 1} - s_{k - 2}\right) + k \left(s_{k}- s_{k - 1} \right) \]继续化简,得:
\[\begin{aligned} ans = s_{1} - 2s_{1} + 2s_{2} - 3{s_{2}} + \dots + \left(k - 1\right)s_{k - 1} - ks_{k - 1} + ks_{k} \\ = -\left(s_1+s_2+ \dots +s_{k-1}\right)+ks_{k}\\ = ks_{k}- \sum\limits_{i = 1}^{k - 1}s_i \end{aligned} \]时间复杂度 \(O(2n)\)
空间复杂度 \(O(2n)\)
跑样例后发现WA了
核心代码:
read(n, k);
for(int i = 1; i <= n; ++i) {
read(a[i]);
s[i] = s[i - 1] + a[i];
}
for(int i = 1; i < k; ++i) tot += s[i];
cout << k * s[k] - tot;
Round DEBUG
检查推导,发现推导的第二步不合理。
Round 3
再推导一遍。
\[\begin{aligned} \sum\limits_{i = 1}^{n} a_{i} \cdot f\left(i\right) \space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\newline = a_{1} + \cdots + a_{x_{1}}+ 2\left(a_{x_{1}}+ a_{x_{1} + 1} + \cdots + a_{x_{2}}\right) \cdot 2 + \cdots + \left(a_{x_{k - 1}} + \dots + a_{n}\right) \cdot k \newline = s_{x_{1}} + 2\left(s_{x_{2}} - s_{x_{1}}\right) + \cdots + \left(k - 1\right)\left(s_{x_{k - 1}} - s_{x_{k - 2}}\right) + k\left(s_{n}-s_{x_{k - 1}}\right)\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space \newline = s_{x_{1}} + 2s_{x_{2}} - 2s_{x_{1}} + \cdots + \left(k - 1\right)s_{x_{3}} - \left(k - 1\right)s_{x_{k - 2}} + ks_{n}- ks_{x_{k - 1}} \space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space \newline = ks_{n} - s_{x_{1}} - s_{x_{2}} - \cdots - s_{x_{k - 1}} \space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space\space \newline \end{aligned} \]想要让这个式子最大,只需要让\(s_{x_{1}} + s_{x_{2}} + \cdots + s_{x_{k - 1}}\) 最小即可所以对前缀和数组排序,取第\(1 \to k - 1\) 个即可。注意排序后不影响题面所说
\[\begin{aligned} \therefore 排序后求 ks_n + \sum\limits_{i = 1}^{k - 1} s_i即可 \end{aligned} \]不能动原始序列的顺序
核心代码:
read(n, k);
for(int i = 1; i <= n; ++i) read(a[i]);
for(int i = 1; i <= n; ++i) s[i] = s[i - 1] + a[i];
sort(s + 1, s + n);
for(int i = 1; i < k; ++i) tot += s[i];
write(k * s[n] - tot);
WA
Round DEBUG
算了一下数据范围,会爆int
十年OI一场空,不开long long见祖宗
AC Code
#pragma region
#include <iostream>
#include <algorithm>
#include <cstdio>
#include <vector>
#include <cmath>
#include <cstring>
#include <type_traits>
using namespace std;
using ll = long long;
using LL = ll;
const int INF = 0x7fffffff;
inline void IOS() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
}
template <typename T>
void read(T &x) {
x = 0; char c = getchar(); int f = 0;
for (; !isdigit(c); c = getchar())
f |= c == '-';
for (; isdigit(c); c = getchar())
x = x * 10 + (c ^ '0');
if (f) x = -x;
}
template <typename T, typename ... Args>
void read(T& a, Args&... args) {
read(a), read(args...);
}
template <typename T>
void write(T x, char ed = '\n') {
if(std::is_same<typename std::decay<T>, char>::value) {
putchar(x); return ;
}
if (x < 0) putchar('-'), x = -x;
static short st[30], tp;
do st[++tp] = x % 10, x /= 10; while (x);
while (tp) putchar(st[tp--] | '0');
putchar(ed);
}
template <typename T, typename ... Args>
void write(T& a, Args&... args) {
write(a), write(args...);
}
#pragma endregion
const int N = 5e5 + 10;
void Solve();
ll n, k;
vector<ll> a(N);
ll s[N];
ll tot = 0;
signed main() {
IOS();
Solve();
return 0;
}
void Solve_throw() {
read(n, k);
for(ll i = 1; i <= n; ++i) {
read(a[i]);
s[i] = s[i - 1] + a[i];
}
sort(s + 1, s + n);
for(ll i = 1; i < k; ++i) tot += s[i];
cout << k * s[n] - tot;
}
void Solve() {
read(n, k);
for(int i = 1; i <= n; ++i) read(a[i]);
for(int i = 1; i <= n; ++i) s[i] = s[i - 1] + a[i];
sort(s + 1, s + n);
for(int i = 1; i < k; ++i) tot += s[i];
write(k * s[n] - tot);
}