A simple approach to segment trees

A segment tree is a tree data structure that allows aggregation queries and updates over array intervals in logarithmic time. As I see it, there are three major use cases for segment trees:

  1. Static segment trees: This is probably the most common use case. We preprocess an array of N elements to construct a segment tree in O(N). Now, we can query aggregates over any arbitrary range/segment of the array in O(log N).
  2. Segment tree with point updates: This allows us to update array values, one at a time in O(log N), while still maintaining the segment tree structure. Queries over any arbitrary range still occurs in O(log N).
  3. Segment tree with range updates: This allows us to update a range of array elements at once in O(N) in the worst case, however problem specific optimizations and lazy propagation typically give huge improvements. Queries over any arbitrary range still occurs in O(log N).

In this post, I’ll cover the first two use cases because they go together. Given a static segment tree, it is very easy to add point update capability to it. I’ll leave the third use case as the subject matter of a future blog post. I intend this post to be a practical introduction to segment trees, rather than a theoretical description, so it will focus on how we can divide a segment tree into its components, the working of each component and how we can separate the problem specific logic from the underlying data structure. We’ll build a template for a segment tree and then apply it to several problems to understand how problem specific logic can be cleanly separated from the template.

Structure of a segment tree
Let’s understand what a segment tree looks like. Each node in a segment tree stores aggregate statistics for some range/segment of an array. The leaf nodes stores aggregate statistics for individual array elements. Although a segment tree is a tree, it is stored in an array similar to a heap. If the input array had 2n elements (i.e., the number of elements were a power of 2), then the segment tree over it would look something like this:
Structure of a segment tree
Each node here shows the segment of the input array for which it is responsible. The number outside a node indicates its index in the segment tree array. Clearly, if the array size N were a power of 2, then the segment tree would have 2*N-1 nodes. It is simpler to store the first node at index 1 in the segment tree array in order to simplify the process of finding indices of left and right children (a node at index i has left and right children at 2*i and 2*i+1 respectively). Thus, for an input array of size N, an array of size 2*N would be required to store the segment tree.

In practice, however, N is not usually a power of 2, so we have to find the power of 2 immediately greater than N, let’s call it x, and allocate an array of size 2*x to store the segment tree. The following procedure calculates the size of array required to store a segment tree for an input array size N:

int getSegmentTreeSize(int N) {
int size = 1;
for (; size < N; size <<= 1);
return size << 1;
}

We’ll try to separate the implementation of the underlying data structure from the problem specific logic. For this purpose, let us define a structure for a segment tree node:
struct SegmentTreeNode {
// variables to store aggregate statistics and
// any other information required to merge these
// aggregate statistics to form parent nodes
void assignLeaf(T value) {
// T is the type of input array element
// Given the value of an input array element,
// build aggregate statistics for this leaf node
}
void merge(SegmentTreeNode& left, SegmentTreeNode& right) {
// merge the aggregate statistics of left and right
// children to form the aggregate statistics of
// their parent node
}
V getValue() {
// V is the type of the required aggregate statistic
// return the value of required aggregate statistic
// associated with this node
}
};

Building a segment tree
We can build a segment tree recursively in a depth first manner, starting at the root node (representative of the whole input array), working our way towards the leaves (representatives of individual input array elements). Once both children of a node have returned, we can merge their aggregate statistics to form their parent node.

void buildTree(T arr[], int stIndex, int lo, int hi) {
if (lo == hi) {
nodes[stIndex].assignLeaf(arr[lo]);
return;
}
int left = 2 * stIndex, right = left + 1, mid = (lo + hi) / 2;
buildTree(arr, left, lo, mid);
buildTree(arr, right, mid + 1, hi);
nodes[stIndex].merge(nodes[left], nodes[right]);
}

Here I’ve assumed that the type of input array elements is T. stIndex represents the index of current segment tree node in the segment tree array, lo and hi indicate the range/segment of input array this node is responsible for. We build the whole segment tree with a single call to buildTree(arr, 1, 0, N-1), where N is the size of input array arr. Clearly, the time complexity of this procedure is O(N), assuming that assignLeaf() and merge() operations work in O(1).

Querying the segment tree
Suppose we want to query the aggregate statistic associated with the segment [lo,hi], we can do this recursively as follows:

// V is the type of the required aggregate statistic
V getValue(int lo, int hi) {
SegmentTreeNode result = getValue(1, 0, N-1, lo, hi);
return result.getValue();
}
// nodes[stIndex] is responsible for the segment [left, right]
// and we want to query for the segment [lo, hi]
SegmentTreeNode getValue(int stIndex, int left, int right, int lo, int hi) {
if (left == lo && right == hi)
return nodes[stIndex];
int mid = (left + right) / 2;
if (lo > mid)
return getValue(2*stIndex+1, mid+1, right, lo, hi);
if (hi <= mid)
return getValue(2*stIndex, left, mid, lo, hi);
SegmentTreeNode leftResult = getValue(2*stIndex, left, mid, lo, mid);
SegmentTreeNode rightResult = getValue(2*stIndex+1, mid+1, right, mid+1, hi);
SegmentTreeNode result;
result.merge(leftResult, rightResult);
return result;
}

This procedure is similar to the one used for building the segment tree, except that we cut off recursion when we reach a desired segment. The complexity of this procedure is O(log N).

Updating the segment tree
The above two procedures, building the segment tree and querying it, are sufficient for the first use case: a static segment tree. It so happens that the second use case: point updates, doesn’t require many changes. In fact, we don’t have to change the problem specific logic at all. No changes in the structure SegmentTreeNode are required.
We just need to add in a procedure for updating the segment tree. It is very similar to the buildTree() procedure, the only difference being that it follows only one path down the tree (the one that leads to the leaf node being updated) and comes back up, recursively updating parent nodes along this same path.

// We want to update the value associated with index in the input array
void update(int index, T value) {
update(1, 0, N-1, index, value);
}
// nodes[stIndex] is responsible for segment [lo, hi]
void update(int stIndex, int lo, int hi, int index, T value) {
if (lo == hi) {
nodes[stIndex].assignLeaf(value);
return;
}
int left = 2 * stIndex, right = left + 1, mid = (lo + hi) / 2;
if (index <= mid)
update(left, lo, mid, index, value);
else
update(right, mid+1, hi, index, value);
nodes[stIndex].merge(nodes[left], nodes[right]);
}

Clearly, the complexity of this operation is O(log N), assuming that assignLeaf() and merge() work in O(1).

Segment Tree template
Let’s put all this together to complete the template for a segment tree.

// T is the type of input array elements
// V is the type of required aggregate statistic
template<class T, class V>
class SegmentTree {
SegmentTreeNode* nodes;
int N;
public:
SegmentTree(T arr[], int N) {
this->N = N;
nodes = new SegmentTreeNode[getSegmentTreeSize(N)];
buildTree(arr, 1, 0, N-1);
}
~SegmentTree() {
delete[] nodes;
}
V getValue(int lo, int hi) {
SegmentTreeNode result = getValue(1, 0, N-1, lo, hi);
return result.getValue();
}
void update(int index, T value) {
update(1, 0, N-1, index, value);
}
private:
void buildTree(T arr[], int stIndex, int lo, int hi) {
if (lo == hi) {
nodes[stIndex].assignLeaf(arr[lo]);
return;
}
int left = 2 * stIndex, right = left + 1, mid = (lo + hi) / 2;
buildTree(arr, left, lo, mid);
buildTree(arr, right, mid + 1, hi);
nodes[stIndex].merge(nodes[left], nodes[right]);
}
SegmentTreeNode getValue(int stIndex, int left, int right, int lo, int hi) {
if (left == lo && right == hi)
return nodes[stIndex];
int mid = (left + right) / 2;
if (lo > mid)
return getValue(2*stIndex+1, mid+1, right, lo, hi);
if (hi <= mid)
return getValue(2*stIndex, left, mid, lo, hi);
SegmentTreeNode leftResult = getValue(2*stIndex, left, mid, lo, mid);
SegmentTreeNode rightResult = getValue(2*stIndex+1, mid+1, right, mid+1, hi);
SegmentTreeNode result;
result.merge(leftResult, rightResult);
return result;
}
int getSegmentTreeSize(int N) {
int size = 1;
for (; size < N; size <<= 1);
return size << 1;
}
void update(int stIndex, int lo, int hi, int index, T value) {
if (lo == hi) {
nodes[stIndex].assignLeaf(value);
return;
}
int left = 2 * stIndex, right = left + 1, mid = (lo + hi) / 2;
if (index <= mid)
update(left, lo, mid, index, value);
else
update(right, mid+1, hi, index, value);
nodes[stIndex].merge(nodes[left], nodes[right]);
}
};

We shall now see how this template can be used to solve different problems, without requiring a change in the tree implementation, and how the structure SegmentTreeNode is implemented differently for different problems.

The first problem we’ll look at it is GSS1. This problem asks for a solution to maximum subarray problem for each range of an array. My objective here is not to explain how to solve this problem, rather to demonstrate how easily it can be implemented with the above template at hand.
As it turns out, we need to store 4 values in each segment tree node to be able to merge child nodes to form a solution to their parent’s node:

  1. Maximum sum of a subarray, starting at the leftmost index of this range
  2. Maximum sum of a subarray, ending at the rightmost index of this range
  3. Maximum sum of any subarray in this range
  4. Sum of all elements in this range

The SegmentTreeNode for this problem looks as follows:

struct SegmentTreeNode {
int prefixMaxSum, suffixMaxSum, maxSum, sum;
void assignLeaf(int value) {
prefixMaxSum = suffixMaxSum = maxSum = sum = value;
}
void merge(SegmentTreeNode& left, SegmentTreeNode& right) {
sum = left.sum + right.sum;
prefixMaxSum = max(left.prefixMaxSum, left.sum + right.prefixMaxSum);
suffixMaxSum = max(right.suffixMaxSum, right.sum + left.suffixMaxSum);
maxSum = max(prefixMaxSum, max(suffixMaxSum, max(left.maxSum, max(right.maxSum, left.suffixMaxSum + right.prefixMaxSum))));
}
int getValue() {
return maxSum;
}
};

The complete solution for this problem can be viewed here.

The second problem we’ll look at is GSS3, which is very similar to GSS1 with the only difference being that it also asks for updates to array elements, while still maintaining the structure for getting maximum subarray sum. Now, we can understand the advantage of separating problem specific logic from the segment tree implementation. This problem requires no changes to the template and even uses the same SegmentTreeNode as used for GSS1. The complete solution for this problem can be viewed here.

The third problem: BRCKTS, we’ll look at is very different from the first two but the differences are only superficial since we’ll be able to solve it using the same structure. This problem gives a string containing parenthesis (open and closed), requires making updates to individual parenthesis (changing an open parenthesis to closed or vice versa), and checking if the whole string represents a correct parenthesization.
As it turns out, we need only 2 things in each segment tree node:

  1. The number of unmatched open parenthesis in this range
  2. The number of unmatched closed parenthesis in this range

The SegmentTreeNode for this problem looks as follows:

struct SegmentTreeNode {
int unmatchedOpenParans, unmatchedClosedParans;
void assignLeaf(char paranthesis) {
if (paranthesis == '(')
unmatchedOpenParans = 1, unmatchedClosedParans = 0;
else
unmatchedOpenParans = 0, unmatchedClosedParans = 1;
}
void merge(SegmentTreeNode& left, SegmentTreeNode& right) {
int newMatches = min(left.unmatchedOpenParans, right.unmatchedClosedParans);
unmatchedOpenParans = right.unmatchedOpenParans + left.unmatchedOpenParans - newMatches;
unmatchedClosedParans = left.unmatchedClosedParans + right.unmatchedClosedParans - newMatches;
}
bool getValue() {
return unmatchedOpenParans == 0 && unmatchedClosedParans == 0;
}
};

The complete solution for this problem can be viewed here.

The final problem we’ll look at in this post is KGSS. This problem asks for the maximum pair sum in each subarray and also requires updates to individual array elements. As it turns out, we only need to store 2 things in each segment tree node:

  1. The maximum value in this range
  2. The second maximum value in this range

The SegmentTreeNode for this problem looks as follows:

struct SegmentTreeNode {
int maxNum, secondMaxNum;
void assignLeaf(int num) {
maxNum = num;
secondMaxNum = -1;
}
void merge(SegmentTreeNode& left, SegmentTreeNode& right) {
maxNum = max(left.maxNum, right.maxNum);
secondMaxNum = min(max(left.maxNum, right.secondMaxNum), max(right.maxNum, left.secondMaxNum));
}
int getValue() {
return maxNum + secondMaxNum;
}
};

The complete solution for this problem can be viewed here.

I hope this post presented a gentle introduction to segment trees and I look forward to feedback for possible improvements and suggestions for a future post on segment trees with lazy propagation.

Continue to Part 2

68 thoughts on “A simple approach to segment trees

  1. Good work! Keep it up. What I think is that you should use a bit simple language and demonstrate at least one problem solving completely.Rest is great.

  2. Pingback: A simple approach to segment trees, part 2 | Everything Under The Sun

  3. Hi Kartik,
    First of all thanks for writing a tutorial which explains clearly everything there is to know about Segment Trees. Now onto my question which is somewhat related to the syntax which you have used for your solution to GSS1.

    Line 12 reads : void merge(SegmentTreeNode& left, SegmentTreeNode& right)
    Why have you passed the addresses of left and right SegmentTreeNode when just writing ‘SegmentTreeNode left’ and ‘SegmentTreeNode right’ also works?

    • Writing just SegmentTreeNode also works but is less efficient because a copy of the whole node will be made. Since we call the merge() method so many times, it makes sense to do this optimization.

    • Time constraints on this problem are pretty tight and Java is pretty slow compared to C++. There are only 15 accepted solutions in Java for this problem.

      I’m afraid it may not be possible to solve this question, under the given time constraints, with the separation of SegmentTreeNode from SegmentTree and generic tree operations. You may have to resort to using several plain arrays to store the data required by segment tree and operate on them directly.

  4. In GSS1 solution, why don’t we consider ‘sum’ as one of the variables while calculating the maxsum.
    i.e. maxSum = max(max(prefixMaxSum, max(suffixMaxSum, max(left.maxSum, max(right.maxSum, left.suffixMaxSum + right.prefixMaxSum)))),sum); instead of maxSum = max(prefixMaxSum, max(suffixMaxSum, max(left.maxSum, max(right.maxSum, left.suffixMaxSum + right.prefixMaxSum))));

    The ‘sum’ indicates the total sum of that particular range and I think that might also be the candidate for maximum continous sum value for that range.

  5. I am having problems while understanding “segment tree node” part.can you please help me out? other than that the tutorial is great.thanks!

    • There are two parts to a segment tree:
      1. A problem independent part: the basic data structure, how the tree is constructed, queried and updated.
      2. A problem specific part: what data is stored in each node of the tree, how data from children nodes are combined to form the parent node.

      Usually people don’t differentiate between the two, keep the data in multiple arrays and reimplement buildTree(), query() and update() functions for each problem.

      But we can separate these two parts, keep the code for the data structure in its own class and supply the problem specific logic in a SegmentTreeNode class. I chose the functions implemented inside SegmentTreeNode which can be used for many such problems. These functions and the data stored inside the node have to be redefined for each new problem.

      I hope this cleared some things.

  6. This blog is amazing. Finally understood Segment tree basics!
    Update — Try discussing Problem QSET from CodeChef Jan Challenge in this blog. It will be a nice addition!

  7. Thanks a lot for this post, best tutorial on Segment Trees.
    Although, I’m a little confused in BRCKTS problem.
    Are you just checking that the number of opening brackets should be equal to number of closing brackets, for correct order? Don’t we also need to check if the opening bracket appears before the closing bracket ?

    • It’s not that hard. We just need to find out what information we must store in each node to be able to combine it in parent nodes. For this problem, it turns out that we need 4 things: sum, prefixMaxSum, suffixMaxSum and maxSum. The merge() method combines these things coming from child nodes (left and right half-segments) to form the parent node (combined segment).

      The sum would just be the sum of two segments. prefixMaxSum is the max of left.prefixMaxSum or left.sum + right.prefixMaxSum. This is intuitive, we can either take the max left sum of left segment or sum of left segment + max left sum of right segment. Similarly for suffixMaxSum.

      maxSum can only be the max of prefixMaxSum, suffixMaxSum, left.maxSum, right.MaxSum or left.suffixMaxSum + right.prefixMaxSum.

      Do let me know if anything is unclear.

      • Actually I cannot convince myself that the recursive path the query() is taking, will lead to correct answer.

        • The query() method, or the getValue() method in the presented code, does not take a single root to leaf path down the tree but is supposed to summarize the whole tree to find out whether it is a correct parenthesization or not. Please do take a look at the complete solution, whose link is provided below the SegmentTreeNode code.

  8. why does your query function have if (left == lo && right == hi)? Shouldn’t it be if(left>=lo&&right<=hi), because we include the contribution if the segment lies in the given range?

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s