快速排序
Partition
的含义是让某个基准元素
归位
排序算法——快速排序(Quicksort)基准值的三种选取和优化方法
- 左侧基准
- 右侧先走
- 指针小于(严格)
i < j
- 值等于(非严格)
a[j] >= pivot
,a[i] <= pivot
如果值比较的时候是严格小于或者大于,那么遇到相等数值的时候,指针是无法移动的
1,2,3,1,1
^ ^
| |
int partition(vector<int> a, int l, int r) {
if (l >= r) return;
/* int mid = (l + r) / 2;
swap(a[mid], a[l]); */
int pivot = a[l];
int i = l, j = r;
while (l < r) {
while (a[j] >= pivot && i < j) j--;
while (a[i] <= pivot && i < j) i++;
swap(a[i], a[j]);
}
swap(a[i], a[l]);
return i;
}
void quicksort(vector<int> a) {
qs(a, 0, a.size() - 1);
}
void qs(vector<int> a, int l, int r) {
if (l >= r) return;
int p = partition(a, l, r);
qs(a, p + 1, r);
qs(a, l, p - 1);
}
以下分析以pivot = a[l]
为例
为什么指针相遇之后需要交换,而不是覆盖
如果不交换,直接把 pivot 放入相关位置的话,会有一个元素a[i]
被覆盖掉
那么就需要交换,关键问题在于,a[i]
可以直接交换吗
这个和下一个为什么右边的先移动这个问题有关
为什么右边的先移动
在 Wiki 当中,算法发明者的原始实现版本就不是右边先移动的
另外,国外版本通常以最右侧值作为枢值,也就要左边先移动了
有一些解释,但是根本原因,还是在于 Partition 函数的作用
我们扫描数组的最终目的,是找到一个位置,安放基准值。
更准确地说,是把基准值和某个值交换位置,这个交换不可以破坏 ij 已扫描过的区间有序性
其实不管哪一边先走,都可以满足如下语义
如果
i != j
,a[l:i]
(闭区间) 所有元素 <= a[l]
如果i != j
,a[j:r]
(闭区间) 所有元素 >= a[l]
但是一旦指针相遇,语义就不确定了
如果让 j 先走,相遇的时候有两种情况
while (l < r) {
// [1] j向左碰到了i,此时i这处一定是检验过的
// 因此满足 a[i](即a[j]) <= a[l]
while (a[j] <= pivot && i < j) j--;
// [2] i向右碰到了j,我们假设这种情况可以成立
// 那么此时j已经停在了一个 a[j]<=a[l] 的地方
// 那么此时如果指针相撞,i == j
// a[i] <= a[l] 的条件自然是满足的
while (a[i] <= pivot && i < j) i++;
swap(a[i], a[j]);
}
我们也可以让左边先走
只不过这个时候需要单独验证,最后的 pivot 位置是否满足要求
下面我们用 Python 脚本验证一下(便于打印)
import numpy as np
import random
from tqdm import trange
a = [3, 1, 2, 5, 6, 1, 7, 3, 4, 2, 6, 8, 1, 3, 2, 6, 1]
SHOW = False
def show_ptr(a, l, r, pos):
_list = [str(_) for _ in a]
_list[pos] = f'"{_list[pos]}"'
_list[l] = f'[{_list[l]}'
_list[r] = f'{_list[r]}]'
print(' '.join(_list))
def show_change(a, l, r, i, j):
_list = [str(_) for _ in a]
_list[i] = f'({_list[i]})'
_list[j] = f'({_list[j]})'
_list[l] = f'[{_list[l]}'
_list[r] = f'{_list[r]}]'
print(' '.join(_list))
def qs_right_first(a, l, r):
if l >= r:
return
pivot = a[l]
i = l
j = r
while i < j:
while i < j and a[j] >= pivot:
j -= 1
while i < j and a[i] <= pivot:
i += 1
if SHOW:
show_change(a, l, r, i, j)
a[i], a[j] = a[j], a[i]
# # wrong
# pivot, a[i] = a[i], pivot
# right
a[l], a[i] = a[i], a[l]
if SHOW:
show_ptr(a, l, r, i)
qs_right_first(a, i+1, r)
qs_right_first(a, l, i-1)
def qs_left_first(a, l, r):
if l >= r:
return
pivot = a[l]
i = l
j = r
while i < j:
while i < j and a[i] <= pivot:
i += 1
while i < j and a[j] >= pivot:
j -= 1
if SHOW:
show_change(a, l, r, i, j)
a[i], a[j] = a[j], a[i]
# if i,j does not meet exchange is unproblematic
# However, if they meet, we need to check
if a[i] > pivot:
i = i-1
if SHOW:
print(
f'cannot put pivot a[{l}] = {a[l]} at a[{i +1 }] = {a[i + 1]}, i--')
show_ptr(a, l, r, i)
# move the pivot to its location
a[l], a[i] = a[i], a[l]
if SHOW:
show_ptr(a, l, r, i)
qs_right_first(a, i+1, r)
qs_right_first(a, l, i-1)
if __name__ == "__main__":
K = 10
BOUND = 20
random.seed(781935)
SHOW = True
a = [random.randint(0, BOUND) for _ in range(K)]
a = np.array(a)
# a = np.array([7,1,8,3,5])
_a = a.copy()
# qs_right_first(_a, 0, len(_a) - 1)
qs_left_first(_a, 0, len(_a) - 1)
a.sort()
print((a == _a).all())
K = 1000
BOUND = 50
random.seed(781935)
SHOW = False
ans = True
for i in trange(500):
a = [random.randint(0, BOUND) for _ in range(K)]
a = np.array(a)
_a = a.copy()
# qs_right_first(_a, 0, len(_a) - 1)
qs_left_first(_a, 0, len(_a) - 1)
a.sort()
ans = np.logical_and(ans, ((a == _a).all()))
print('qs_left_first', ans)
[17 10 3 5 10 10 10 0 (18) (12)]
# 18,12 交换之后,i先走一步,撞到了j(在len-1处)
[17 10 3 5 10 10 10 0 12 ((18))]
# 此时从while退出,不知道这个位置能不能作为pivot的归位处,因此需要进行判断
cannot put pivot a[0] = 17 at a[9] = 18, i--
# 向后一步,这一片区域都是i扫过的区域,因此pivot一定可以归位
[17 10 3 5 10 10 10 0 "12" 18]
[12 10 3 5 10 10 10 0 "17" 18]
[12 10 3 5 10 10 10 ((0))] 17 18
[0 10 3 5 10 10 10 "12"] 17 18
[((0)) 10 3 5 10 10 10] 12 17 18
["0" 10 3 5 10 10 10] 12 17 18
0 [10 3 ((5)) 10 10 10] 12 17 18
0 [5 3 "10" 10 10 10] 12 17 18
0 5 3 10 [((10)) 10 10] 12 17 18
0 5 3 10 ["10" 10 10] 12 17 18
0 5 3 10 10 [((10)) 10] 12 17 18
0 5 3 10 10 ["10" 10] 12 17 18
0 [5 ((3))] 10 10 10 10 12 17 18
0 [3 "5"] 10 10 10 10 12 17 18
更进一步,是否可以任意选取枢轴值
- 以枢轴值为标准扫描
- 枢轴值归位
分析
i,j 相遇,左右两侧的序列一定是满足要求的
最后返回的也一定是i=j
这个位置,其他所有位置的值都已经通过了检验
另外,任意取pivot
,a[pivot_ind]
一定是在原位的
- 因为移动只对应
i != j
的情况,而i,j
不会在pivot
处停留 - 而如果相撞在
pivot_ind
处,自然也没有影响
不过,这个位置的值本身和pivot
的关系,没有通过检验
- 如果相等,
i
这个位置就自然满足要求 - 如果不相等,我们需要把
a[i]
换成pivot
,因为我们的扫描过程是以枢轴值为标准的
a[i]
需要和pivot
进行交换
代码
pivot is a[9] = 2
[(6) 1 (1) 2 7 8 3 7 9 2]
[1 1 ((6)) 2 7 8 3 7 9 2]
a[2] = 6 can be safely swapped
[1 1 "6" 2 7 8 3 7 9 "2"]
[1 1 "2" 2 7 8 3 7 9 6]
---
pivot is a[5] = 8
1 1 2 [2 7 8 3 7 (9) (6)]
1 1 2 [2 7 8 3 7 6 ((9))]
smaller pivot a[5] = 8 at the left of a[9] = 9, i = max(l, i-1)
1 1 2 [2 7 "8" 3 7 "6" 9]
1 1 2 [2 7 6 3 7 "8" 9]
---
pivot is a[5] = 6
1 1 2 [2 (7) 6 (3) 7] 8 9
1 1 2 [2 3 6 ((7)) 7] 8 9
smaller pivot a[5] = 6 at the left of a[6] = 7, i = max(l, i-1)
1 1 2 [2 3 ""6"" 7 7] 8 9
1 1 2 [2 3 "6" 7 7] 8 9
---
pivot is a[6] = 7
1 1 2 2 3 6 [7 ((7))] 8 9
a[7] = 7 can be safely swapped
1 1 2 2 3 6 ["7" "7"] 8 9
1 1 2 2 3 6 [7 "7"] 8 9
---
pivot is a[3] = 2
1 1 2 [2 ((3))] 6 7 7 8 9
smaller pivot a[3] = 2 at the left of a[4] = 3, i = max(l, i-1)
1 1 2 [""2"" 3] 6 7 7 8 9
1 1 2 ["2" 3] 6 7 7 8 9
---
pivot is a[0] = 1
[1 ((1))] 2 2 3 6 7 7 8 9
a[1] = 1 can be safely swapped
["1" "1"] 2 2 3 6 7 7 8 9
[1 "1"] 2 2 3 6 7 7 8 9
---
================pass================