# LeetCode - Kth Smallest Element in a Sorted Matrix

## Problem statement

Given an `n x n` `matrix` where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.

Note that it is the `kth` smallest element in the sorted order, not the `kth` distinct element.

You must find a solution with a memory complexity better than `O(n^2)`.

Problem statement taken from: https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix

Example 1:

``````Input: matrix = [[1, 5, 9], [10, 11, 13], [12, 13, 15]], k = 8
Output: 13
Explanation: The elements in the matrix are [1, 5, 9, 10, 11, 12, 13, 13, 15], and the 8th smallest number is 13
``````

Example 2:

``````Input: matrix = [[-5]], k = 1
Output: -5
``````

Constraints:

``````- n == matrix.length == matrix[i].length
- 1 <= n <= 300
- -10^9 <= matrix[i][j] <= 10^9
- All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order.
- 1 <= k <= n2
``````

• Could you solve the problem with a constant memory (i.e., `O(1)` memory complexity)?

• Could you solve the problem in `O(n)` time complexity? The solution may be too advanced for an interview but you may find reading this paper fun.

### Explanation

#### Brute force approach

The easiest approach is to use an additional 1D array. We iterate the matrix row or column wise and store the elements in an array. We sort the array and return the `k - 1` element of the array.

A C++ snippet of the above approach is as follows:

``````int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
int m = matrix[0].size();
vector<int> array;

for(int i = 0; i < n; i++) {
for(int j = 0; j < m; j++) {
array.push_back(matrix[i][j]);
}
}

sort(array.begin(), array.end());

return array[k - 1];
}``````

The time and space complexity of the above approach is O(n * m).

#### Priority Queue

C++ has an inbuilt priority queue data structure. Priority queues are implemented using heaps, and in maximum priority queues, the maximum element is always at the top of the heap.

We store all the elements in the priority queue. The maximum element will be at the top of the priority queue. We keep using the pop function n * m - k times. pop function removes the top element from the heap. After n * m - k times the topmost element will be the kth smallest element in the queue.

A C++ snippet of the above approach is as follows:

``````int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
int m = matrix[0].size();
priority_queue<int> pq;

for(int i = 0; i < n; i++) {
for(int j = 0; j < m; j++) {
pq.push(input[i][j]);
}
}

for(int i = 0; i < n * m - k; i++){
pq.pop();
}

return pq.top();

}``````

The time and space complexity of this approach is O(n * m).

#### Binary Search

The input matrix is sorted row-wise and column-wise. We can utilize this information to reduce our time and space complexity of the algorithm.

The idea is to divide the matrix into two parts using the middle element as a point of reference. We count the number of elements that are less than or equal to the middle element. If the count is smaller than k, we check the 2nd half of the matrix and if count is greater than k, we check the 1st half.

Let's check the algorithm for this approach.

#### Algorithm

``````// kthSmallest function
- set n = matrix.size(), m = matrix[0].size()
low = matrix[0][0], high = matrix[n - 1][m - 1]
initialize mid, greaterThanOrEqualMid

- loop while low <= high
- set mid = low + (high - low) / 2

- set greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)

- if greaterThanOrEqualMid >= k
- update high = mid - 1
- else
- update low = mid + 1
- if end

- while end

- return low

// getElementsGreaterThanOrEqualMid function
- set count = 0
initialize greaterThanMid

- loop for i = 0; i < n; i++
- if matrix[i][0] > mid
- return count

- if matrix[i][n - 1] <= mid
- count = count + n
- continue

- set greaterThanMid = 0

- loop for j = n / 2; j >= 1; j /= 2
- loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
- update greaterThanMid = greaterThanMid + j
- while end
- for end

- set count = count + greaterThanMid + 1
- for end

- return count
``````

The time complexity of this approach is `O(y * n * log(n))`, where `y = log(abs(matrix[0][0] - matrix[n - 1][n - 1]))`. The space complexity is `O(1)`.

#### C++ solution

``````class Solution {
public:
int getElementsGreaterThanOrEqualMid(vector<vector<int>>& matrix, int n, int mid) {
int count = 0;
int greaterThanMid;

for(int i = 0; i < n; i++) {
if(matrix[i][0] > mid) {
return count;
}

if(matrix[i][n - 1] <= mid) {
count += n;
continue;
}

greaterThanMid = 0;

for(int j = n / 2; j >= 1; j /= 2) {
while(greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid) {
greaterThanMid += j;
}
}

count += greaterThanMid + 1;
}

return count;
}

int kthSmallest(vector<vector<int>>& matrix, int k) {
int n = matrix.size();
int m = matrix[0].size();
int low = matrix[0][0];
int high = matrix[n - 1][m - 1];
int mid, greaterThanOrEqualMid;

while(low <= high) {
mid = low + (high - low) / 2;

greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid);

if (greaterThanOrEqualMid >= k)
high = mid - 1;
else
low = mid + 1;
}

return low;
}
};``````

#### Golang solution

``````func getElementsGreaterThanOrEqualMid(matrix [][]int, n, mid int) int {
count, greaterThanMid := 0, 0

for i := 0; i < n; i++ {
if matrix[i][0] > mid {
return count
}

if matrix[i][n - 1] <= mid {
count += n
continue
}

greaterThanMid = 0

for j := n / 2; j >= 1; j /= 2 {
for greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid {
greaterThanMid += j
}
}

count += greaterThanMid + 1
}

return count
}

func kthSmallest(matrix [][]int, k int) int {
n, m := len(matrix), len(matrix[0])
low, high := matrix[0][0], matrix[n - 1][m - 1]
mid, greaterThanOrEqualMid := 0, 0

for low <= high {
mid = low + (high - low) / 2

greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)

if greaterThanOrEqualMid >= k {
high = mid - 1
} else {
low = mid + 1
}
}

return low
}``````

#### JavaScript solution

``````var getElementsGreaterThanOrEqualMid = function(matrix, n, mid) {
let count = 0, greaterThanMid = 0;

for(let i = 0; i < n; i++) {
if(matrix[i][0] > mid) {
return count;
}

if(matrix[i][n - 1] <= mid) {
count += n;
continue;
}

greaterThanMid = 0;

for(let j = n / 2; j >= 1; j /= 2) {
while(greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid) {
greaterThanMid += j;
}
}

count += greaterThanMid + 1;
}

return count;
}

var kthSmallest = function(matrix, k) {
let n = matrix.length, m = matrix[0].length;
let low = matrix[0][0], high = matrix[n - 1][m - 1];
let mid, greaterThanOrEqualMid;

while(low <= high) {
mid = low + parseInt((high - low) / 2, 10);

greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid);

if(greaterThanOrEqualMid >= k) {
high = mid - 1;
} else {
low = mid + 1;
}
}

return low;
};``````

#### Dry Run

Let's dry-run our algorithm to see how the solution works.

``````Input: matrix = [[1, 5, 9], [10, 11, 13], [12, 13, 15]]
k = 8

// kthSmallest function
Step 1: n = matrix.size()
= 3
m = matrix[0].size()
= 3
low = matrix[0][0]
= 1
high = matrix[n - 1][m - 1]
= matrix[3 - 1][3 - 1]
= matrix[2][2]
= 15
int mid, greaterThanOrEqualMid

Step 2: loop while low <= high
1 <= 15
true

mid = low + (high - low) / 2
= 1 + (15 - 1) / 2
= 1 + 7
= 8

greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= getElementsGreaterThanOrEqualMid(matrix, 3, 8)

// getElementsGreaterThanOrEqualMid function
Step 3: count = 0
greaterThanMid

loop for i = 0; i < n; i++
if matrix[i][0] > mid
matrix[0][0] > 8
1 > 8
false

if matrix[i][n - 1] <= mid
matrix[0][2] <= 8
9 <= 8
false

greaterThanMid = 0

loop for j = n / 2; j >= 1; j /= 2
j = 3/2 = 1
j >= 1
1 >= 1
true

loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
0 + 1 < 3 && matrix[0][0 + 1] <= mid
1 < 3 && 5 <= 8
true

greaterThanMid = greaterThanMid + j
= 0 + 1
= 1

loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
1 + 1 < 3 && matrix[0][1 + 1] <= mid
2 < 3 && 9 <= 8
false

j = j / 2
= 1 / 2
= 0

loop for j = n / 2; j >= 1
0 >= 1
false

count = count + greaterThanMid + 1
= 0 + 1 + 1
= 2

i++
i = 1

loop for i < n
1 < 3
true

if matrix[i][0] > mid
matrix[1][0] > 8
10 > 8
true

return count

return 2

// kthSmallest function
Step 4: greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= 2

if greaterThanOrEqualMid >= k
2 >= 8
false
else
low = mid + 1
= 8 + 1
= 9

Step 5: loop while low <= high
9 <= 15
true

mid = low + (high - low) / 2
= 9 + (15 - 9) / 2
= 9 + 3
= 12

greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= getElementsGreaterThanOrEqualMid(matrix, 3, 12)

// getElementsGreaterThanOrEqualMid function
Step 6: count = 0
greaterThanMid

loop for i = 0; i < n; i++
0 < 3
true

if matrix[i][0] > mid
matrix[0][0] > 12
1 > 12
false

if matrix[i][n - 1] <= mid
matrix[0][2] <= 12
9 <= 12
true

count = count + n
= 0 + 3
= 3

continue

i++
i = 1

loop for i < n
1 < 3
true

if matrix[i][0] > mid
matrix[1][0] > 12
10 > 12
false

if matrix[i][n - 1] <= mid
matrix[1][2] <= 12
13 <= 12
false

greaterThanMid = 0

loop for j = n / 2; j >= 1; j /= 2
j = 3/2 = 1
j >= 1
1 >= 1
true

loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
0 + 1 < 3 && matrix[0][0 + 1] <= 12
1 < 3 && 5 <= 12
true

greaterThanMid = greaterThanMid + j
= 0 + 1
= 1

loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
1 + 1 < 3 && matrix[1][1 + 1] <= 12
2 < 3 && 13 <= 12
false

count = count + greaterThanMid + 1
= 3 + 1 + 1
= 5

i++
i = 2

loop for i < n
2 < 3
true

if matrix[i][0] > mid
matrix[2][0] > 12
12 > 12
false

if matrix[i][n - 1] <= mid
matrix[2][2] <= 12
15 <= 12
false

greaterThanMid = 0

loop for j = n / 2; j >= 1; j /= 2
j = 3/2 = 1
j >= 1
1 >= 1
true

loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
0 + 1 < 3 && matrix[2][0 + 1] <= 12
1 < 3 && 13 <= 12
false

j = j / 2
= 1 / 2
= 0

loop for j >= 1
0 >= 1
false

count = count + greaterThanMid + 1
= 5 + 0 + 1
= 6

i++
i = 3

loop for i < n
3 < 3
false

return count
return 6

// kthSmallest function
Step 7: greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= 6

if greaterThanOrEqualMid >= k
6 >= 8
false
else
low = mid + 1
= 12 + 1
= 13

Step 8: loop while low <= high
13 <= 15
true

mid = low + (high - low) / 2
= 13 + (15 - 13) / 2
= 13 + 1
= 14

greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
= getElementsGreaterThanOrEqualMid(matrix, 3, 14)

We return the above step and return the answer as 13.
``````