Problem of the day: Matrix median

So here’s an interesting problem: Given an NxM integer matrix in which each row is sorted, find the overall median of the matrix assuming N*M is odd. For example,

Given matrix =
[1, 5, 7]
[4, 10, 11]
[8, 11, 12]
The sorted array is [1, 4, 5, 7, 8, 10, 11, 11, 12]
and the median is 8.

It’s interesting because there are some obvious solutions which are sub-optimal and the optimal solution is non-obvious. There are a few obvious solutions:

  • Copy the matrix values in an array, sort and return the middle element – Space complexity: O(N*M), Time complexity: O(N*M*log NM)
  • Copy the matrix values in an array and use selection algorithm to find the median – Space complexity: O(N*M), Time complexity: O(N*M)
  • Modify the sort procedure/selection algorithm to work with a matrix instead of a 1-D array (a matrix is actually stored in row-major fashion as a 1-D array) – This wouldn’t require any extra memory but the time complexity would still be O(N*M*log NM) with sorting and O(N*M) with selection algorithm.

All of these solutions are solving a more general problem by throwing away the information that each row in the matrix is sorted. We should be able to reduce the time complexity by using this information.

Consider this insight: Given an integer x, we can count the number of matrix elements ≤ x in O(N*log M), by simply doing a binary search on each row. x will be the median if x is an element of the matrix and number of matrix elements ≤ x equals 1 + N*M/2. Let mn and mx be the minimum and maximum elements of the matrix respectively, then a boolean function

f(x) = count(elements ≤ x) ≥ 1 + N*M/2, where x ∈ [mn, mx]

is monotonic in x. The smallest x for which it becomes positive is the median because number of elements ≤ x would equal 1 + N*M/2 (would be greater if x is repeated but that’s not a problem) and since f(x) became true starting at x, x must be an element of the matrix.

Let’s work with the same example.

x count(elements ≤ x) f(x) = count >= 5
1 1 false
2 1 false
3 1 false
4 2 false
5 3 false
6 3 false
7 4 false
8 5 true
9 5 true
10 6 true
11 8 true
12 9 true

x = 8 is the inflection point at which f(x) first becomes positive and is the median.

It’s rare to find a solution with a binary search within a binary search. We can run a discrete binary search on the range [mn, mx] to find the first x for which f(x) becomes true. The overall time complexity is O(N*log M * log mx-mn) = O(32*N* log M) since an integer is at most 32 bits.

Here’s the solution:

int findMedian(vector<vector<int> > &A) {
int mn = A[0][0], mx = A[0][0], n = A.size(), m = A[0].size();
for (int i = 0; i < n; ++i) {
if (A[i][0] < mn) mn = A[i][j];
if (A[i][m-1] > mx) mx = A[i][j];
}
int desired = (n * m + 1) / 2;
while (mn < mx) {
int mid = mn + (mx - mn) / 2;
int place = 0;
for (int i = 0; i < n; ++i)
place += upper_bound(A[i].begin(), A[i].end(), mid) - A[i].begin();
if (place < desired)
mn = mid + 1;
else
mx = mid;
}
return mn;
}

6 thoughts on “Problem of the day: Matrix median

  1. I think you can speed this up by using the fact that the overall median for the whole matrix must be in the domain of the individual row medians (e.g. [5,11] in the example above).

    Also, as an alternate approach, merge sort pairs of rows. Without loss of generality, assume N is even. After one pass of this, you have a N/2 x 2M matrix of sorted rows. Repeat the process until you have a one dimensional array that sorted, at which point you can pluck out the median.

    • As for the first idea, you are right. The median search range can be reduced from [mn..mx] to [min row median..max row median] but it doesn’t help the time complexity.

      I don’t know if we can implement the second approach without additional space and even then it will have far worse complexity.

  2. I agree that the first approach won’t help the complexity analysis. But, if it helps in terms of the practical run time, then it’s worth doing, and I think it will do that. Regarding the second approach, I agree it will have far worse complexity if we use the complexity of the general case for merge sort. But, here we are always merging two sorted arrays, so I think we just have to make one pass of length 2m for each sort.

    More generally, sometimes an algorithm or method with worse complexity is better in practice compared to one with better complexity.

    • I agree about the first approach. Determining the median range doesn’t take many operations. It’s certainly faster than computing min and max in the whole matrix. I just saw that I was computing global min and max by comparing each element. I’ve updated it.

      Median min-max range can be used in this way:

      int mn = INT_MAX, mx = INT_MIN, n = A.size(), m = A[0].size();
      for (int i = 0; i < n; ++i) {
          int median = A[i][m/2];
          if (median < mn)
              mn = median;
          if (median > mx)
              mx = median;
      }
      

      I don’t agree about the second approach though. Even if we could sort the whole matrix in-place by merging 1st row with 2nd, first 2 rows with 3rd, and so on. This has complexity M+M+M+2M+…+M+(N-1)M = M*(N-1 + N*(N-1)/2) = O(M*N*N), which is too much. Of course, sometimes an algorithm with worse complexity beats another with better complexity in practice on inputs of real importance but here, this approach would even have a larger constant and would be far more difficult to implement.

    • We don’t need to check and it’s count(elements= 1 + (N*M)/2. I think this is the most straightforward explanation: there’s always at least one matrix element in the range [mn..mx] (true at the beginning and ensured by the way bounds are updated in the loop) and we break out of the loop when mn==mx so mn must exist in the matrix.

Leave a reply to Ed Klotz Cancel reply