Alkesh

LeetCode - Kth Smallest Element in a Sorted Matrix

Problem statement

Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

You must find a solution with a memory complexity better than O(n^2).

Problem statement taken from: https://leetcode.com/problems/kth-smallest-element-in-a-sorted-matrix

Example 1:

Input: matrix = [[1, 5, 9], [10, 11, 13], [12, 13, 15]], k = 8
Output: 13
Explanation: The elements in the matrix are [1, 5, 9, 10, 11, 12, 13, 13, 15], and the 8th smallest number is 13

Example 2:

Input: matrix = [[-5]], k = 1
Output: -5

Constraints:

- n == matrix.length == matrix[i].length
- 1 <= n <= 300
- -10^9 <= matrix[i][j] <= 10^9
- All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order.
- 1 <= k <= n2

Follow up:

  • Could you solve the problem with a constant memory (i.e., O(1) memory complexity)?

  • Could you solve the problem in O(n) time complexity? The solution may be too advanced for an interview but you may find reading this paper fun.

Explanation

Brute force approach

The easiest approach is to use an additional 1D array. We iterate the matrix row or column wise and store the elements in an array. We sort the array and return the k - 1 element of the array.

A C++ snippet of the above approach is as follows:

int kthSmallest(vector<vector<int>>& matrix, int k) {
    int n = matrix.size();
    int m = matrix[0].size();
    vector<int> array;

    for(int i = 0; i < n; i++) {
        for(int j = 0; j < m; j++) {
            array.push_back(matrix[i][j]);
        }
    }

    sort(array.begin(), array.end());

    return array[k - 1];
}

The time and space complexity of the above approach is O(n * m).

Priority Queue

C++ has an inbuilt priority queue data structure. Priority queues are implemented using heaps, and in maximum priority queues, the maximum element is always at the top of the heap.

We store all the elements in the priority queue. The maximum element will be at the top of the priority queue. We keep using the pop function n * m - k times. pop function removes the top element from the heap. After n * m - k times the topmost element will be the kth smallest element in the queue.

A C++ snippet of the above approach is as follows:

int kthSmallest(vector<vector<int>>& matrix, int k) {
    int n = matrix.size();
    int m = matrix[0].size();
    priority_queue<int> pq;

    for(int i = 0; i < n; i++) {
        for(int j = 0; j < m; j++) {
            pq.push(input[i][j]);
        }
    }

    for(int i = 0; i < n * m - k; i++){
        pq.pop();
    }

    return pq.top();

}

The time and space complexity of this approach is O(n * m).

Binary Search

The input matrix is sorted row-wise and column-wise. We can utilize this information to reduce our time and space complexity of the algorithm.

The idea is to divide the matrix into two parts using the middle element as a point of reference. We count the number of elements that are less than or equal to the middle element. If the count is smaller than k, we check the 2nd half of the matrix and if count is greater than k, we check the 1st half.

Let's check the algorithm for this approach.

Algorithm

// kthSmallest function
- set n = matrix.size(), m = matrix[0].size()
      low = matrix[0][0], high = matrix[n - 1][m - 1]
  initialize mid, greaterThanOrEqualMid

- loop while low <= high
  - set mid = low + (high - low) / 2

  - set greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)

  - if greaterThanOrEqualMid >= k
    - update high = mid - 1
  - else
    - update low = mid + 1
  - if end

- while end

- return low

// getElementsGreaterThanOrEqualMid function
- set count = 0
  initialize greaterThanMid

- loop for i = 0; i < n; i++
  - if matrix[i][0] > mid
    - return count

  - if matrix[i][n - 1] <= mid
    - count = count + n
    - continue

  - set greaterThanMid = 0

  - loop for j = n / 2; j >= 1; j /= 2
    - loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
      - update greaterThanMid = greaterThanMid + j
    - while end
  - for end

  - set count = count + greaterThanMid + 1
- for end

- return count

The time complexity of this approach is O(y * n * log(n)), where y = log(abs(matrix[0][0] - matrix[n - 1][n - 1])). The space complexity is O(1).

C++ solution

class Solution {
public:
    int getElementsGreaterThanOrEqualMid(vector<vector<int>>& matrix, int n, int mid) {
        int count = 0;
        int greaterThanMid;

        for(int i = 0; i < n; i++) {
            if(matrix[i][0] > mid) {
                return count;
            }

            if(matrix[i][n - 1] <= mid) {
                count += n;
                continue;
            }

            greaterThanMid = 0;

            for(int j = n / 2; j >= 1; j /= 2) {
                while(greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid) {
                    greaterThanMid += j;
                }
            }

            count += greaterThanMid + 1;
        }

        return count;
    }

    int kthSmallest(vector<vector<int>>& matrix, int k) {
        int n = matrix.size();
        int m = matrix[0].size();
        int low = matrix[0][0];
        int high = matrix[n - 1][m - 1];
        int mid, greaterThanOrEqualMid;

        while(low <= high) {
            mid = low + (high - low) / 2;

            greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid);

            if (greaterThanOrEqualMid >= k)
                high = mid - 1;
            else
                low = mid + 1;
        }

        return low;
    }
};

Golang solution

func getElementsGreaterThanOrEqualMid(matrix [][]int, n, mid int) int {
    count, greaterThanMid := 0, 0

    for i := 0; i < n; i++ {
        if matrix[i][0] > mid {
            return count
        }

        if matrix[i][n - 1] <= mid {
            count += n
            continue
        }


        greaterThanMid = 0

        for j := n / 2; j >= 1; j /= 2 {
            for greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid {
                greaterThanMid += j
            }
        }

        count += greaterThanMid + 1
    }

    return count
}

func kthSmallest(matrix [][]int, k int) int {
    n, m := len(matrix), len(matrix[0])
    low, high := matrix[0][0], matrix[n - 1][m - 1]
    mid, greaterThanOrEqualMid := 0, 0

    for low <= high {
        mid = low + (high - low) / 2

        greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)

        if greaterThanOrEqualMid >= k {
            high = mid - 1
        } else {
            low = mid + 1
        }
    }

    return low
}

JavaScript solution

var getElementsGreaterThanOrEqualMid = function(matrix, n, mid) {
    let count = 0, greaterThanMid = 0;

    for(let i = 0; i < n; i++) {
        if(matrix[i][0] > mid) {
            return count;
        }

        if(matrix[i][n - 1] <= mid) {
            count += n;
            continue;
        }

        greaterThanMid = 0;

        for(let j = n / 2; j >= 1; j /= 2) {
            while(greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid) {
                greaterThanMid += j;
            }
        }

        count += greaterThanMid + 1;
    }

    return count;
}

var kthSmallest = function(matrix, k) {
    let n = matrix.length, m = matrix[0].length;
    let low = matrix[0][0], high = matrix[n - 1][m - 1];
    let mid, greaterThanOrEqualMid;

    while(low <= high) {
        mid = low + parseInt((high - low) / 2, 10);

        greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid);

        if(greaterThanOrEqualMid >= k) {
            high = mid - 1;
        } else {
            low = mid + 1;
        }
    }

    return low;
};

Dry Run

Let's dry-run our algorithm to see how the solution works.

Input: matrix = [[1, 5, 9], [10, 11, 13], [12, 13, 15]]
       k = 8

// kthSmallest function
Step 1: n = matrix.size()
          = 3
        m = matrix[0].size()
          = 3
        low = matrix[0][0]
            = 1
        high = matrix[n - 1][m - 1]
             = matrix[3 - 1][3 - 1]
             = matrix[2][2]
             = 15
        int mid, greaterThanOrEqualMid

Step 2: loop while low <= high
          1 <= 15
          true

          mid = low + (high - low) / 2
              = 1 + (15 - 1) / 2
              = 1 + 7
              = 8

          greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
                                = getElementsGreaterThanOrEqualMid(matrix, 3, 8)

// getElementsGreaterThanOrEqualMid function
Step 3: count = 0
        greaterThanMid

        loop for i = 0; i < n; i++
          if matrix[i][0] > mid
             matrix[0][0] > 8
             1 > 8
             false

          if matrix[i][n - 1] <= mid
             matrix[0][2] <= 8
             9 <= 8
             false

          greaterThanMid = 0

          loop for j = n / 2; j >= 1; j /= 2
            j = 3/2 = 1
            j >= 1
            1 >= 1
            true

            loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
              0 + 1 < 3 && matrix[0][0 + 1] <= mid
              1 < 3 && 5 <= 8
              true

                greaterThanMid = greaterThanMid + j
                               = 0 + 1
                               = 1

            loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
              1 + 1 < 3 && matrix[0][1 + 1] <= mid
              2 < 3 && 9 <= 8
              false

            j = j / 2
              = 1 / 2
              = 0

          loop for j = n / 2; j >= 1
            0 >= 1
            false

          count = count + greaterThanMid + 1
                = 0 + 1 + 1
                = 2

          i++
          i = 1

        loop for i < n
          1 < 3
          true

          if matrix[i][0] > mid
             matrix[1][0] > 8
             10 > 8
             true

             return count

             return 2

// kthSmallest function
Step 4: greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
                              = 2

        if greaterThanOrEqualMid >= k
          2 >= 8
          false
        else
          low = mid + 1
              = 8 + 1
              = 9

Step 5: loop while low <= high
          9 <= 15
          true

          mid = low + (high - low) / 2
              = 9 + (15 - 9) / 2
              = 9 + 3
              = 12

          greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
                                = getElementsGreaterThanOrEqualMid(matrix, 3, 12)

// getElementsGreaterThanOrEqualMid function
Step 6: count = 0
        greaterThanMid

        loop for i = 0; i < n; i++
          0 < 3
          true

            if matrix[i][0] > mid
              matrix[0][0] > 12
              1 > 12
              false

            if matrix[i][n - 1] <= mid
              matrix[0][2] <= 12
              9 <= 12
              true

              count = count + n
                    = 0 + 3
                    = 3

              continue

          i++
          i = 1

        loop for i < n
          1 < 3
          true

          if matrix[i][0] > mid
            matrix[1][0] > 12
            10 > 12
            false

          if matrix[i][n - 1] <= mid
            matrix[1][2] <= 12
            13 <= 12
            false

          greaterThanMid = 0

          loop for j = n / 2; j >= 1; j /= 2
            j = 3/2 = 1
            j >= 1
            1 >= 1
            true

            loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
              0 + 1 < 3 && matrix[0][0 + 1] <= 12
              1 < 3 && 5 <= 12
              true

                greaterThanMid = greaterThanMid + j
                               = 0 + 1
                               = 1

            loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
              1 + 1 < 3 && matrix[1][1 + 1] <= 12
              2 < 3 && 13 <= 12
              false

          count = count + greaterThanMid + 1
                = 3 + 1 + 1
                = 5

          i++
          i = 2

        loop for i < n
          2 < 3
          true

          if matrix[i][0] > mid
            matrix[2][0] > 12
            12 > 12
            false

          if matrix[i][n - 1] <= mid
            matrix[2][2] <= 12
            15 <= 12
            false

          greaterThanMid = 0

          loop for j = n / 2; j >= 1; j /= 2
            j = 3/2 = 1
            j >= 1
            1 >= 1
            true

            loop while greaterThanMid + j < n && matrix[i][greaterThanMid + j] <= mid
              0 + 1 < 3 && matrix[2][0 + 1] <= 12
              1 < 3 && 13 <= 12
              false

            j = j / 2
              = 1 / 2
              = 0

          loop for j >= 1
            0 >= 1
            false

          count = count + greaterThanMid + 1
                = 5 + 0 + 1
                = 6

          i++
          i = 3

        loop for i < n
          3 < 3
          false

        return count
        return 6

// kthSmallest function
Step 7: greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
                              = 6

        if greaterThanOrEqualMid >= k
          6 >= 8
          false
        else
          low = mid + 1
              = 12 + 1
              = 13

Step 8: loop while low <= high
          13 <= 15
          true

          mid = low + (high - low) / 2
              = 13 + (15 - 13) / 2
              = 13 + 1
              = 14

          greaterThanOrEqualMid = getElementsGreaterThanOrEqualMid(matrix, n, mid)
                                = getElementsGreaterThanOrEqualMid(matrix, 3, 14)

We return the above step and return the answer as 13.
Share this post!