Alkesh

LeetCode - Count of Smaller Numbers After Self

Problem statement

Given an integer array nums, return an integer array counts where counts[i] is the number of smaller elements to the right of nums[i].

Problem statement taken from: https://leetcode.com/problems/count-of-smaller-numbers-after-self

Example 1:

Input: nums = [5, 2, 6, 1]
Output: [2, 1, 1, 0]
Explanation:
To the right of 5 there are 2 smaller elements (2 and 1).
To the right of 2 there is only 1 smaller element (1).
To the right of 6 there is 1 smaller element (1).
To the right of 1 there is 0 smaller element.

Example 2:

Input: nums = [-1]
Output: [0]

Example 3:

Input: nums = [-1, -1]
Output: [0, 0]

Constraints:

- 1 <= nums.length <= 10^5
- -10^4 <= nums[i] <= 10^4

Explanation

Brute force solution

A naive approach is to use nested loops. The outer loop selects all the elements from left to right. The inner loop iterates through all the elements on the right side of the selected element and update the count of elements less than the number.

A C++ snippet of this approach is as below:

vector<int> countSmaller(vector<int>& nums) {
    int i, j;
    vector<int> ans;
    int n = nums.size();

    for (i = 0; i < n; i++)
        ans[i] = 0;

    for (i = 0; i < n; i++) {
        for (j = i + 1; j < n; j++) {
            if (nums[j] < nums[i])
                ans[i]++;
        }
    }
}

The time complexity of this approach is O(n^2). The space complexity is O(n).

AVL tree

In this approach, we use a Self-balancing Binary Search Tree, an AVL tree.

We iterate the array from right to left and insert the elements one by one in the AVL tree. While inserting a new key in the AVL tree, we compare the element with the root value of the tree. If the element is greater than the root, then all the nodes in the left subtree are smaller than the element. We add the size of the left subtree to the element being inserted. We recursively follow the same approach for all the nodes.

A C++ snippet of this approach is as follows:

struct node {
    int key;
    struct node* left;
    struct node* right;
    int height;
    int size;
};

int height(struct node* node) {
    return node != NULL ? node->height : 0;
}

int size(struct node* node) {
    return node != NULL ? node->size : 0;
}

int max(int a, int b) { return (a > b) ? a : b; }

struct node* newNode(int key) {
    struct node* node = (struct node*)malloc(sizeof(struct node));
    node->key = key;
    node->left = NULL;
    node->right = NULL;
    node->height = 1;
    node->size = 1;

    return node;
}

struct node* rightRotate(struct node* y) {
    struct node* leftTree = y->left;
    struct node* T2 = leftTree->right;

    leftTree->right = y;
    y->left = T2;

    y->height = max(height(y->left), height(y->right)) + 1;
    leftTree->height = max(height(leftTree->left), height(leftTree->right)) + 1;

    y->size = size(y->left) + size(y->right) + 1;
    leftTree->size = size(leftTree->left) + size(leftTree->right) + 1;

    return x;
}

struct node* leftRotate(struct node* x) {
    struct node* y = x->right;
    struct node* T2 = y->left;

    y->left = x;
    x->right = T2;

    x->height = max(height(x->left), height(x->right)) + 1;
    y->height = max(height(y->left), height(y->right)) + 1;

    x->size = size(x->left) + size(x->right) + 1;
    y->size = size(y->left) + size(y->right) + 1;

    return y;
}

int getBalance(struct node* node) {
    return node != NULL ? height(node->left) - height(node->right) : 0;
}

struct node* insert(struct node* node, int key, int* count) {
    if (node == NULL)
        return (newNode(key));

    if (key < node->key)
        node->left = insert(node->left, key, count);
    else {
        node->right = insert(node->right, key, count);
        *count = *count + size(node->left) + 1;
    }

    node->height = max(height(node->left), height(node->right)) + 1;
    node->size = size(node->left) + size(node->right) + 1;

    int balance = getBalance(node);

    if (balance > 1 && key < node->left->key)
        return rightRotate(node);

    if (balance < -1 && key > node->right->key)
        return leftRotate(node);

    if (balance > 1 && key > node->left->key) {
        node->left = leftRotate(node->left);
        return rightRotate(node);
    }

    if (balance < -1 && key < node->right->key) {
        node->right = rightRotate(node->right);
        return leftRotate(node);
    }

    return node;
}

void constructLower(int nums[]) {
    int i, j;
    int n = nums.size();
    struct node* root = NULL;

    for (i = 0; i < n; i++)
        ans[i] = 0;

    for (i = n - 1; i >= 0; i--) {
        root = insert(root, nums[i], &ans[i]);
    }
}

The time complexity of this approach is O(n * log(n)). The space complexity is O(n).

Merge Sort

The idea is to use the MergeSort algorithm. When merging back the array, as we do in merge sort, we sort the array elements in descending order and keep track of the smaller elements.

We know how the merge sort work, let's check the algorithm first.

Algorithm

//countSmaller(nums)
- initialize vector pair v
  set n = nums.size
  set ans = vector<int> ans(n, 0)

- loop for i = 0; i < n; i++
  // push the element and the index
  - v.push_back({ nums[i], i })
- for end

- mergesort(v, ans, 0, n - 1)

- return ans

// mergesort(v, ans, i, j)
- if i < j
  - set mid = (i + j) / 2

  //recursively call left half of the array
  - mergesort(v, ans, i, mid)

  //recursively call right half of the array
  - mergesort(v, ans, mid + 1, j)

  // merge the array and sort the elements
  - merge(v, ans, i, mid, j)
- if end

// merge(v, ans, l, mid, h)
- initialize vector pair temp
  set i = l
  set j = mid + 1

- loop while i < mid + 1 && j <= h
  - if v[i].first > v[j].first
    // add up all the elements that are less than this element
    // and these elements are present in the 2nd half of the array
    - set ans[v[i].second] += (h - j + 1)

    - temp.push_back(v[i])

    - increment i, i++
  - else
    - temp.push_back(v[j])

    - increment j, j++
  - if end
- while end

- loop while i <= mid
  - temp.push_back(v[i])
  - i++
- while end

- loop while j <= h
  - temp.push_back(v[j])
  - j++
- while end

- loop for k = 0, i = l; i <= h; i++, k++
  - v[i] = temp[k]
- for end

The time complexity of this approach is O(n * log(n)). The space complexity is O(n).

Let's check our algorithm in C++, Golang, and JavaScript.

C++ solution

class Solution {
public:
    void merge(vector<pair<int, int>> &v, vector<int> &ans, int l, int mid, int h) {
        vector<pair<int, int>> temp;
        int i = l;
        int j = mid + 1;

        while (i < mid + 1 && j <= h) {
            if (v[i].first > v[j].first) {
                ans[v[i].second] += (h - j + 1);
                temp.push_back(v[i]);
                i++;
            } else {
                temp.push_back(v[j]);
                j++;
            }
        }

        while (i <= mid)
            temp.push_back(v[i++]);

        while (j <= h)
            temp.push_back(v[j++]);

        for (int k = 0, i = l; i <= h; i++, k++)
            v[i] = temp[k];
    }

    void mergesort(vector<pair<int, int>> &v, vector<int> &ans, int i, int j) {
        int mid;

        if(i < j) {
            mid = (i + j) / 2;

            mergesort(v, ans, i, mid);

            mergesort(v, ans, mid + 1, j);

            merge(v, ans, i, mid, j);
        }
    }

    vector<int> countSmaller(vector<int>& nums) {
        vector<pair<int, int>> v;
        int n = nums.size();
        vector<int> ans(n, 0);

        for (int i = 0; i < n; i++) {
            v.push_back({nums[i], i});
        }

        mergesort(v, ans, 0, n - 1);

        return ans;
    }
};

Golang solution

type Pair struct {
	Val   int
	Index int
}

func merge(v []Pair, ans []int, l, mid, h int) {
	temp := make([]Pair, h - l +1)
	i, j, k := l, mid + 1, 0

    for i <= mid && j <= h {
		if v[i].Val > v[j].Val {
            ans[v[i].Index] += h - j + 1
			temp[k] = v[i]
			i++

		} else {
			temp[k] = v[j]
			j++
		}

		k++
	}

	for i <= mid {
		temp[k] = v[i]
		i++
		k++
	}
	for j <= h {
		temp[k] = v[j]
		j++
		k++
	}

	for i := l; i <= h; i++ {
		v[i] = temp[i - l]
	}
}

func mergeSort(v []Pair, ans []int, i, j int) {
	if i < j {
		mid := (i + j)/2
		mergeSort(v, ans, i, mid)
		mergeSort(v, ans, mid + 1, j)
		merge(v, ans, i, mid, j)
	}
}

func countSmaller(nums []int) []int {
	n := len(nums)
	ans := make([]int, n)
	v := make([]Pair, n)
	for index, value := range nums {
		v[index] = Pair{value, index}
	}

	mergeSort(v, ans, 0, n - 1)

	return ans
}

JavaScript solution

var merge = function(v, ans, l, mid, h) {
    let temp = [];
    let i = l, j = mid + 1;

    while(i < mid + 1 && j <= h) {
        if(j < v.length && v[i][0] > v[j][0]) {
            ans[v[i][1]] += (h - j + 1);
            temp.push(v[i]);
            i++;
        } else {
            temp.push(v[j]);
            j++;
        }
    }

    while (i <= mid)
        temp.push(v[i++]);

    while (j <= h)
        temp.push(v[j++]);

    for (let k = 0, i = l; i <= h; i++, k++)
        v[i] = temp[k];
}

var mergesort = function(v, ans, i, j) {
    let mid;

    if(i < j) {
        mid = Math.round((i + j) / 2)-1;

        mergesort(v, ans, i, mid);

        mergesort(v, ans, mid + 1, j);

        merge(v, ans, i, mid, j);
    }
}

var countSmaller = function(nums) {
    let v = [];
    let n = nums.length;
    let ans = new Array(n);

    for(let i = 0; i < n; i++) {
        let x = [nums[i], i];
        v.push(x);
    }

    for(let i = 0; i < n; i++) {
        ans[i] = 0;
    }

    mergesort(v, ans, 0, n - 1);

    return ans;
};

Dry Run

Let's dry-run our algorithm for a few examples to see how the solution works.

Input: nums = [5, 2, 6, 1]

// countSmaller
Step 1: vector<pair<int, int>> v

        int n = nums.size()
              = 4

        vector<int> ans(n, 0)

Step 2: loop for int i = 0; i < n; i++
            v.push_back({nums[i], i});
        for end

        v will be [[5, 0], [2, 1], [6, 2], [1, 3]]

Step 3: mergesort(v, ans, 0, n - 1)
        mergesort(v, ans, 0, 3)

// mergesort(v, ans, 0, 3)
Step 4: if i < j
         0 < 3
         true

         mid = i + j / 2
             = 0 + 3 / 2
             = 1

         mergesort(v, ans, i, mid)
         mergesort(v, ans, 0, 1)

// mergesort(v, ans, 0, 1)
Step 5: if i < j
         0 < 1
         true

         mid = i + j / 2
             = 0 + 1 / 2
             = 0

         mergesort(v, ans, i, mid)
         mergesort(v, ans, 0, 0)

// mergesort(v, ans, 0, 0)
Step 6: if i < j
         0 < 1
         false

         we rollback to step 5

// mergesort(v, ans, 0, 1)
Step 7: if i < j
         0 < 1
         true

         mid = i + j / 2
             = 0 + 1 / 2
             = 0

         mergesort(v, ans, i, mid)
         mergesort(v, ans, 0, 0)    // evaluated in Step 6

         mergesort(v, ans, mid + 1, j)
         mergesort(v, ans, 0 + 1, 3)
         mergesort(v, ans, 1, 3)

// mergesort(v, ans, 1, 3)
Step 8: if i < j
         1 < 3
         true

         mid = i + j / 2
             = 1 + 3 / 2
             = 2

         mergesort(v, ans, i, mid)
         mergesort(v, ans, 1, 2)

         This recursion occurs till we get each element as
         [5], [2], [6], [1]

         We then reach a step where we first merge 6 and 1 and call
         merge(v, ans, i, mid, j)
         merge(v, ans, 2, 2, 3)

// merge(v, ans, 2, 2, 3)
Step 9: vector<pair<int, int>> temp
        int i = l
              = 2
        int j = mid + 1
              = 2 + 1
              = 3

        loop while i < mid + 1 && j <= h
          2 < 2 + 1 && 3 <= 3
          true

          if v[i].first > v[j].first
             v[2].first > v[3].first
             6 > 1
             true

             ans[v[i].second] = ans[v[i].second] + (h - j + 1)
             ans[v[2].second] = ans[v[2].second] + 3 - 3 + 1
                              = 0 + 1
                              = 1
             ans = [0, 0, 1, 0]

             temp.push_back(v[i])
             temp = [[6, 2]]

             i++
             i = 3

          if end

        loop while i < mid + 1 && j <= h
          3 < 2 + 1 && 3 <= 3
          false

        loop while i <= mid
          3 <= 2
          false

        loop while j <= h
          3 <= 3
          true

          temp.push_back(v[j])
          temp.push_back(v[3])
          temp = [[6, 2], [1, 3]]

          j++
          j = 4

        loop while j <= h
          4 <= 3
          false

        loop for int k = 0, i = l; i <= h;
          k = 0
          i = 2

          i <= h
          2 <= 3

          v[i] = temp[k]
          v[2] = temp[0]
               = [6, 2]

          i++
          i = 3

          k++
          k = 1

        loop for i <= h
          3 <= 3
          true

          k = 1
          i = 3

          v[i] = temp[k]
          v[3] = temp[1]
               = [1, 3]

          i++
          i = 4

          k++
          k = 2

        loop for i <= h
          4 <= 3
          false

        We backtrack to step where we merge 5 and 2.

        We then backtrack to step where we merge [5, 2] and [6, 1]

We keep updating the ans array and get the final result as [2, 1, 1, 0].
Share this post!