在前三章中,我们解析了十种经典排序算法。但在实际工作中,我们很少需要自己手写排序,因为 Java 源码中已经为我们提供了用于排序的库函数,我们只需要调用 Arrays.sort() 函数即可完成排序,例如:
int[] arr = new int[]{5, 3, 2, 1, 4};
Arrays.sort(arr);
// 输出:[1, 2, 3, 4, 5]
System.out.println(Arrays.toString(arr));
那么 Java 源码中使用的是哪种排序算法呢?在这一章中,我们就来一探究竟。
注:本章使用的 JDK 版本是 11.0.8,建议读者对照着源码阅读本章。
Java 源码中的 Arrays.sort() 函数是由 Java 语言的几位设计者所编写的,它并没有采用某种单一的排序算法,而是通过分析所输入数据的规模、特点,对不同的输入数据采用不同的排序算法。
Arrays 类中有很多个 sort 函数:
void sort (int[])
void sort (int[], int, int)
void sort (long[])
void sort (long[], int, int)
void sort (short)
void sort (short[], int, int)
void sort (char[])
void sort (char[], int, int)
void sort (byte[])
void sort (byte[], int, int)
void sort (float[])
void sort (float[], int, int)
void sort (double[])
void sort (double[], int, int)
void sort (Object[])
void sort (Object[], int, int)
void sort (T[], Comparator)
void sort (T[], int, int, Comparator)
这些 sort 函数可以分为两类:
- 对基本类型的排序(
int 、long 、short 、char 、byte 、float 、double )
- 对非基本类型的排序(
Object 、T )
对基本类型的排序是通过调用对应的 DualPivotQuicksort.sort() 函数完成的。
对非基本类型的排序采用的是 TimSort 或者归并排序,在 JDK 1.7 之前,默认采用归并排序,JDK 1.7 及之后,默认采用 TimSort ,但可以通过设置 JVM 参数 -Djava.util.Arrays.useLegacyMergeSort=true 继续使用归并排序。
简单起见,本章只分析 Array.sort(int[] a) 函数,对其他基本类型的排序思路都是类似的,对非基本类型的排序暂不讲解。
入口
查看 Arrays.sort(int[] a) 函数,代码如下:
public static void sort(int[] a) {
DualPivotQuicksort.sort(a, 0, a.length - 1, null, 0, 0);
}
DualPivotQuicksort 译为「双轴快排」,之前我们介绍的快速排序算法属于「单轴快排」。顾名思义,双轴快排每轮选取两个轴,将数组分为三个区域,这样就能每轮排好两个基数,通常效率比单轴快排更高。
但 DualPivotQuicksort 类中并不是只使用了双轴快排算法,它会根据输入数据的规模、结构化程度来采用不同的排序算法。
DualPivotQuicksort
跟进 DualPivotQuicksort.sort(int[] a, int left, int right, int[] work, int workBase, int workLen) 函数,整个函数如下:
/**
* Sorts the specified range of the array using the given
* workspace array slice if possible for merging
*
* @param a the array to be sorted
* @param left the index of the first element, inclusive, to be sorted
* @param right the index of the last element, inclusive, to be sorted
* @param work a workspace array (slice)
* @param workBase origin of usable space in work array
* @param workLen usable size of work array
*/
static void sort(int[] a, int left, int right,
int[] work, int workBase, int workLen) {
// Use Quicksort on small arrays
if (right - left < QUICKSORT_THRESHOLD) {
sort(a, left, right, true);
return;
}
/*
* Index run[i] is the start of i-th run
* (ascending or descending sequence).
*/
int[] run = new int[MAX_RUN_COUNT + 1];
int count = 0; run[0] = left;
// Check if the array is nearly sorted
for (int k = left; k < right; run[count] = k) {
if (a[k] < a[k + 1]) { // ascending
while (++k <= right && a[k - 1] <= a[k]);
} else if (a[k] > a[k + 1]) { // descending
while (++k <= right && a[k - 1] >= a[k]);
for (int lo = run[count] - 1, hi = k; ++lo < --hi; ) {
int t = a[lo]; a[lo] = a[hi]; a[hi] = t;
}
} else { // equal
for (int m = MAX_RUN_LENGTH; ++k <= right && a[k - 1] == a[k]; ) {
if (--m == 0) {
sort(a, left, right, true);
return;
}
}
}
/*
* The array is not highly structured,
* use Quicksort instead of merge sort.
*/
if (++count == MAX_RUN_COUNT) {
sort(a, left, right, true);
return;
}
}
// Check special cases
// Implementation note: variable "right" is increased by 1.
if (run[count] == right++) { // The last run contains one element
run[++count] = right;
} else if (count == 1) { // The array is already sorted
return;
}
// Determine alternation base for merge
byte odd = 0;
for (int n = 1; (n <<= 1) < count; odd ^= 1);
// Use or create temporary array b for merging
int[] b; // temp array; alternates with a
int ao, bo; // array offsets from 'left'
int blen = right - left; // space needed for b
if (work == null || workLen < blen || workBase + blen > work.length) {
work = new int[blen];
workBase = 0;
}
if (odd == 0) {
System.arraycopy(a, left, work, workBase, blen);
b = a;
bo = 0;
a = work;
ao = workBase - left;
} else {
b = work;
ao = 0;
bo = workBase - left;
}
// Merging
for (int last; count > 1; count = last) {
for (int k = (last = 0) + 2; k <= count; k += 2) {
int hi = run[k], mi = run[k - 1];
for (int i = run[k - 2], p = i, q = mi; i < hi; ++i) {
if (q >= hi || p < mi && a[p + ao] <= a[q + ao]) {
b[i + bo] = a[p++ + ao];
} else {
b[i + bo] = a[q++ + ao];
}
}
run[++last] = hi;
}
if ((count & 1) != 0) {
for (int i = right, lo = run[count - 1]; --i >= lo;
b[i + bo] = a[i + ao]
);
run[++last] = right;
}
int[] t = a; a = b; b = t;
int o = ao; ao = bo; bo = o;
}
}
/**
* Sorts the specified range of the array by Dual-Pivot Quicksort.
*
* @param a the array to be sorted
* @param left the index of the first element, inclusive, to be sorted
* @param right the index of the last element, inclusive, to be sorted
* @param leftmost indicates if this part is the leftmost in the range
*/
private static void sort(int[] a, int left, int right, boolean leftmost) {
int length = right - left + 1;
// Use insertion sort on tiny arrays
if (length < INSERTION_SORT_THRESHOLD) {
if (leftmost) {
/*
* Traditional (without sentinel) insertion sort,
* optimized for server VM, is used in case of
* the leftmost part.
*/
for (int i = left, j = i; i < right; j = ++i) {
int ai = a[i + 1];
while (ai < a[j]) {
a[j + 1] = a[j];
if (j-- == left) {
break;
}
}
a[j + 1] = ai;
}
} else {
/*
* Skip the longest ascending sequence.
*/
do {
if (left >= right) {
return;
}
} while (a[++left] >= a[left - 1]);
/*
* Every element from adjoining part plays the role
* of sentinel, therefore this allows us to avoid the
* left range check on each iteration. Moreover, we use
* the more optimized algorithm, so called pair insertion
* sort, which is faster (in the context of Quicksort)
* than traditional implementation of insertion sort.
*/
for (int k = left; ++left <= right; k = ++left) {
int a1 = a[k], a2 = a[left];
if (a1 < a2) {
a2 = a1; a1 = a[left];
}
while (a1 < a[--k]) {
a[k + 2] = a[k];
}
a[++k + 1] = a1;
while (a2 < a[--k]) {
a[k + 1] = a[k];
}
a[k + 1] = a2;
}
int last = a[right];
while (last < a[--right]) {
a[right + 1] = a[right];
}
a[right + 1] = last;
}
return;
}
// Inexpensive approximation of length / 7
int seventh = (length >> 3) + (length >> 6) + 1;
/*
* Sort five evenly spaced elements around (and including) the
* center element in the range. These elements will be used for
* pivot selection as described below. The choice for spacing
* these elements was empirically determined to work well on
* a wide variety of inputs.
*/
int e3 = (left + right) >>> 1; // The midpoint
int e2 = e3 - seventh;
int e1 = e2 - seventh;
int e4 = e3 + seventh;
int e5 = e4 + seventh;
// Sort these elements using insertion sort
if (a[e2] < a[e1]) { int t = a[e2]; a[e2] = a[e1]; a[e1] = t; }
if (a[e3] < a[e2]) { int t = a[e3]; a[e3] = a[e2]; a[e2] = t;
if (t < a[e1]) { a[e2] = a[e1]; a[e1] = t; }
}
if (a[e4] < a[e3]) { int t = a[e4]; a[e4] = a[e3]; a[e3] = t;
if (t < a[e2]) { a[e3] = a[e2]; a[e2] = t;
if (t < a[e1]) { a[e2] = a[e1]; a[e1] = t; }
}
}
if (a[e5] < a[e4]) { int t = a[e5]; a[e5] = a[e4]; a[e4] = t;
if (t < a[e3]) { a[e4] = a[e3]; a[e3] = t;
if (t < a[e2]) { a[e3] = a[e2]; a[e2] = t;
if (t < a[e1]) { a[e2] = a[e1]; a[e1] = t; }
}
}
}
// Pointers
int less = left; // The index of the first element of center part
int great = right; // The index before the first element of right part
if (a[e1] != a[e2] && a[e2] != a[e3] && a[e3] != a[e4] && a[e4] != a[e5]) {
/*
* Use the second and fourth of the five sorted elements as pivots.
* These values are inexpensive approximations of the first and
* second terciles of the array. Note that pivot1 <= pivot2.
*/
int pivot1 = a[e2];
int pivot2 = a[e4];
/*
* The first and the last elements to be sorted are moved to the
* locations formerly occupied by the pivots. When partitioning
* is complete, the pivots are swapped back into their final
* positions, and excluded from subsequent sorting.
*/
a[e2] = a[left];
a[e4] = a[right];
/*
* Skip elements, which are less or greater than pivot values.
*/
while (a[++less] < pivot1);
while (a[--great] > pivot2);
/*
* Partitioning:
*
* left part center part right part
* +--------------------------------------------------------------+
* | < pivot1 | pivot1 <= && <= pivot2 | ? | > pivot2 |
* +--------------------------------------------------------------+
* ^ ^ ^
* | | |
* less k great
*
* Invariants:
*
* all in (left, less) < pivot1
* pivot1 <= all in [less, k) <= pivot2
* all in (great, right) > pivot2
*
* Pointer k is the first index of ?-part.
*/
outer:
for (int k = less - 1; ++k <= great; ) {
int ak = a[k];
if (ak < pivot1) { // Move a[k] to left part
a[k] = a[less];
/*
* Here and below we use "a[i] = b; i++;" instead
* of "a[i++] = b;" due to performance issue.
*/
a[less] = ak;
++less;
} else if (ak > pivot2) { // Move a[k] to right part
while (a[great] > pivot2) {
if (great-- == k) {
break outer;
}
}
if (a[great] < pivot1) { // a[great] <= pivot2
a[k] = a[less];
a[less] = a[great];
++less;
} else { // pivot1 <= a[great] <= pivot2
a[k] = a[great];
}
/*
* Here and below we use "a[i] = b; i--;" instead
* of "a[i--] = b;" due to performance issue.
*/
a[great] = ak;
--great;
}
}
// Swap pivots into their final positions
a[left] = a[less - 1]; a[less - 1] = pivot1;
a[right] = a[great + 1]; a[great + 1] = pivot2;
// Sort left and right parts recursively, excluding known pivots
sort(a, left, less - 2, leftmost);
sort(a, great + 2, right, false);
/*
* If center part is too large (comprises > 4/7 of the array),
* swap internal pivot values to ends.
*/
if (less < e1 && e5 < great) {
/*
* Skip elements, which are equal to pivot values.
*/
while (a[less] == pivot1) {
++less;
}
while (a[great] == pivot2) {
--great;
}
/*
* Partitioning:
*
* left part center part right part
* +----------------------------------------------------------+
* | == pivot1 | pivot1 < && < pivot2 | ? | == pivot2 |
* +----------------------------------------------------------+
* ^ ^ ^
* | | |
* less k great
*
* Invariants:
*
* all in (*, less) == pivot1
* pivot1 < all in [less, k) < pivot2
* all in (great, *) == pivot2
*
* Pointer k is the first index of ?-part.
*/
outer:
for (int k = less - 1; ++k <= great; ) {
int ak = a[k];
if (ak == pivot1) { // Move a[k] to left part
a[k] = a[less];
a[less] = ak;
++less;
} else if (ak == pivot2) { // Move a[k] to right part
while (a[great] == pivot2) {
if (great-- == k) {
break outer;
}
}
if (a[great] == pivot1) { // a[great] < pivot2
a[k] = a[less];
/*
* Even though a[great] equals to pivot1, the
* assignment a[less] = pivot1 may be incorrect,
* if a[great] and pivot1 are floating-point zeros
* of different signs. Therefore in float and
* double sorting methods we have to use more
* accurate assignment a[less] = a[great].
*/
a[less] = pivot1;
++less;
} else { // pivot1 < a[great] < pivot2
a[k] = a[great];
}
a[great] = ak;
--great;
}
}
}
// Sort center part recursively
sort(a, less, great, false);
} else { // Partitioning with one pivot
/*
* Use the third of the five sorted elements as pivot.
* This value is inexpensive approximation of the median.
*/
int pivot = a[e3];
/*
* Partitioning degenerates to the traditional 3-way
* (or "Dutch National Flag") schema:
*
* left part center part right part
* +-------------------------------------------------+
* | < pivot | == pivot | ? | > pivot |
* +-------------------------------------------------+
* ^ ^ ^
* | | |
* less k great
*
* Invariants:
*
* all in (left, less) < pivot
* all in [less, k) == pivot
* all in (great, right) > pivot
*
* Pointer k is the first index of ?-part.
*/
for (int k = less; k <= great; ++k) {
if (a[k] == pivot) {
continue;
}
int ak = a[k];
if (ak < pivot) { // Move a[k] to left part
a[k] = a[less];
a[less] = ak;
++less;
} else { // a[k] > pivot - Move a[k] to right part
while (a[great] > pivot) {
--great;
}
if (a[great] < pivot) { // a[great] <= pivot
a[k] = a[less];
a[less] = a[great];
++less;
} else { // a[great] == pivot
/*
* Even though a[great] equals to pivot, the
* assignment a[k] = pivot may be incorrect,
* if a[great] and pivot are floating-point
* zeros of different signs. Therefore in float
* and double sorting methods we have to use
* more accurate assignment a[k] = a[great].
*/
a[k] = pivot;
}
a[great] = ak;
--great;
}
}
/*
* Sort left and right parts recursively.
* All elements from center part are equal
* and, therefore, already sorted.
*/
sort(a, left, less - 1, leftmost);
sort(a, great + 1, right, false);
}
}
这段函数就是本章讲解的全部内容,其中涉及的排序算法有:TimSort 、插入排序、双插入排序、双轴快排、荷兰国旗问题。作为本章的第一篇文章,在本文中我们只梳理整体的脉络,旨在帮助读者理清整体的排序思路。细节部分将在之后的几篇文章中渐次展开。
先看第一行:
static void sort(int[] a, int left, int right,
int[] work, int workBase, int workLen) {
// Use Quicksort on small arrays
if (right - left < QUICKSORT_THRESHOLD) {
sort(a, left, right, true);
return;
}
//...
}
这里判断了数据的规模,QUICKSORT_THRESHOLD 是一个常数,它的值是 286。这个值可能是通过大量的测试用例得出的,当数据的规模小于 286 时,通过调用 sort(int[] a, int left, int right, boolean leftmost) 函数完成排序。在 sort(int[] a, int left, int right, boolean leftmost) 函数中就包含了双轴快排的具体实现。
那么数据规模大于或等于 286 时采用的是何种排序算法呢?难道我们现在读的这个 DualPivotQuicksort.sort(int[] a, int left, int right, int[] work, int workBase, int workLen) 函数不是用的双轴快排吗?
带着这个问题,我们接着往下读源码:
int[] run = new int[MAX_RUN_COUNT + 1];
这里声明了一个长度为 MAX_RUN_COUNT + 1 的 int 数组,命名为 run ,MAX_RUN_COUNT 是常数 67。run 是 TimSort 中的一个概念。当数据量达到 286 时,就会尝试采用类似 TimSort 的一个算法来进行排序。
TimSort
TimSort 是 Tim Peters 在 2002 年提出的一种算法,它是归并算法的一种改进算法。
回顾一下归并排序的思想:通过不断合并两个有序子数组,完成整个数组的排序。
为了得到两个有序子数组,我们先将整个数组不断一分为二,拆分至每个子数组只剩下一个元素时,我们就认为这个子数组是有序的。
TimSort 优化了归并排序拆分出子数组的过程。TimSort 的主要思想是:通过遍历数组,将数组拆分成若干个单调递增的子数组。每一块称为一个 run 。拆分完成后,再将 run 两两合并起来。
在遍历数组时,如果遇到单调递减的小块,TimSort 会将其翻转使其单调递增。
举几个例子:
- 对于数组 [1,4,2,3],
TimSort 会将其拆分为两个 run : [1,4]、[2,3]
- 对于数组 [3,4,5,1],
TimSort 会将其拆分为两个 run : [3,4,5]、[1]
- 对于数组 [3,2,1,4,5],
TimSort 会将其拆分为两个 run : [1,2,3]、[4,5],其中第一个 run 是由 [3,2,1] 翻转而来。
细心的读者可能会问了,这样的拆分方式一定比归并排序的拆分方式更优秀吗?
答案是不一定。我们很容易找到反例,比如数组 [5,2,6,3,7,1],TimSort 会将其拆分成 [2,5]、[3,6]、[1,7],而且每次拆分都会经过一次翻转,这个拆分过程比归并排序的拆分过程要耗时不少。
那么为什么说 TimSort 是归并排序的优化呢?优化的点在哪里?
答案是 TimSort 在对部分有序的数组进行排序时,速度很快。而现实世界中的数据往往总是部分有序的。比如:
- 一个年级的多个班统计成绩,每个班的成绩已经排好序,最后需要将每个班的成绩表综合起来排出全年级排名。
- 商场统计产品销量时,每家商店的产品销量已经排好序,需要将所有商店的产品销量综合起来找出畅销商品总排行。
TimSort 非常适合处理这类场景,因为整个数组可以拆分成少量的 run 小块,将其合并即可完成排序。我们称这一类只能被拆分成少量 run 小块的数组是「高度结构化」的 (highly structured )。
在 sort(int[] a, int left, int right, int[] work, int workBase, int workLen) 函数中,如果数组是高度结构化的(即数组只能被拆分成少量 run 小块),则采用类似 TimSort 的算法进行排序,否则调用 sort(int[] a, int left, int right, boolean leftmost) 函数进行排序。
为什么说是类似 TimSort 的算法呢?
因为在完整的 TimSort 算法中,为了提升合并 run 小块时的效率,在拆分时,并不是简单的将数组划分为单调递增的小块,而是设定了一些拆分规则,使得每一个 run 小块的长度都比较接近,不至于相差太大导致合并时需要拷贝大量的「尾巴」。但这里的 TimSort 只是将数组划分为单调递增的小块就开始合并了,相当于 TimSort 的简化版。
在拆分 run 小块的过程中,有两个条件会停止调用 TimSort ,改为调用 sort(int[] a, int left, int right, boolean leftmost) 函数进行排序,我们看一下这两个跳出 TimSort 的地方:
int count = 0; run[0] = left;
// Check if the array is nearly sorted
for (int k = left; k < right; run[count] = k) {
if (a[k] < a[k + 1]) { // ascending
while (++k <= right && a[k - 1] <= a[k]);
} else if (a[k] > a[k + 1]) { // descending
while (++k <= right && a[k - 1] >= a[k]);
for (int lo = run[count] - 1, hi = k; ++lo < --hi; ) {
int t = a[lo]; a[lo] = a[hi]; a[hi] = t;
}
} else { // equal
for (int m = MAX_RUN_LENGTH; ++k <= right && a[k - 1] == a[k]; ) {
if (--m == 0) {
sort(a, left, right, true);
return;
}
}
}
/*
* The array is not highly structured,
* use Quicksort instead of merge sort.
*/
if (++count == MAX_RUN_COUNT) {
sort(a, left, right, true);
return;
}
}
这段代码的逻辑是:以 k 为索引遍历整个数组,比较相邻数字的大小关系:
- 如果数字在递增,记录下此
run 小块
- 如果数字在递减,记录下此
run 小块,并翻转这部分子数组
- 如果数字相等,记录下此
run 小块,并判断连续相等的数字是否达到了 MAX_RUN_LENGTH 个,如果达到了 MAX_RUN_LENGTH ,则不再使用 TimSort ,改为调用 sort(int[] a, int left, int right, boolean leftmost) 函数进行排序。
MAX_RUN_LENGTH 是常数 33,为什么连续相等的数字达到了 33 个就要改为调用 sort(int[] a, int left, int right, boolean leftmost) 函数呢?
这是因为 sort(int[] a, int left, int right, boolean leftmost) 函数中,对于数组中存在较多相等元素的场景做了单独的优化。所以如果连续相等的数字过多,使用 sort(int[] a, int left, int right, boolean leftmost) 函数排序的速度会比 TimSort 快一些。
第二个跳出位置与 count 数量有关。每记录一个 run 小块,count 的数量就加 1,当 count 数量达到了 MAX_RUN_COUNT 后,则不再使用 TimSort ,改为调用 sort(int[] a, int left, int right, boolean leftmost) 函数进行排序。
上文说到,TimSort 仅适用于「高度结构化」的数组,MAX_RUN_COUNT 的值是 67,如果 run 小块数量达到这个值,我们就认为该数组不是「高度结构化」的,不适合采用 TimSort 排序。
到这里,我们就完成了第一步的阅读,总结一下:当数组长度达到 286 ,并且不存在较多连续相等元素,并且「高度结构化」时,采用类似 TimSort 的算法进行排序,这是一种归并排序的优化算法,在数组「高度结构化」时,排序效率优于归并排序。否则调用 sort(int[] a, int left, int right, boolean leftmost) 函数进行排序。
插入排序 & 双插入排序
接下来看一下 sort(int[] a, int left, int right, boolean leftmost) 函数的实现:
private static void sort(int[] a, int left, int right, boolean leftmost) {
int length = right - left + 1;
// Use insertion sort on tiny arrays
if (length < INSERTION_SORT_THRESHOLD) {
if (leftmost) {
/*
* Traditional (without sentinel) insertion sort,
* optimized for server VM, is used in case of
* the leftmost part.
*/
for (int i = left, j = i; i < right; j = ++i) {
int ai = a[i + 1];
while (ai < a[j]) {
a[j + 1] = a[j];
if (j-- == left) {
break;
}
}
a[j + 1] = ai;
}
} else {
/*
* Skip the longest ascending sequence.
*/
do {
if (left >= right) {
return;
}
} while (a[++left] >= a[left - 1]);
/*
* Every element from adjoining part plays the role
* of sentinel, therefore this allows us to avoid the
* left range check on each iteration. Moreover, we use
* the more optimized algorithm, so called pair insertion
* sort, which is faster (in the context of Quicksort)
* than traditional implementation of insertion sort.
*/
for (int k = left; ++left <= right; k = ++left) {
int a1 = a[k], a2 = a[left];
if (a1 < a2) {
a2 = a1; a1 = a[left];
}
while (a1 < a[--k]) {
a[k + 2] = a[k];
}
a[++k + 1] = a1;
while (a2 < a[--k]) {
a[k + 1] = a[k];
}
a[k + 1] = a2;
}
int last = a[right];
while (last < a[--right]) {
a[right + 1] = a[right];
}
a[right + 1] = last;
}
return;
}
// ...
}
首先仍然是判断数组长度,如果小于 INSERTION_SORT_THRESHOLD (它的值是 47),则采用插入排序算法或双插入排序算法进行排序。这是因为插入排序在数据量小的时候,排序性能比较好。
双插入排序算法每轮从待处理的数字中取两个数字插入前方已有序的数组中,它比普通的插入排序算法更快一些。
如果数据量达到了 47,则采用双轴快排算法进行排序。
到此,我们就将 Arrays.sort() 函数梳理得差不多了。
总结
- 在对基本数据类型的数组排序时,
Arrays.sort() 函数通过调用 DualPivotQuicksort.sort() 完成排序;
- 当数组长度达到 286 ,并且不存在较多连续相等元素,并且「高度结构化」时,采用类似
TimSort 的算法进行排序;
- 当数组长度小于
INSERTION_SORT_THRESHOLD (即 47)时,采用插入排序或双插入排序;
- 否则采用双轴快排进行排序。
|