LeetCode - Kth Smallest Element in a Sorted Matrix
Problem statement
Given an n x n
matrix
where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.
Note that it is the kth
smallest element in the sorted order, not the kth
distinct element.
You must find a solution with a memory complexity better than O(n^2)
.
Problem statement taken from: https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix
Example 1:
Input: matrix = [[1, 5, 9], [10, 11, 13], [12, 13, 15]], k = 8
Output: 13
Explanation: The elements in the matrix are [1, 5, 9, 10, 11, 12, 13, 13, 15], and the 8th smallest number is 13
Example 2:
Input: matrix = [[-5]], k = 1
Output: -5
Constraints:
- n == matrix.length == matrix[i].length
- 1 <= n <= 300
- -10^9 <= matrix[i][j] <= 10^9
- All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order.
- 1 <= k <= n2
Follow up:
Could you solve the problem with a constant memory (i.e.,
O(1)
memory complexity)?Could you solve the problem in
O(n)
time complexity? The solution may be too advanced for an interview but you may find reading this paper fun.
Explanation
Brute force approach
The easiest approach is to use an additional 1D array. We iterate the matrix row or column wise and store the elements in an array. We sort the array and return the k - 1
element of the array.
A C++ snippet of the above approach is as follows:
int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
int m = matrix[0].size();
vector<int> array;
for(int i = 0; i < n; i++) {
for(int j = 0; j < m; j++) {
array.push_back(matrix[i][j]);
}
}
sort(array.begin(), array.end());
return array[k - 1];
}
The time and space complexity of the above approach is O(n * m).
Priority Queue
C++ has an inbuilt priority queue data structure. Priority queues are implemented using heaps, and in maximum priority queues, the maximum element is always at the top of the heap.
We store all the elements in the priority queue. The maximum element will be at the top of the priority queue. We keep using the pop function n * m - k times. pop function removes the top element from the heap. After n * m - k times the topmost element will be the kth smallest element in the queue.
A C++ snippet of the above approach is as follows:
int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
int m = matrix[0].size();
priority_queue<int> pq;
for(int i = 0; i < n; i++) {
for(int j = 0; j < m; j++) {
pq.push(input[i][j]);
}
}
for(int i = 0; i < n * m - k; i++){
pq.pop();
}
return pq.top();
}
The time and space complexity of this approach is O(n * m).
Binary Search
The input matrix is sorted row-wise and column-wise. We can utilize this information to reduce our time and space complexity of the algorithm.
The idea is to divide the matrix into two parts using the middle element as a point of reference. We count the number of elements that are less than or equal to the middle element. If the count is smaller than k, we check the 2nd half of the matrix and if count is greater than k, we check the 1st half.
Let's check the algorithm for this approach.
Algorithm
// kthSmallest function
- set n = matrix.size(), m = matrix[0].size()
low = matrix[0][0], high = matrix[n - 1][m - 1]
initialize mid, greaterThanOrEqualMid
- loop while low <= high
- set mid = low + (high - low) / 2
- set greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
- if greaterThanOrEqualMid >= k
- update high = mid - 1
- else
- update low = mid + 1
- if end
- while end
- return low
// getElementsGreaterThanOrEqualMid function
- set count = 0
initialize greaterThanMid
- loop for i = 0; i < n; i++
- if matrix[i][0] > mid
- return count
- if matrix[i][n - 1] <= mid
- count = count + n
- continue
- set greaterThanMid = 0
- loop for j = n / 2; j >= 1; j /= 2
- loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
- update greaterThanMid = greaterThanMid + j
- while end
- for end
- set count = count + greaterThanMid + 1
- for end
- return count
The time complexity of this approach is O(y * n * log(n))
, where y = log(abs(matrix[0][0] - matrix[n - 1][n - 1]))
. The space complexity is O(1)
.
C++ solution
class Solution {
public:
int getElementsGreaterThanOrEqualMid(vector<vector<int>>& matrix, int n, int mid) {
int count = 0;
int greaterThanMid;
for(int i = 0; i < n; i++) {
if(matrix[i][0] > mid) {
return count;
}
if(matrix[i][n - 1] <= mid) {
count += n;
continue;
}
greaterThanMid = 0;
for(int j = n / 2; j >= 1; j /= 2) {
while(greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid) {
greaterThanMid += j;
}
}
count += greaterThanMid + 1;
}
return count;
}
int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
int m = matrix[0].size();
int low = matrix[0][0];
int high = matrix[n - 1][m - 1];
int mid, greaterThanOrEqualMid;
while(low <= high) {
mid = low + (high - low) / 2;
greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid);
if (greaterThanOrEqualMid >= k)
high = mid - 1;
else
low = mid + 1;
}
return low;
}
};
Golang solution
func getElementsGreaterThanOrEqualMid(matrix [][]int, n, mid int) int {
count, greaterThanMid := 0, 0
for i := 0; i < n; i++ {
if matrix[i][0] > mid {
return count
}
if matrix[i][n - 1] <= mid {
count += n
continue
}
greaterThanMid = 0
for j := n / 2; j >= 1; j /= 2 {
for greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid {
greaterThanMid += j
}
}
count += greaterThanMid + 1
}
return count
}
func kthSmallest(matrix [][]int, k int) int {
n, m := len(matrix), len(matrix[0])
low, high := matrix[0][0], matrix[n - 1][m - 1]
mid, greaterThanOrEqualMid := 0, 0
for low <= high {
mid = low + (high - low) / 2
greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
if greaterThanOrEqualMid >= k {
high = mid - 1
} else {
low = mid + 1
}
}
return low
}
JavaScript solution
var getElementsGreaterThanOrEqualMid = function(matrix, n, mid) {
let count = 0, greaterThanMid = 0;
for(let i = 0; i < n; i++) {
if(matrix[i][0] > mid) {
return count;
}
if(matrix[i][n - 1] <= mid) {
count += n;
continue;
}
greaterThanMid = 0;
for(let j = n / 2; j >= 1; j /= 2) {
while(greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid) {
greaterThanMid += j;
}
}
count += greaterThanMid + 1;
}
return count;
}
var kthSmallest = function(matrix, k) {
let n = matrix.length, m = matrix[0].length;
let low = matrix[0][0], high = matrix[n - 1][m - 1];
let mid, greaterThanOrEqualMid;
while(low <= high) {
mid = low + parseInt((high - low) / 2, 10);
greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid);
if(greaterThanOrEqualMid >= k) {
high = mid - 1;
} else {
low = mid + 1;
}
}
return low;
};
Dry Run
Let's dry-run our algorithm to see how the solution works.
Input: matrix = [[1, 5, 9], [10, 11, 13], [12, 13, 15]]
k = 8
// kthSmallest function
Step 1: n = matrix.size()
= 3
m = matrix[0].size()
= 3
low = matrix[0][0]
= 1
high = matrix[n - 1][m - 1]
= matrix[3 - 1][3 - 1]
= matrix[2][2]
= 15
int mid, greaterThanOrEqualMid
Step 2: loop while low <= high
1 <= 15
true
mid = low + (high - low) / 2
= 1 + (15 - 1) / 2
= 1 + 7
= 8
greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= getElementsGreaterThanOrEqualMid(matrix, 3, 8)
// getElementsGreaterThanOrEqualMid function
Step 3: count = 0
greaterThanMid
loop for i = 0; i < n; i++
if matrix[i][0] > mid
matrix[0][0] > 8
1 > 8
false
if matrix[i][n - 1] <= mid
matrix[0][2] <= 8
9 <= 8
false
greaterThanMid = 0
loop for j = n / 2; j >= 1; j /= 2
j = 3/2 = 1
j >= 1
1 >= 1
true
loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
0 + 1 < 3 && matrix[0][0 + 1] <= mid
1 < 3 && 5 <= 8
true
greaterThanMid = greaterThanMid + j
= 0 + 1
= 1
loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
1 + 1 < 3 && matrix[0][1 + 1] <= mid
2 < 3 && 9 <= 8
false
j = j / 2
= 1 / 2
= 0
loop for j = n / 2; j >= 1
0 >= 1
false
count = count + greaterThanMid + 1
= 0 + 1 + 1
= 2
i++
i = 1
loop for i < n
1 < 3
true
if matrix[i][0] > mid
matrix[1][0] > 8
10 > 8
true
return count
return 2
// kthSmallest function
Step 4: greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= 2
if greaterThanOrEqualMid >= k
2 >= 8
false
else
low = mid + 1
= 8 + 1
= 9
Step 5: loop while low <= high
9 <= 15
true
mid = low + (high - low) / 2
= 9 + (15 - 9) / 2
= 9 + 3
= 12
greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= getElementsGreaterThanOrEqualMid(matrix, 3, 12)
// getElementsGreaterThanOrEqualMid function
Step 6: count = 0
greaterThanMid
loop for i = 0; i < n; i++
0 < 3
true
if matrix[i][0] > mid
matrix[0][0] > 12
1 > 12
false
if matrix[i][n - 1] <= mid
matrix[0][2] <= 12
9 <= 12
true
count = count + n
= 0 + 3
= 3
continue
i++
i = 1
loop for i < n
1 < 3
true
if matrix[i][0] > mid
matrix[1][0] > 12
10 > 12
false
if matrix[i][n - 1] <= mid
matrix[1][2] <= 12
13 <= 12
false
greaterThanMid = 0
loop for j = n / 2; j >= 1; j /= 2
j = 3/2 = 1
j >= 1
1 >= 1
true
loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
0 + 1 < 3 && matrix[0][0 + 1] <= 12
1 < 3 && 5 <= 12
true
greaterThanMid = greaterThanMid + j
= 0 + 1
= 1
loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
1 + 1 < 3 && matrix[1][1 + 1] <= 12
2 < 3 && 13 <= 12
false
count = count + greaterThanMid + 1
= 3 + 1 + 1
= 5
i++
i = 2
loop for i < n
2 < 3
true
if matrix[i][0] > mid
matrix[2][0] > 12
12 > 12
false
if matrix[i][n - 1] <= mid
matrix[2][2] <= 12
15 <= 12
false
greaterThanMid = 0
loop for j = n / 2; j >= 1; j /= 2
j = 3/2 = 1
j >= 1
1 >= 1
true
loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
0 + 1 < 3 && matrix[2][0 + 1] <= 12
1 < 3 && 13 <= 12
false
j = j / 2
= 1 / 2
= 0
loop for j >= 1
0 >= 1
false
count = count + greaterThanMid + 1
= 5 + 0 + 1
= 6
i++
i = 3
loop for i < n
3 < 3
false
return count
return 6
// kthSmallest function
Step 7: greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= 6
if greaterThanOrEqualMid >= k
6 >= 8
false
else
low = mid + 1
= 12 + 1
= 13
Step 8: loop while low <= high
13 <= 15
true
mid = low + (high - low) / 2
= 13 + (15 - 13) / 2
= 13 + 1
= 14
greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= getElementsGreaterThanOrEqualMid(matrix, 3, 14)
We return the above step and return the answer as 13.