目录
- 1. 题目理解
- 2. 解题思路分析
- 2.1 方法一:最小堆(优先队列)
- 2.2 方法二:二分查找(值域)
- 3. 代码实现
- 3.1 方法一:最小堆实现
- 3.2 方法二:二分查找实现【⭐】
- 4. 两种方法对比与总结
- 4.1 为什么不能直接 flatten(扁平化)排序?
- 4.2 总结
1. 题目理解
输入:
- 一个 \(n \times n\) 的矩阵 matrix。
- 一个整数 k。
关键特性:
- 矩阵的每一行从左到右是升序排列的。
- 矩阵的每一列从上到下是升序排列的。
目标:
找到矩阵中所有元素排序后的第 \(k\) 小元素(注意:如果有重复元素,重复的也要算进去。例如 [1, 1, 2] 中第 2 小的是 1,第 3 小的是 2)。
限制条件:
- 内存复杂度必须优于 \(O(n^2)\)。这意味着我们不能简单地把所有元素拿出来排序(因为那样需要 \(O(n^2)\) 的空间存储所有元素)。
- \(n\) 最大为 300。
2. 解题思路分析
这道题主要有两种经典的解法,分别利用了堆(优先队列)和二分查找。
2.1 方法一:最小堆(优先队列)
核心思想:
既然每一行都是有序的,那么每一行的第一个元素一定是该行最小的。我们可以把这 \(n\) 行的“当前最小元素”都放入一个最小堆中。
- 初始时,将每一行的第一个元素放入堆中。
- 每次从堆中弹出最小的元素,这就是当前全局未访问元素中的最小值。
- 弹出后,将该元素所在行的下一个元素放入堆中(如果该行还有元素的话)。
- 重复上述操作 \(k\) 次,第 \(k\) 次弹出的元素就是答案。
图解:
假设矩阵如下,k=8:
- 堆初始化:[1, 10, 12] (每行第一个)
- 弹出 1 (第 1 小),推入 5。堆:[5, 10, 12]
- 弹出 5 (第 2 小),推入 9。堆:[9, 10, 12]
- 弹出 9 (第 3 小),该行无后续。堆:[10, 12]
- 弹出 10 (第 4 小),推入 11。堆:[11, 12]
- 弹出 11 (第 5 小),推入 13。堆:[12, 13]
- 弹出 12 (第 6 小),推入 13。堆:[13, 13]
- 弹出 13 (第 7 小),推入 15。堆:[13, 15]
- 弹出 13 (第 8 小) -> 返回 13
复杂度分析:
- 时间复杂度:\(O(k \log n)\)。我们需要执行 \(k\) 次弹出操作,每次堆调整需要 \(\log n\)(堆的大小最大为 \(n\))。
- 空间复杂度:\(O(n)\)。堆中最多存储 \(n\) 个元素(每行一个)。满足题目优于 \(O(n^2)\) 的要求。
2.2 方法二:二分查找(值域)
核心思想:
我们不是对“位置”进行二分,而是对“数值范围”进行二分。
- 矩阵中最小值是 matrix[0][0],最大值是 matrix[n-1][n-1]。答案一定在这个范围内。
- 我们猜测一个中间值 mid。
- 统计矩阵中有多少个元素 小于等于 mid。
- 如果数量 \(< k\),说明 mid 太小了,答案在右半部分 (left = mid + 1)。
- 如果数量 \(\ge k\),说明 mid 可能是答案,或者答案在左半部分 (right = mid - 1),我们需要记录 mid 并继续尝试更小的值。
- 当 left > right 时,记录的最后一次满足条件的 mid 即为答案。
关键难点:如何在 \(O(n)\) 时间内统计小于等于 mid 的元素个数?
利用矩阵行列有序的特性,我们可以从左下角开始搜索:
- 设当前位置为 (row, col),初始为 (n-1, 0)。
- 如果 matrix[row][col] mid:
- 说明当前元素太大了,需要找更小的数。
- 我们向上移动 (row -= 1)。
- 这样只需要走 \(2n\) 步即可完成统计。
复杂度分析:
- 时间复杂度:\(O(n \log(\text{max} - \text{min}))\)。二分查找的次数取决于数值范围,每次统计需要 \(O(n)\)。
- 空间复杂度:\(O(1)\)。只需要几个变量,不需要额外空间。这是最优的空间解法。
3. 代码实现
下面提供两种方法的 Python 3 代码。代码中包含了详细的注释。
3.1 方法一:最小堆实现
- from typing import List
- import heapq
- class Solution:
- def kthSmallest(self, matrix: List[List[int]], k: int) -> int:
- """
- 方法一:最小堆 (Priority Queue)
- 时间复杂度:O(k * log n)
- 空间复杂度:O(n)
- """
- n = len(matrix)
-
- # 最小堆,存储元组 (数值,行索引,列索引)
- # 初始化:将每一行的第一个元素放入堆中
- min_heap = []
- for r in range(n):
- # 推入 (值,行号,列号)
- heapq.heappush(min_heap, (matrix[r][0], r, 0))
-
- # 执行 k 次弹出操作
- # 第 k 次弹出的元素即为第 k 小的元素
- for _ in range(k):
- # 弹出当前堆中最小的元素
- val, r, c = heapq.heappop(min_heap)
-
- # 如果这是第 k 次弹出,直接返回
- # 注意:循环是从 0 到 k-1,所以当 _ == k-1 时是第 k 次
- if _ == k - 1:
- return val
-
- # 将该元素所在行的下一个元素推入堆中
- # 确保不越界
- if c + 1 < n:
- heapq.heappush(min_heap, (matrix[r][c + 1], r, c + 1))
-
- return -1 # 理论上不会执行到这里
复制代码 3.2 方法二:二分查找实现【⭐】
[code]from typing import Listclass Solution: def kthSmallest(self, matrix: List[List[int]], k: int) -> int: """ 方法二:二分查找 (Binary Search on Value) 时间复杂度:O(n * log(max - min)) 空间复杂度:O(1) """ n = len(matrix) # 确定二分查找的上下界 # 最小值是矩阵左上角,最大值是矩阵右下角 left, right = matrix[0][0], matrix[n - 1][n - 1] # 辅助函数:统计矩阵中小于等于 target 的元素个数 def countLessEqual(target: int) -> int: count = 0 # 从左下角开始搜索 row, col = n - 1, 0 # 只要还在矩阵范围内 while row >= 0 and col < n: if matrix[row][col] |