215. 数组中的第 K 个最大元素

LeetCode 原题链接

1. 题目描述

给定一个整数数组 nums 和一个整数 k,返回数组中第 k 个最大的元素。

注意这里要找的是排序后的第 k 大元素,不是第 k 个不同的元素。也就是说,重复元素会参与排名。

示例 1:

输入:nums = [3, 2, 1, 5, 6, 4], k = 2
输出:5
解释:降序排列后是 [6, 5, 4, 3, 2, 1],第 2 大元素是 5。

示例 2:

输入:nums = [3, 2, 3, 1, 2, 4, 5, 5, 6], k = 4
输出:4
解释:降序排列后是 [6, 5, 5, 4, 3, 3, 2, 2, 1],第 4 大元素是 4。

2. 题型判断

本题要求在无序数组中找到第 k 个最大的元素,属于典型的 Top K 问题。

直观做法是先排序,再取下标 k - 1 的元素,时间复杂度是 O(n log n)。但题目只需要第 k 大,不需要整个数组有序,因此可以用快速选择。快速选择每次通过一次分区确定一个元素的最终排名,再只递归处理可能包含答案的一侧,平均时间复杂度可以降到 O(n)

3. 核心思路

使用两路快速选择,分区时按照从大到小的顺序摆放元素:

  • 随机选一个 pivot
  • 把大于等于 pivot 的元素放到左侧。
  • 把小于 pivot 的元素放到右侧。
  • 分区结束后,pivot 的位置 p 就表示它在当前区间中的降序排名。

如果 p 正好是目标下标 k - 1,说明 nums[p] 就是第 k 大元素。

如果 p > k - 1,说明第 k 大在左半部分;如果 p < k - 1,说明第 k 大在右半部分。每次只保留一边继续查找。

4. 步骤图解

下面以 nums = [3, 2, 1, 5, 6, 4]k = 2 为例,演示快速选择如何通过分区逐步缩小搜索区间。

图中展示的是两路快速选择的分区过程:先把 pivot 暂存在区间末尾,遍历时把大于等于 pivot 的元素换到左侧,最后再把 pivot 放回它的最终位置。

快速选择查找第 K 个最大元素步骤图解

5. 变量与区间含义

  • targetIndex:第 k 大元素在降序数组中的下标,也就是 k - 1
  • left:当前查找区间的左边界。
  • right:当前查找区间的右边界。
  • pivotIndex:随机选择的基准值下标。
  • storeIndex:下一个大于等于 pivot 的元素应该放入的位置。
  • p:分区结束后,pivot 所在的位置。

快速选择维护的是闭区间 [left, right]。答案一定在这个区间内,每次分区后根据 ptargetIndex 的关系缩小区间。

6. 推进规则

分区规则:

  1. 随机选择 pivotIndex,并把它交换到 right 位置。
  2. 遍历 [left, right - 1],注意 right 位置暂存的是 pivot,不参与循环。
  3. 如果 nums[i] >= pivot,说明它应该排在 pivot 左边,将它交换到 storeIndex,然后 storeIndex++
  4. 遍历结束后,把 pivotright 交换回 storeIndex
  5. 此时 storeIndex 左边的元素都大于等于 pivot,右边的元素都小于 pivot

查找规则:

  • 如果 p === targetIndex,直接返回 nums[p]
  • 如果 p > targetIndex,说明答案在左侧,更新 right = p - 1
  • 如果 p < targetIndex,说明答案在右侧,更新 left = p + 1

7. 边界条件与易错点

  • k 是从 1 开始计数的,数组下标从 0 开始,所以目标下标是 k - 1
  • 分区按降序处理,所以判断条件是 nums[i] >= pivot,不是 nums[i] <= pivot
  • for 循环只遍历到 right - 1,因为 right 位置暂存的是 pivot,最后要单独归位。
  • 数组中可能有大量重复元素,两路分区会缩小很慢;这种情况可以看后面的三路快速选择优化版。
  • 快速选择会修改数组顺序。如果不希望修改原数组,可以先复制一份数组再处理。
  • 随机选择 pivot 可以降低遇到最坏情况的概率。

8. 代码实现

/**
 * @param {number[]} nums
 * @param {number} k
 * @return {number}
 */
var findKthLargest = function (nums, k) {
  // 第 k 大元素在降序数组中的下标是 k - 1。
  // 例如第 1 大对应下标 0,第 2 大对应下标 1。
  const targetIndex = k - 1;

  // 快速选择只在当前闭区间 [left, right] 内查找答案。
  let left = 0;
  let right = nums.length - 1;

  while (left <= right) {
    // partition 会把一个 pivot 放到它在降序排列中的正确位置,
    // 并返回这个位置下标 p。
    const p = partition(nums, left, right);

    if (p === targetIndex) {
      // pivot 的排名正好是第 k 大,直接返回。
      return nums[p];
    } else if (p > targetIndex) {
      // p 太靠右,说明第 k 大在 pivot 左边更大的那一部分。
      right = p - 1;
    } else {
      // p 太靠左,说明左边元素数量不够,第 k 大在 pivot 右边。
      left = p + 1;
    }
  }
};

function partition(nums, left, right) {
  // 随机选择 pivot,可以降低数组已经有序时退化成 O(n^2) 的概率。
  const pivotIndex = left + Math.floor(Math.random() * (right - left + 1));
  [nums[pivotIndex], nums[right]] = [nums[right], nums[pivotIndex]];

  const pivot = nums[right];

  // storeIndex 表示下一个“大于等于 pivot 的元素”应该放入的位置。
  // 循环过程中,[left, storeIndex - 1] 都是已经放好的较大元素。
  let storeIndex = left;

  for (let i = left; i < right; i++) {
    // 本题找第 k 大,所以按降序分区:
    // 大于等于 pivot 的元素放左边,小于 pivot 的元素留在右边。
    if (nums[i] >= pivot) {
      [nums[storeIndex], nums[i]] = [nums[i], nums[storeIndex]];
      storeIndex++;
    }
  }

  // 把 pivot 放到 storeIndex。
  // 此时左边都 >= pivot,右边都 < pivot,pivot 的降序位置已经确定。
  [nums[storeIndex], nums[right]] = [nums[right], nums[storeIndex]];
  return storeIndex;
}

9. 其他解法

解法一:排序

最直观的做法是把数组按降序排序,然后返回下标 k - 1 的元素。

这种写法最短,也很适合先解释题意,但它会把整个数组都排好序,而题目只需要第 k 大元素,所以时间复杂度比快速选择更高。

var findKthLargest = function (nums, k) {
  // 降序排序后,第 k 大元素就在下标 k - 1。
  nums.sort((a, b) => b - a);
  return nums[k - 1];
};

解法二:手写快排

也可以手写快速排序,把数组整体按降序排好,再返回 nums[k - 1]

这种解法和内置排序的思想一样,都会排序整个数组。它的优势是能练习快排分区逻辑,但对本题来说不如快速选择高效,因为快速选择只处理可能包含第 k 大的一侧,快排则需要继续排序两侧。

var findKthLargest = function (nums, k) {
  quickSort(nums, 0, nums.length - 1);
  return nums[k - 1];
};

function quickSort(nums, left, right) {
  if (left >= right) {
    return;
  }

  const pivotIndex = partition(nums, left, right);

  quickSort(nums, left, pivotIndex - 1);
  quickSort(nums, pivotIndex + 1, right);
}

function partition(nums, left, right) {
  // 随机选择 pivot,减少有序数组导致退化的概率。
  const pivotIndex = left + Math.floor(Math.random() * (right - left + 1));
  [nums[pivotIndex], nums[right]] = [nums[right], nums[pivotIndex]];

  const pivot = nums[right];
  let storeIndex = left;

  for (let i = left; i < right; i++) {
    // 降序排序:比 pivot 大的元素放左边。
    if (nums[i] > pivot) {
      [nums[storeIndex], nums[i]] = [nums[i], nums[storeIndex]];
      storeIndex++;
    }
  }

  [nums[storeIndex], nums[right]] = [nums[right], nums[storeIndex]];
  return storeIndex;
}

解法三:三路快速选择

三路快速选择是对两路分区的优化。它每次选一个 pivot,把当前区间分成三段:

  • 左侧:大于 pivot
  • 中间:等于 pivot
  • 右侧:小于 pivot

如果 targetIndex 落在中间等值区间里,就可以直接返回。它特别适合数组中有大量重复元素的情况,因为可以一次跳过所有等于 pivot 的元素。

var findKthLargest = function (nums, k) {
  const targetIndex = k - 1;
  let left = 0;
  let right = nums.length - 1;

  while (left <= right) {
    const [lt, gt] = partition(nums, left, right);

    if (targetIndex < lt) {
      right = lt - 1;
    } else if (targetIndex > gt) {
      left = gt + 1;
    } else {
      return nums[targetIndex];
    }
  }
};

function partition(nums, left, right) {
  const pivotIndex = left + Math.floor(Math.random() * (right - left + 1));
  const pivot = nums[pivotIndex];

  let lt = left;
  let i = left;
  let gt = right;

  while (i <= gt) {
    if (nums[i] > pivot) {
      [nums[lt], nums[i]] = [nums[i], nums[lt]];
      lt++;
      i++;
    } else if (nums[i] < pivot) {
      [nums[i], nums[gt]] = [nums[gt], nums[i]];
      gt--;
    } else {
      i++;
    }
  }

  return [lt, gt];
}

解法四:最小堆

维护一个大小为 k 的最小堆。堆里始终保存当前见过的前 k 大元素,堆顶就是这 k 个元素里最小的那个,也就是当前的第 k 大候选值。

遍历数组时:

  • 如果堆的大小小于 k,直接入堆。
  • 如果当前元素大于堆顶,说明它应该进入前 k 大,弹出堆顶后再入堆。
  • 如果当前元素小于等于堆顶,说明它进不了前 k 大,跳过。

遍历结束后,堆顶就是第 k 大元素。

var findKthLargest = function (nums, k) {
  const heap = new MinHeap();

  for (const num of nums) {
    if (heap.size() < k) {
      heap.push(num);
    } else if (num > heap.peek()) {
      heap.pop();
      heap.push(num);
    }
  }

  return heap.peek();
};

class MinHeap {
  constructor() {
    this.heap = [];
  }

  size() {
    return this.heap.length;
  }

  peek() {
    return this.heap[0];
  }

  push(value) {
    this.heap.push(value);
    this.shiftUp(this.heap.length - 1);
  }

  pop() {
    const top = this.heap[0];
    const last = this.heap.pop();

    if (this.heap.length > 0) {
      this.heap[0] = last;
      this.shiftDown(0);
    }

    return top;
  }

  shiftUp(index) {
    while (index > 0) {
      const parent = Math.floor((index - 1) / 2);

      if (this.heap[parent] <= this.heap[index]) {
        break;
      }

      [this.heap[parent], this.heap[index]] = [this.heap[index], this.heap[parent]];
      index = parent;
    }
  }

  shiftDown(index) {
    const n = this.heap.length;

    while (true) {
      let smallest = index;
      const left = index * 2 + 1;
      const right = index * 2 + 2;

      if (left < n && this.heap[left] < this.heap[smallest]) {
        smallest = left;
      }

      if (right < n && this.heap[right] < this.heap[smallest]) {
        smallest = right;
      }

      if (smallest === index) {
        break;
      }

      [this.heap[index], this.heap[smallest]] = [this.heap[smallest], this.heap[index]];
      index = smallest;
    }
  }
}

最小堆适合数据流场景:如果数字一个个到来,不能一次性拿到完整数组,就可以持续维护大小为 k 的堆。

解法五:计数法

如果题目给出的数值范围很小,也可以统计每个数字出现次数,然后从大到小累计数量,累计到 k 时返回当前数字。

这种方法在值域很小时很快,但如果数值范围很大,例如从 -10^910^9,就不适合直接开数组计数。

var findKthLargest = function (nums, k) {
  const countMap = new Map();

  for (const num of nums) {
    countMap.set(num, (countMap.get(num) || 0) + 1);
  }

  const values = [...countMap.keys()].sort((a, b) => b - a);

  for (const value of values) {
    k -= countMap.get(value);

    if (k <= 0) {
      return value;
    }
  }
};

这里用 Map 后仍然需要对不同的值排序,所以复杂度取决于不同数字的个数。如果值域小到可以直接开计数数组,就可以进一步减少排序成本。

10. 复杂度分析

两路快速选择:

  • 时间复杂度:平均 O(n)。每次分区只处理当前区间,并且平均只继续查找一半区间。
  • 最坏时间复杂度:O(n^2)。如果每次选到的 pivot 都非常偏,或者重复元素很多导致区间缩小很慢,会退化成单边查找。
  • 空间复杂度:O(1)。迭代写法只使用常数个额外变量。

排序写法:

  • 时间复杂度:O(n log n)
  • 空间复杂度:取决于 JavaScript 引擎的排序实现。

手写快排:

  • 平均时间复杂度:O(n log n)
  • 最坏时间复杂度:O(n^2)。随机选择 pivot 可以降低退化概率。
  • 空间复杂度:平均 O(log n),主要来自递归调用栈;最坏情况下是 O(n)

三路快速选择:

  • 平均时间复杂度:O(n)
  • 最坏时间复杂度:O(n^2)。随机选择 pivot 可以降低退化概率;重复元素很多时通常比两路分区更稳。
  • 空间复杂度:O(1)。这里使用迭代写法,只需要常数级额外变量。

最小堆:

  • 时间复杂度:O(n log k)。堆的大小最多是 k,每次入堆或出堆需要 O(log k)
  • 空间复杂度:O(k)

计数法:

  • 使用 Map 加排序时,时间复杂度是 O(n + m log m),其中 m 是不同数字的个数。
  • 空间复杂度是 O(m)