xref: /linux-6.15/lib/interval_tree.c (revision fff3fd8a)
1 #include <linux/init.h>
2 #include <linux/interval_tree.h>
3 
4 /* Callbacks for augmented rbtree insert and remove */
5 
6 static inline unsigned long
7 compute_subtree_last(struct interval_tree_node *node)
8 {
9 	unsigned long max = node->last, subtree_last;
10 	if (node->rb.rb_left) {
11 		subtree_last = rb_entry(node->rb.rb_left,
12 			struct interval_tree_node, rb)->__subtree_last;
13 		if (max < subtree_last)
14 			max = subtree_last;
15 	}
16 	if (node->rb.rb_right) {
17 		subtree_last = rb_entry(node->rb.rb_right,
18 			struct interval_tree_node, rb)->__subtree_last;
19 		if (max < subtree_last)
20 			max = subtree_last;
21 	}
22 	return max;
23 }
24 
25 RB_DECLARE_CALLBACKS(static, augment_callbacks, struct interval_tree_node, rb,
26 		     unsigned long, __subtree_last, compute_subtree_last)
27 
28 /* Insert / remove interval nodes from the tree */
29 
30 void interval_tree_insert(struct interval_tree_node *node,
31 			  struct rb_root *root)
32 {
33 	struct rb_node **link = &root->rb_node, *rb_parent = NULL;
34 	unsigned long start = node->start, last = node->last;
35 	struct interval_tree_node *parent;
36 
37 	while (*link) {
38 		rb_parent = *link;
39 		parent = rb_entry(rb_parent, struct interval_tree_node, rb);
40 		if (parent->__subtree_last < last)
41 			parent->__subtree_last = last;
42 		if (start < parent->start)
43 			link = &parent->rb.rb_left;
44 		else
45 			link = &parent->rb.rb_right;
46 	}
47 
48 	node->__subtree_last = last;
49 	rb_link_node(&node->rb, rb_parent, link);
50 	rb_insert_augmented(&node->rb, root, &augment_callbacks);
51 }
52 
53 void interval_tree_remove(struct interval_tree_node *node,
54 			  struct rb_root *root)
55 {
56 	rb_erase_augmented(&node->rb, root, &augment_callbacks);
57 }
58 
59 /*
60  * Iterate over intervals intersecting [start;last]
61  *
62  * Note that a node's interval intersects [start;last] iff:
63  *   Cond1: node->start <= last
64  * and
65  *   Cond2: start <= node->last
66  */
67 
68 static struct interval_tree_node *
69 subtree_search(struct interval_tree_node *node,
70 	       unsigned long start, unsigned long last)
71 {
72 	while (true) {
73 		/*
74 		 * Loop invariant: start <= node->__subtree_last
75 		 * (Cond2 is satisfied by one of the subtree nodes)
76 		 */
77 		if (node->rb.rb_left) {
78 			struct interval_tree_node *left =
79 				rb_entry(node->rb.rb_left,
80 					 struct interval_tree_node, rb);
81 			if (start <= left->__subtree_last) {
82 				/*
83 				 * Some nodes in left subtree satisfy Cond2.
84 				 * Iterate to find the leftmost such node N.
85 				 * If it also satisfies Cond1, that's the match
86 				 * we are looking for. Otherwise, there is no
87 				 * matching interval as nodes to the right of N
88 				 * can't satisfy Cond1 either.
89 				 */
90 				node = left;
91 				continue;
92 			}
93 		}
94 		if (node->start <= last) {		/* Cond1 */
95 			if (start <= node->last)	/* Cond2 */
96 				return node;	/* node is leftmost match */
97 			if (node->rb.rb_right) {
98 				node = rb_entry(node->rb.rb_right,
99 					struct interval_tree_node, rb);
100 				if (start <= node->__subtree_last)
101 					continue;
102 			}
103 		}
104 		return NULL;	/* No match */
105 	}
106 }
107 
108 struct interval_tree_node *
109 interval_tree_iter_first(struct rb_root *root,
110 			 unsigned long start, unsigned long last)
111 {
112 	struct interval_tree_node *node;
113 
114 	if (!root->rb_node)
115 		return NULL;
116 	node = rb_entry(root->rb_node, struct interval_tree_node, rb);
117 	if (node->__subtree_last < start)
118 		return NULL;
119 	return subtree_search(node, start, last);
120 }
121 
122 struct interval_tree_node *
123 interval_tree_iter_next(struct interval_tree_node *node,
124 			unsigned long start, unsigned long last)
125 {
126 	struct rb_node *rb = node->rb.rb_right, *prev;
127 
128 	while (true) {
129 		/*
130 		 * Loop invariants:
131 		 *   Cond1: node->start <= last
132 		 *   rb == node->rb.rb_right
133 		 *
134 		 * First, search right subtree if suitable
135 		 */
136 		if (rb) {
137 			struct interval_tree_node *right =
138 				rb_entry(rb, struct interval_tree_node, rb);
139 			if (start <= right->__subtree_last)
140 				return subtree_search(right, start, last);
141 		}
142 
143 		/* Move up the tree until we come from a node's left child */
144 		do {
145 			rb = rb_parent(&node->rb);
146 			if (!rb)
147 				return NULL;
148 			prev = &node->rb;
149 			node = rb_entry(rb, struct interval_tree_node, rb);
150 			rb = node->rb.rb_right;
151 		} while (prev == rb);
152 
153 		/* Check if the node intersects [start;last] */
154 		if (last < node->start)		/* !Cond1 */
155 			return NULL;
156 		else if (start <= node->last)	/* Cond2 */
157 			return node;
158 	}
159 }
160