1 #include <linux/init.h> 2 #include <linux/interval_tree.h> 3 4 /* Callbacks for augmented rbtree insert and remove */ 5 6 static inline unsigned long 7 compute_subtree_last(struct interval_tree_node *node) 8 { 9 unsigned long max = node->last, subtree_last; 10 if (node->rb.rb_left) { 11 subtree_last = rb_entry(node->rb.rb_left, 12 struct interval_tree_node, rb)->__subtree_last; 13 if (max < subtree_last) 14 max = subtree_last; 15 } 16 if (node->rb.rb_right) { 17 subtree_last = rb_entry(node->rb.rb_right, 18 struct interval_tree_node, rb)->__subtree_last; 19 if (max < subtree_last) 20 max = subtree_last; 21 } 22 return max; 23 } 24 25 RB_DECLARE_CALLBACKS(static, augment_callbacks, struct interval_tree_node, rb, 26 unsigned long, __subtree_last, compute_subtree_last) 27 28 /* Insert / remove interval nodes from the tree */ 29 30 void interval_tree_insert(struct interval_tree_node *node, 31 struct rb_root *root) 32 { 33 struct rb_node **link = &root->rb_node, *rb_parent = NULL; 34 unsigned long start = node->start, last = node->last; 35 struct interval_tree_node *parent; 36 37 while (*link) { 38 rb_parent = *link; 39 parent = rb_entry(rb_parent, struct interval_tree_node, rb); 40 if (parent->__subtree_last < last) 41 parent->__subtree_last = last; 42 if (start < parent->start) 43 link = &parent->rb.rb_left; 44 else 45 link = &parent->rb.rb_right; 46 } 47 48 node->__subtree_last = last; 49 rb_link_node(&node->rb, rb_parent, link); 50 rb_insert_augmented(&node->rb, root, &augment_callbacks); 51 } 52 53 void interval_tree_remove(struct interval_tree_node *node, 54 struct rb_root *root) 55 { 56 rb_erase_augmented(&node->rb, root, &augment_callbacks); 57 } 58 59 /* 60 * Iterate over intervals intersecting [start;last] 61 * 62 * Note that a node's interval intersects [start;last] iff: 63 * Cond1: node->start <= last 64 * and 65 * Cond2: start <= node->last 66 */ 67 68 static struct interval_tree_node * 69 subtree_search(struct interval_tree_node *node, 70 unsigned long start, unsigned long last) 71 { 72 while (true) { 73 /* 74 * Loop invariant: start <= node->__subtree_last 75 * (Cond2 is satisfied by one of the subtree nodes) 76 */ 77 if (node->rb.rb_left) { 78 struct interval_tree_node *left = 79 rb_entry(node->rb.rb_left, 80 struct interval_tree_node, rb); 81 if (start <= left->__subtree_last) { 82 /* 83 * Some nodes in left subtree satisfy Cond2. 84 * Iterate to find the leftmost such node N. 85 * If it also satisfies Cond1, that's the match 86 * we are looking for. Otherwise, there is no 87 * matching interval as nodes to the right of N 88 * can't satisfy Cond1 either. 89 */ 90 node = left; 91 continue; 92 } 93 } 94 if (node->start <= last) { /* Cond1 */ 95 if (start <= node->last) /* Cond2 */ 96 return node; /* node is leftmost match */ 97 if (node->rb.rb_right) { 98 node = rb_entry(node->rb.rb_right, 99 struct interval_tree_node, rb); 100 if (start <= node->__subtree_last) 101 continue; 102 } 103 } 104 return NULL; /* No match */ 105 } 106 } 107 108 struct interval_tree_node * 109 interval_tree_iter_first(struct rb_root *root, 110 unsigned long start, unsigned long last) 111 { 112 struct interval_tree_node *node; 113 114 if (!root->rb_node) 115 return NULL; 116 node = rb_entry(root->rb_node, struct interval_tree_node, rb); 117 if (node->__subtree_last < start) 118 return NULL; 119 return subtree_search(node, start, last); 120 } 121 122 struct interval_tree_node * 123 interval_tree_iter_next(struct interval_tree_node *node, 124 unsigned long start, unsigned long last) 125 { 126 struct rb_node *rb = node->rb.rb_right, *prev; 127 128 while (true) { 129 /* 130 * Loop invariants: 131 * Cond1: node->start <= last 132 * rb == node->rb.rb_right 133 * 134 * First, search right subtree if suitable 135 */ 136 if (rb) { 137 struct interval_tree_node *right = 138 rb_entry(rb, struct interval_tree_node, rb); 139 if (start <= right->__subtree_last) 140 return subtree_search(right, start, last); 141 } 142 143 /* Move up the tree until we come from a node's left child */ 144 do { 145 rb = rb_parent(&node->rb); 146 if (!rb) 147 return NULL; 148 prev = &node->rb; 149 node = rb_entry(rb, struct interval_tree_node, rb); 150 rb = node->rb.rb_right; 151 } while (prev == rb); 152 153 /* Check if the node intersects [start;last] */ 154 if (last < node->start) /* !Cond1 */ 155 return NULL; 156 else if (start <= node->last) /* Cond2 */ 157 return node; 158 } 159 } 160