堆是一种树形结构,树的根是堆顶,堆顶始终保持为所有元素的最优值。有大根堆和小根堆,大根堆的根节点是最大值,小根堆的根节点是最小值。堆一般用二叉树实现,称为二叉堆。
堆的存储方式
堆的操作
empty
返回堆是否为空
top
直接返回根节点的值,时间复杂度 \(O(1)\)
push
将新元素添加在数组最后面,若它比父节点小则不断与其父节点交换,使得堆重新满足父节点比子节点存储的数都要小(自下而上),时间复杂度 \(O(\log n)\)
pop
弹出根节点,并让堆依然符合原来的性质。首先交换根节点和数组中最后一个元素,再去掉最后一个元素。若新根节点比子节点大,则不断与较小子节点交换,直到重新满足条件(自上而下),时间复杂度 \(O(\log n)\)
例:P3378 【模板】堆
由此,给出二叉堆的模板实现:
参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN = 1e6 + 5;
int heap[MAXN], len;
void push(int x) {
heap[++len] = x;
int i = len;
while (i > 1 && heap[i] < heap[i / 2]) {
swap(heap[i], heap[i / 2]);
i /= 2;
}
}
void pop() {
heap[1] = heap[len--];
int i = 1;
while (i * 2 <= len) {
int son = i * 2;
if (son < len && heap[son + 1] < heap[son]) son++;
if (heap[son] < heap[i]) {
swap(heap[son], heap[i]);
i = son;
} else break;
}
}
int main()
{
int n;
scanf("%d", &n);
while (n--) {
int op;
scanf("%d", &op);
if (op == 1) {
int x;
scanf("%d", &x);
push(x);
} else if (op == 2) printf("%d\n", heap[1]);
else pop();
}
return 0;
}
例:P1177 【模板】排序
输入 \(n (n < 10^5)\) 个数字 \(a_i ()a_i < 10^9\),将其从小到大排序后输出。
分析:利用堆也是可以做排序的,先把所有的元素 push 进去,然后每次取出堆顶(最小值)输出并弹出堆顶,直到堆空为止,这种排序方法称为堆排序。
参考代码
#include <cstdio>
#include <algorithm>
using namespace std;
const int MAXN=100005;
struct Heap {
int a[MAXN],cnt;
void push(int x) { // 压入
a[++cnt]=x;
int i=cnt;
while (i>1 && a[i]<a[i/2]) {
swap(a[i/2],a[i]);
i/=2;
}
}
void pop() { // 删除
a[1]=a[cnt--];
int i=1;
while (i*2<=cnt) {
int son=i*2;
if (son<cnt && a[son+1]<a[son]) son++;
if (a[son]<a[i]) {
swap(a[son],a[i]);
i=son;
} else break;
}
}
int top() {
return a[1];
}
};
Heap h;
int main()
{
int n,x;
scanf("%d",&n);
for (int i=1;i<=n;i++) {
scanf("%d",&x);
h.push(x);
}
for (int i=1;i<=n;i++) {
printf("%d ",h.top());
h.pop();
}
return 0;
}
堆排序整体的时间复杂度是 \(O(n \log n)\),空间复杂度为 \(O(n)\)
优先队列
C++ 提供了优先队列这个数据结构,也就是 STL 中的 priority_queue
,底层就是由堆实现的。要使用优先队列,需要包含 queue
头文件,优先队列支持的基础操作如下:
priority_queue<int> q
新建一个保存int
型变量的优先队列q
,默认是大根堆priority_queue<int, vector<int>, greater<int>> q
新建一个小根堆q.top()
优先队列查询最大值(或者是最小值)q.pop()
将最大值(最小值)弹出队列q.push(x)
将x
加入优先队列
和大多数 STL 容器一样,可以使用 q.empty()
判断它是否为空,用 q.size()
获取它的大小。
例:P3378 【模板】堆
用 STL 的优先队列来写这道题代码更加简洁。
// STL 优先队列
#include <cstdio>
#include <algorithm>
#include <queue>
using namespace std;
priority_queue<int, vector<int>, greater<int>> q; // 小根堆
int main()
{
int n; scanf("%d", &n); // 操作次数
while (n--) {
int op, x; scanf("%d", &op);
if (op == 1) { scanf("%d", &x); q.push(x); }
else if (op == 2) printf("%d\n", q.top());
else q.pop();
}
return 0;
}
例:P2168 [NOI2015] 荷马史诗
一部《荷马史诗》中有 \(n(n \le 10^6)\) 种不同的单词,从 \(1\) 到 \(n\) 进行编号。其中第 \(i\) 种单词出现的总次数为 \(w_i(w_i \le 10^11)\)。现在要用 \(k\) 进制串 \(s_i\) 来替换第 \(i\) 种单词,使得其满足对于任意的 \(1 \le i,j \le n, i \ne j\),都有 \(s_i\) 不是 \(s_j\) 的前缀。请问如何选择 \(s_i\),才能使替换以后得到的新的《荷马史诗》长度最小。在确保总长度最小的情况下,还想知道最长的 \(s_i\) 的最短长度是多少?
解题思路
哈夫曼编码的变形。每次从堆中选出权重最小的 \(k\) 个结点,将其合并建边,然后放回堆中,直到建完哈夫曼树。例如,当各结点权重分别为 1、1、3、3、9、9,需要编码为三进制时,生成的哈夫曼树如下:
需要注意的是,每次合并都会减少 \(k-1\) 个结点,在合并最后一次的时候,如果可以合并的点的数量不足 \(k\) 个,靠近根结点的位置(短编码)反而没有被利用,所以需要在一开始补上 k-1-(n-1)%(k-1)
个权重为 \(0\) 的结点,把权重大的结点“推”到离根结点更近的位置。根据题目数据范围,答案需要 long long
类型。
参考代码
#include <cstdio>
#include <queue>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100005;
LL w[N];
struct Node {
LL val;
int depth;
};
struct NodeCompare { // 定义Node比较类
bool operator()(const Node &a, const Node &b) {
// 权重相同时,高度小的优先出队
return a.val != b.val ? a.val > b.val : a.depth > b.depth;
}
};
int main()
{
int n, k;
scanf("%d%d", &n, &k);
priority_queue<Node, vector<Node>, NodeCompare> q;
for (int i = 1; i <= n; i++) {
scanf("%lld", &w[i]);
q.push({w[i], 1}); // 读入结点(叶节点)
}
if ((n - 1) % (k - 1) != 0) { // 有一次合并结点数量不足k个
for (int i = 1; i <= k - 1 - (n - 1) % (k - 1); i++)
q.push({0, 1}); // 需要补若干个权重为0的结点
}
LL ans = 0;
while (q.size() != 1) {
LL sum = 0; int maxh = 0;
for (int i = 1; i <= k; i++) { // 从堆中取k个最小的
Node tmp = q.top(); q.pop();
sum += tmp.val; // 新结点加上子结点权重
maxh = max(maxh, tmp.depth); // 最大深度
}
ans += sum; // 更新总长度
q.push({sum, maxh + 1}); // 合并后的结点放回堆中
}
printf("%lld\n%lld\n", ans, q.top().depth - 1); // 编码长度是哈夫曼树的高度减1
return 0;
}
例:P2085 最小函数值
题目给定了若干个二次函数,由于 \(x\) 取的都是正整数,并且三个系数都为正整数,因此函数的取值单调递增且肯定大于 \(0\),要求这些函数生成的所有函数值中最小的 \(m\) 个。
朴素想法
暴力计算每个函数值
朴素的想法是对于每个函数都计算前 \(m\) 个取值,这样会得到 \(n \times m\) 个函数值,最小的 \(m\) 个函数值一定在这个范围内,用一个最大容量限定为 \(m\) 的小根堆始终维护最小的 \(m\) 个函数值,时间复杂度 \(O(nm \log m)\)
优化思路
注意函数的取值是单调递增的,因此实际上可以看作是给定 \(n\) 个排好序的数组,只不过数组并没有真正地存下来,而是给出了下标和值的对应关系。对于每个数组,它们的最小值所在的下标都是 \(1\),假设每个数组都有一个箭头指向 \(1\),需要在所有箭头指向的函数值中找到最小的那个,接下来最小的那个所处的数组的箭头向后移动,指向 \(2\),然后再和其他箭头关联的函数值比较,以此类推。这样一来箭头的后移只需要执行 \(m\) 次即可,而找最小函数值这个过程可以利用一个小根堆来提高效率,总体时间复杂度 \(O(m \log m)\)。
参考代码
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
const int N = 10005;
int a[N], b[N], c[N];
struct Node {
int idx, x, f;
};
struct NodeCompare {
bool operator()(const Node &lhs, const Node &rhs) const {
return lhs.f > rhs.f;
}
};
priority_queue<Node, vector<Node>, NodeCompare> q;
int fn(int idx, int x) {
return a[idx] * x * x + b[idx] * x + c[idx];
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 0; i < n; i++) {
scanf("%d%d%d", &a[i], &b[i], &c[i]);
}
for (int i = 0; i < n; i++) q.push({i, 1, fn(i, 1)});
for (int i = 0; i < m; i++) {
Node t = q.top();
q.pop();
printf("%d ", t.f);
q.push({t.idx, t.x + 1, fn(t.idx, t.x + 1)});
}
return 0;
}
例:P1631 序列合并
解题思路
可以发现,最小和一定是 \(A[1]+B[1]\),次小和是 \(\min (A[1]+B[2],A[2]+B[1])\),假设次小和是 \(A[2]+B[1]\),那么第三小和就是 \(A[1]+B[2],A[2]+B[2],A[3]+B[1]\) 三者之一。也就是说,当确定 \(A[i]+B[j]\) 为第 \(k\) 小和后,\(A[i+1]+B[j]\) 与 \(A[i]+B[j+1]\) 就加入了第 \(k+1\) 小和的备选答案集合。需要注意的是,\(A[1]+B[2]\) 与 \(A[2]+B[1]\) 都能产生 \(A[2]+B[2]\) 这个备选答案。
考虑到这一点,我们不妨把 \(A\) 和 \(B\) 两个序列的和看成 \(N\) 个有序数组,其中第一个数组为 \(A[1]+B[...]\),第二个数组为 \(A[2]+B[...]\),以此类推。这样一来,就相当于将这 \(N\) 个有序数组合并取出前 \(N\) 小的。因此可以先将 \(A[1]+B[1], A[2]+B[1], ..., A[N]+B[1]\) 这 \(N\) 种情况先加入堆中,若取出的堆顶元素来自于第 \(K\) 个数组,则将 \(A[K]+B[2]\) 这种情况继续放入堆中,直到取够前 \(N\) 种情况。时间复杂度 \(O(N \log N)\)。
参考代码
#include <cstdio>
#include <algorithm>
#include <queue>
using namespace std;
typedef long long LL;
const int N = 100005;
int a[N], b[N], ans[N];
struct Index {
int x, y;
};
struct IndexCompare {
bool operator()(const Index& idx1, const Index& idx2) const {
return a[idx1.x] + b[idx1.y] > a[idx2.x] + b[idx2.y];
}
};
int main()
{
int n; scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
for (int i = 1; i <= n; i++) scanf("%d", &b[i]);
priority_queue<Index, vector<Index>, IndexCompare> q;
for (int i = 1; i <= n; i++) q.push({i, 1});
for (int i = 1; i <= n; i++) {
Index tmp = q.top(); q.pop();
ans[i] = a[tmp.x] + b[tmp.y];
q.push({tmp.x, tmp.y + 1});
}
for (int i = 1; i <= n; i++) printf("%d%c", ans[i], i == n ? '\n' : ' ');
return 0;
}
对顶堆
如果把大根堆想成一个上宽下窄的三角形,把小根堆想成一个上窄下宽的三角形,那么对顶堆就可以具体地被想象成一个“陀螺”或者一个“沙漏”,通过这两个堆的上下组合,我们可以把一组数据分别加入到对顶堆中的大根堆和小根堆,以维护我们不同的需要。
根据数学中不等式的传递原理,假如一个集合 A 中的最小元素比另一个集合 B 中的最大元素还要大,那么就可以断定: A 中的所有元素都比 B 中元素大。所以,我们把小根堆“放在”大根堆“上面”,如果小根堆的堆顶元素比大根堆的堆顶元素大,那么小根堆的所有元素要比大根堆的所有元素大。
例如给定 \(N\) 个数字,求其前 \(i\) 个元素中第 \(K\) 小的那个元素。
我们可以这样解决问题:把大根堆的元素个数限制成 \(K\) 个,前 \(K\) 个元素入队之后,每个元素在压入堆之前先与堆顶元素比较,如果比堆顶元素大,就加入小根堆,如果没有的话,把大根堆的堆顶弹出至小根堆,将新元素加入大根堆,这样就维护出一个对顶堆。
同理,对顶堆还可以用于解决其他“第 \(K\) 小”的变形问题:比如求前 \(i\) 个元素的中位数等。
例:P1168 中位数
解题思路
使用两个堆,大根堆维护较小的数,小根堆维护较大的数。这样一来,小根堆的堆顶是较大的数中最小的,大根堆的堆顶是较小的数中最大的。
而求中位数只需要在保证两个堆中元素大小关系的同时,控制两个堆的大小尽可能平衡,这样其中一个堆的堆顶元素即为中位数。
参考代码
#include <cstdio>
#include <queue>
#include <vector>
using namespace std;
int main()
{
int n;
scanf("%d", &n);
priority_queue<int> big;
priority_queue<int, vector<int>, greater<int>> small;
for (int i = 1; i <= n; i++) {
int x;
scanf("%d", &x);
small.push(x);
if (i % 2 == 1) {
while (!big.empty() && small.top() < big.top()) {
int st = small.top();
small.pop();
int bt = big.top();
big.pop();
small.push(bt);
big.push(st);
}
int st = small.top();
small.pop();
big.push(st);
printf("%d\n", big.top());
}
}
return 0;
}
例:P1801 黑匣子
解题思路
控制对顶堆中的大根堆的元素数目伴随着 \(i\) 的增长而增长。
参考代码
#include <cstdio>
#include <queue>
#include <iostream>
using namespace std;
const int N = 200005;
int a[N];
priority_queue<int, vector<int>, greater<int>> h;
priority_queue<int> ans;
int main()
{
int m, n;
scanf("%d%d", &m, &n);
for (int i = 1; i <= m; i++) scanf("%d", &a[i]);
int pre = 0;
int idx = 0;
while (n--) {
int u;
scanf("%d", &u);
for (int i = pre + 1; i <= u; i++)
ans.push(a[i]);
while (ans.size() > idx) {
h.push(ans.top());
ans.pop();
}
ans.push(h.top());
h.pop();
printf("%d\n", ans.top());
pre = u;
idx++;
}
return 0;
}