A simple approach to segment trees, part 2

This post is in continuation to a previous post on segment trees. If you haven’t read that post already, it would be a good idea to read that first.

Let me do a quick recap. The previous post described three use cases for segment trees (persistent/static, point updates and range updates) and explained the first two, leaving the last as the subject matter for this post. This post will describe how we can use range updates with segment trees, lazy propagation and various optimizations possible.

I tried to come up with a single template for segment trees which supported all the 3 use cases, had lazy propagation and all the optimizations that come with it but finally I decided it wasn’t worth the effort for following reasons:

  • The more general I made the template, the slower it ran. Most problems which require segment trees with lazy propagation have very strict time constraints on online judges and I had to run through hoops and do some serious gymnastics to get the solutions to pass.
  • Supporting different flavors of lazy propagation and various optimizations places huge requirements on the SegmentTreeNode (described in previous post), requiring it to implement as many as 10 functions for any problem. I guess there comes a point when the cure becomes more harmful than the disease.
  • Pretty much every new problem required me to generalize the template in some way or add more functions to SegmentTreeNode. This became unusable pretty soon.

For these reasons, I decided that different ideas in lazy propagation and various optimizations that come with it are hard to combine, to say the least, and probably not worth the effort. So I’ll take a different approach. I’ll start with a partial template and fill bits and pieces of it to demonstrate different ideas.

struct SegmentTreeNode {
int start, end; // this node is responsible for the segment [start...end]
// variables to store aggregate statistics and
// any other information required to merge these
// aggregate statistics to form parent nodes
void assignLeaf(InputType value) {
// InputType is the type of input array element
// Given the value of an input array element,
// build aggregate statistics for this leaf node
}
void merge(SegmentTreeNode& left, SegmentTreeNode& right) {
// merge the aggregate statistics of left and right
// children to form the aggregate statistics of
// their parent node
}
OutputType query() {
// OutputType is the type of the required aggregate statistic
// return the value of required aggregate statistic
// associated with this node
}
};
template<class InputType, class UpdateType, class OutputType>
class SegmentTree {
SegmentTreeNode* nodes;
int N;
public:
SegmentTree(InputType arr[], int N) {
this->N = N;
nodes = new SegmentTreeNode[getSegmentTreeSize(N)];
buildTree(arr, 1, 0, N-1);
}
~SegmentTree() {
delete[] nodes;
}
// get the value associated with the segment [start...end]
OutputType query(int start, int end) {
SegmentTreeNode result = query(1, start, end);
return result.query();
}
// range update: update the range [start...end] by value
// Exactly what is meant by an update is determined by the
// problem statement and that logic is captured in segment tree node
void update(int start, int end, UpdateType value) {
update(1, start, end, value);
}
private:
void buildTree(InputType arr[], int stIndex, int start, int end) {
// nodes[stIndex] is responsible for the segment [start...end]
nodes[stIndex].start = start, nodes[stIndex].end = end;
if (start == end) {
// a leaf node is responsible for a segment containing only 1 element
nodes[stIndex].assignLeaf(arr[start]);
return;
}
int mid = (start + end) / 2,
leftChildIndex = 2 * stIndex,
rightChildIndex = leftChildIndex + 1;
buildTree(arr, leftChildIndex, start, mid);
buildTree(arr, rightChildIndex, mid + 1, end);
nodes[stIndex].merge(nodes[leftChildIndex], nodes[rightChildIndex]);
}
int getSegmentTreeSize(int N) {
int size = 1;
for (; size < N; size <<= 1);
return size << 1;
}
SegmentTreeNode query(int stIndex, int start, int end) {
// we'll fill this in later
}
void update(int stIndex, int start, int end, UpdateType value) {
// we'll fill this in later
}
};

We now only have to describe how to fill in update() and query() methods.

The previous post provided us a way to update a single element in O(log N) time. We can use it to achieve range update in O(N log N) by calling the point update function iteratively for each element in the given range. But we can do better! We can in fact do a range update in O(N). The range update function is very similar to the buildTree() function.

void update(int stIndex, int start, int end, UpdateType value) {
if (start == end) {
nodes[stIndex].applyUpdate(value);
return;
}
int mid = (nodes[stIndex].start + nodes[stIndex].end) / 2,
leftChildIndex = 2 * stIndex,
rightChildIndex = leftChildIndex + 1;
if (start > mid)
update(rightChildIndex, start, end, value);
else if (end <= mid)
update(leftChildIndex, start, end, value);
else {
update(leftChildIndex, start, mid, value);
update(rightChildIndex, mid+1, end, value);
}
nodes[stIndex].merge(nodes[leftChildIndex], nodes[rightChildIndex]);
}

This along with the query() function defined in the previous post allows us to do range updates in O(N) and range queries in O(log N). Note that this requires us to add a function void applyUpdate(UpdateType value) to our SegmentTreeNode.

Let’s see it in use on a sample problem.

Problem:

Given an array A of N floating point values, support two operations on any range A[a..b] (0<=a<=b<N):

  • replace each A[i] (a<=i<=b) by sqrt(A[i])
  • find sum A[a]+…+A[b]

Solution: The SegmentTreeNode for this problem looks like this:

struct SegmentTreeNode {
int start, end; // this node is responsible for the segment [start...end]
double total;
void assignLeaf(double value) {
total = value;
}
void merge(SegmentTreeNode& left, SegmentTreeNode& right) {
total = left.total + right.total;
}
double query() {
return total;
}
// the value of the update is dummy in this case
void applyUpdate(bool value) {
total = sqrt(total);
}
};

The complete solution for this problem is available here.

For most problems, however, we can’t get away with an update function having linear complexity and must do better. This is where lazy propagation comes in. The basic idea of lazy propagation is to hold off propagating updates down the tree and propagate them only when absolutely necessary. Lazy propagation can be built into any of update() or query() functions but here I’ll describe it only for update(). There are several reasons for why we might be able to hold off updates in internal nodes:

  • Some updates are not really required and can be thrown away. Doing nothing is awesome!
  • For some problems, updates can be applied to internal nodes (segments/ranges) as opposed to leaves (actual array elements).
  • For some problems, updates can be accumulated in internal nodes until they cross a threshold and only afterwards must they be propagated.

Let’s see some examples of these ideas on actual problems.

Problem (GSS4):

Given an array A of N integers, support two operations on any range A[a..b] (0<=a<=b<N):

  • replace each A[i] (a<=i<=b) by floor(sqrt(A[i]))
  • find sum A[a]+…+A[b]

Solution:

There is only a small difference between this problem and the previous one and that is that all array elements are guaranteed to be integers, even after applying updates. This provides us an opportunity for optimization: There’s only so many times you can take square root before a number reduces to 1 and once a number reduces to 1, it stays 1 forever. So, we can throw out updates for nodes which are 1. This problem demonstrates the first idea that we can sometimes throw away updates.

We can modify the update() function such that at each internal node, it decides whether to propagate an update or throw it out. The query() function remains unchanged.

void update(int stIndex, int start, int end, UpdateType value) {
if (nodes[stIndex].start == start && nodes[stIndex].end == end) {
lazyPropagatePendingUpdateToSubtree(stIndex, value);
return;
}
int mid = (nodes[stIndex].start + nodes[stIndex].end) >> 1,
leftChildIndex = stIndex << 1,
rightChildIndex = leftChildIndex + 1;
if (start > mid)
update(rightChildIndex, start, end, value);
else if (end <= mid)
update(leftChildIndex, start, end, value);
else {
update(leftChildIndex, start, mid, value);
update(rightChildIndex, mid+1, end, value);
}
nodes[stIndex].merge(nodes[leftChildIndex], nodes[rightChildIndex]);
}
void lazyPropagatePendingUpdateToSubtree(int stIndex, UpdateType value) {
nodes[stIndex].addUpdate(value);
if (!nodes[stIndex].isPropagationRequired())
return;
if (nodes[stIndex].start == nodes[stIndex].end) {
nodes[stIndex].applyPendingUpdate();
return;
}
UpdateType pendingUpdate = nodes[stIndex].getPendingUpdate();
nodes[stIndex].clearPendingUpdate();
int mid = (nodes[stIndex].start + nodes[stIndex].end) >> 1,
leftChildIndex = stIndex << 1,
rightChildIndex = leftChildIndex + 1;
lazyPropagatePendingUpdateToSubtree(leftChildIndex, pendingUpdate);
lazyPropagatePendingUpdateToSubtree(rightChildIndex, pendingUpdate);
nodes[stIndex].merge(nodes[leftChildIndex], nodes[rightChildIndex]);
}

I’ve tried to make the update() function as general as I could so that it works for several similar problems, without requiring any change. However, this requires 5 new functions to be implemented in SegmentTreeNode, although they are all 1-liners. The SegmentTreeNode for this problem looks like this:

struct SegmentTreeNode {
int start, end; // this node is responsible for the segment [start...end]
ll total;
bool pendingUpdate;
SegmentTreeNode() : total(0), pendingUpdate(false) {}
void assignLeaf(ll value) {
total = value;
}
void merge(SegmentTreeNode& left, SegmentTreeNode& right) {
total = left.total + right.total;
}
ll query() {
return total;
}
// For this particular problem, propagation is not required
// if all elements in this segment are 1's
bool isPropagationRequired() {
return total > end-start+1;
}
void applyPendingUpdate() {
total = (ll) sqrt(total);
pendingUpdate = false;
}
// For this particular problem, the value of the update is dummy
// and is just an instruction to square root the leaf value
void addUpdate(bool value) {
pendingUpdate = true;
}
// returns a dummy value
bool getPendingUpdate() {
return true;
}
void clearPendingUpdate() {
pendingUpdate = false;
}
};

The complete solution for this problem is available here. Note that I had to define the nodes[] array outside the SegmentTree template. The time limit on this problem is too strict to allow memory allocation/deallocation per test case.

Problem (HORRIBLE):

Given an array A of N integers, support two operations on any range A[a..b] (0<=a<=b<N):

  • add a given value v to each A[i] (a<=i<=b)
  • find sum A[a]+…+A[b]

Solution:

This problem demonstrates the second idea: for some problems, updates can be applied to (or stored inside) internal nodes. Any node whose segment is completely covered by the update segment can just store the update value and not propagate it down the tree. The query() function, while coming back up the tree, picks up pending updates and applies them to the result.

SegmentTreeNode query(int stIndex, int start, int end) {
if (nodes[stIndex].start == start && nodes[stIndex].end == end) {
SegmentTreeNode result = nodes[stIndex];
if (result.hasPendingUpdate())
result.applyPendingUpdate();
return result;
}
int mid = (nodes[stIndex].start + nodes[stIndex].end) >> 1,
leftChildIndex = stIndex << 1,
rightChildIndex = leftChildIndex + 1;
SegmentTreeNode result;
if (start > mid)
result = query(rightChildIndex, start, end);
else if (end <= mid)
result = query(leftChildIndex, start, end);
else {
SegmentTreeNode leftResult = query(leftChildIndex, start, mid),
rightResult = query(rightChildIndex, mid+1, end);
result.start = leftResult.start;
result.end = rightResult.end;
result.merge(leftResult, rightResult);
}
if (nodes[stIndex].hasPendingUpdate()) {
result.addUpdate(nodes[stIndex].getPendingUpdate());
result.applyPendingUpdate();
}
return result;
}
void update(int stIndex, int start, int end, UpdateType value) {
if (nodes[stIndex].start == start && nodes[stIndex].end == end) {
nodes[stIndex].addUpdate(value);
return;
}
int mid = (nodes[stIndex].start + nodes[stIndex].end) >> 1,
leftChildIndex = stIndex << 1,
rightChildIndex = leftChildIndex + 1;
if (start > mid)
update(rightChildIndex, start, end, value);
else if (end <= mid)
update(leftChildIndex, start, end, value);
else {
update(leftChildIndex, start, mid, value);
update(rightChildIndex, mid+1, end, value);
}
nodes[stIndex].merge(nodes[leftChildIndex], nodes[rightChildIndex]);
}

The SegmentTreeNode for this problem becomes:

struct SegmentTreeNode {
int start, end; // this node is responsible for the segment [start...end]
ll total, pendingUpdate;
SegmentTreeNode() : total(0), pendingUpdate(0) {}
void assignLeaf(ll value) {
total = value;
}
void merge(SegmentTreeNode& left, SegmentTreeNode& right) {
total = left.total + right.total;
if (left.pendingUpdate > 0)
total += left.pendingUpdate * (left.end - left.start + 1);
if (right.pendingUpdate > 0)
total += right.pendingUpdate * (right.end - right.start + 1);
}
ll query() {
return total;
}
bool hasPendingUpdate() {
return pendingUpdate != 0;
}
void applyPendingUpdate() {
total += (end - start + 1) * pendingUpdate;
pendingUpdate = 0;
}
void addUpdate(ll value) {
pendingUpdate += value;
}
ll getPendingUpdate() {
return pendingUpdate;
}
};

Note how this change places special requirement on the merge() function. When we store an update inside an internal node, we ensure that queries to all other nodes (its ancestors and descendants) return correct results by:

  • modifying the query() function such that it collects pending updates while going back up the tree (recursion unrolling). This ensures that all descendants of a node will see an update made to that node.
  • modifying the merge() function such that update() function correctly updates all ancestor nodes of an updated node while going back up the tree (recursion unrolling).

The complete solution for this problem is available here.

Problem (FLIPCOIN):

Given an array A of N booleans, support two operations on any range A[a..b] (0<=a<=b<N):

  • flip each A[i] (a<=i<=b)
  • count how many of A[a],…,A[b] are true

Solution:

This problem is exactly similar to the previous problem. Here also, we can store updates in internal nodes, modify the merge() function to correctly update ancestors and query() function to correctly update descendants. The SegmentTree template for the previous problem can be used exactly as is for this problem. The only change is to SegmentTreeNode:

struct SegmentTreeNode {
int start, end; // this node is responsible for the segment [start...end]
int count;
bool pendingUpdate;
SegmentTreeNode() : count(0), pendingUpdate(false) {}
void assignLeaf(bool value) {}
void merge(SegmentTreeNode& left, SegmentTreeNode& right) {
count = (left.pendingUpdate ? (left.end - left.start + 1 - left.count) : left.count)
+ (right.pendingUpdate ? (right.end - right.start + 1 - right.count) : right.count);
}
int query() {
return count;
}
bool hasPendingUpdate() {
return pendingUpdate;
}
void applyPendingUpdate() {
count = (end - start + 1) - count;
pendingUpdate = false;
}
void addUpdate(bool value) {
pendingUpdate = !pendingUpdate;
}
bool getPendingUpdate() {
return true;
}
};

The complete solution for this problem is available here.

I hope this post provided a good introduction to range updates and lazy propagation in segment trees. There are some ideas that I’ve skipped, which are either very similar to what is discussed in this post (like lazy propagation in query() instead of in update()) or are rare (like holding off updates in internal nodes, until they cross a threshold, after which they should be lazily propagated). Both of these are natural extensions of what is described in the post. There may also exist some other ideas that I’m not yet aware of. Please share your thoughts, feedback and suggestions in comments.

22 thoughts on “A simple approach to segment trees, part 2

  1. Pingback: A simple approach to segment trees | Everything Under The Sun

  2. The explanation is excellent. Best i know on web .
    Please , if possible , give an example on lazy propagation on query and that of internal nodes with threshhold concept .

  3. How can we change the value of a range [x,y] to v , in logarithmic time? Eg, 1 4 7 would mean changing the values of all elements in indices 1 to 4 to 7.

    • We can achieve this through lazy propagation. We’ll lazily propagate the updates, stopping at the first segment that is completely covered by the update range and storing it there. Then, while querying if we go down that path, we can propagate the update downwards, essentially hiding the update time behind several query calls.

  4. How can we multiply each element of a range [x,y] by v , in logarithmic time? Eg, 1 4 7 would mean multiplying the values of all elements in indices 1 to 4 by 7.

  5. Hey Kartik, don’t you think that in FLIPCOIN, it’s nicer that the value would be a tuple with (#heads, #tails)? or (#heads, #total)? this way, the merge function would be self-contained.
    Note that in your code, you break the encapsulation of “merge” and outside it you handle the total number of coins:

    result.start = leftResult.start;
    result.end = rightResult.end;

    WDYS?

    • Hey, in your earlier comment, you said that storing segment start and end indices in SegmentTreeNode is better for encapsulation and code reuse. In this problem, there’s added benefit to storing them in the node because we needed them to calculate the statistics.

      I didn’t understand a few things about your question. Are you saying that the argument of assignLeaf() and addUpdate() should remain a boolean value but we should store the number of heads & tails in each node? Isn’t what is done almost the same? count stores the number of heads and we use (end-start+1) to calculate the total number of coins.

    • When we store start and end indices in SegmentTreeNode, we don’t have to pass those as arguments to query() and update() methods. Basically these are two ways to accomplish the same thing:
      1. Either you pass the start and end indices of the current interval to query() and update() methods, which looks more messy.
      2. Or store the start and end indices in the node itself and then you only have to pass the node index to these methods. It has added benefits if a node has to use the start and end index values for query or update (for an example, look at the solution of the problem HORRIBLE).

  6. can you please give a link to the question in which we have to use the third idea ?
    i.e. in which updates can be accumulated in internal nodes until they cross a threshold and only afterwards must they be propagated.

    • I’m sure there will be many problems that utilize this concept. A sample problem that I was able to think of: Given an integer array, support two operations: adding an integer to every element in a segment and querying how many elements are larger than X (a constant) in a segment. For this problem, we could store two values in each node: maximum value of any element in this segment and the sum of pending updates. Until the max + updates reaches X, we don’t have to propagate updates.

    • I think it would depend on the problem but in general, you will store some values for both updates in each node. In this example, how do you apply sqrt? It’s not an aggregate operation. Just sqrt all elements in a range?

Leave a comment