23#include <triqs/utility/first_include.hpp>
24#include <triqs/utility/exceptions.hpp>
29#include "./rbt_iterators.hpp"
34 struct rbt_insert_error {};
39 template <
typename Key,
typename Value,
typename Compare = std::less<Key>>
class rb_tree {
41 static const bool RED =
true;
42 static const bool BLACK =
false;
47 using node = node_t *;
50 struct node_t :
public Value {
54 node left =
nullptr, right =
nullptr;
55 bool modified, delete_flag;
57 node_t(Key
const &key, Value
const &val,
bool color,
int N)
58 : Value(val), key(key), color(color), N(N), left{nullptr}, right{nullptr}, modified(true), delete_flag(false) {}
60 node_t(node_t
const &n)
65 left(n.left ? new node_t(*n.left) : nullptr),
66 right(n.right ? new node_t(*n.right) : nullptr),
68 delete_flag(n.delete_flag) {}
69 node_t &operator=(node_t
const &) =
delete;
70 template <
typename... T>
void reset(Key
const &k, T &&... x) {
74 Value::reset(std::forward<T>(x)...);
78 using iterator = detail::rbt_iterator<rb_tree, node>;
79 using const_iterator = detail::rbt_iterator<const rb_tree, const node>;
80 iterator begin() noexcept {
return {
this,
false}; }
81 iterator end() noexcept {
return {
this,
true}; }
82 const_iterator begin() const noexcept {
return {
this,
false}; }
83 const_iterator end() const noexcept {
return {
this,
true}; }
84 const_iterator cbegin() const noexcept {
return {
this,
false}; }
85 const_iterator cend() const noexcept {
return {
this,
true}; }
93 template <
typename Fnt>
void apply_recursive(Fnt
const &f, node n)
const {
94 if (n->left) apply_recursive(f, n->left);
96 if (n->right) apply_recursive(f, n->right);
99 template <
typename Fnt>
void apply_recursive_reverse(Fnt
const &f, node n)
const {
100 if (n->right) apply_recursive_reverse(f, n->right);
102 if (n->left) apply_recursive_reverse(f, n->left);
105 template <
typename Fnt>
void apply_recursive_subtree_first(Fnt
const &f, node n)
const {
106 if (n->left) apply_recursive_subtree_first(f, n->left);
107 if (n->right) apply_recursive_subtree_first(f, n->right);
112 bool is_red(node x) {
113 if (x ==
nullptr)
return false;
114 return (x->color == RED);
118 int size(node x)
const {
119 if (x ==
nullptr)
return 0;
123 void rec_free(node n) {
124 if (n ==
nullptr)
return;
140 template <
typename Fnt>
friend void foreach (rb_tree
const &tr, Fnt
const &f) {
141 if (tr.root) tr.apply_recursive(f, tr.root);
143 template <
typename Fnt>
friend void foreach (rb_tree
const &tr, node n, Fnt
const &f) {
144 if (n) tr.apply_recursive(f, n);
148 template <
typename Fnt>
friend void foreach_reverse(rb_tree
const &tr, Fnt
const &f) {
149 if (tr.root) tr.apply_recursive_reverse(f, tr.root);
151 template <
typename Fnt>
friend void foreach_reverse(rb_tree
const &tr, node n, Fnt
const &f) {
152 if (n) tr.apply_recursive_reverse(f, n);
156 template <
typename Fnt>
friend void foreach_subtree_first(rb_tree
const &tr, node n, Fnt
const &f) {
157 if (n) tr.apply_recursive_subtree_first(f, n);
159 template <
typename Fnt>
friend void foreach_subtree_first(rb_tree
const &tr, Fnt
const &f) { foreach_subtree_first(tr, tr.root, f); }
161 rb_tree() : root(nullptr) {}
162 ~rb_tree() { rec_free(root); }
165 rb_tree(rb_tree
const &n) : compare(n.compare) {
166 if (n.root) root =
new node_t(*n.root);
170 int size()
const {
return size(root); }
172 bool empty()
const {
return root ==
nullptr; }
174 node
const &get_root()
const {
return root; }
175 node &get_root() {
return root; }
178 Compare
const &get_comparator()
const {
return compare; }
181 void print(std::ostream &out)
const {
182 apply_recursive([&out](node n) { out << n->key << std::endl; }, root);
186 struct print_key_as_double {
187 void operator()(std::ostream &os, node n) { os << double(n->key); }
191 template <
typename NodePr
inter = pr
int_key_as_
double>
void graphviz(std::ostream &&out, NodePrinter np = {})
const { graphviz(out, np); }
193 template <
typename NodePr
inter = pr
int_key_as_
double>
void graphviz(std::ostream &out, NodePrinter np = {})
const {
194 auto color_node_to_string = [](node n) -> std::string {
195 if (n->delete_flag)
return "green";
196 if (n->modified)
return "red";
199 out <<
"digraph G{ graph [ordering=\"out\"];" << std::endl;
202 out <<
"[color = " << color_node_to_string(root) <<
"];" << std::endl;
204 auto f = [&out, &color_node_to_string, &np](node n) {
208 out <<
"[color = " << color_node_to_string(n->left) <<
"];\n";
212 out << (n->left->color == RED ?
"[color = red];" :
";") << std::endl;
216 out <<
"[color = " << color_node_to_string(n->right) <<
"];\n";
220 out << (n->right->color == RED ?
"[color = red];" :
";") << std::endl;
225 out <<
"}" << std::endl;
228 void check_no_node_modified()
const {
229 foreach_subtree_first(*
this, [&](node y) {
230 if (y && y->modified) std::cout <<
"node modified " << y->key << std::endl;
233 void check_no_node_flagged_for_delete()
const {
234 foreach_subtree_first(*
this, [&](node y) {
235 if (y && y->modified) std::cout <<
"node flagged for deletion " << y->key << std::endl;
239 int clear_modified() {
240 int r = clear_modified_impl(root);
241#ifdef TRIQS_RBT_CHECKS
242 check_no_node_modified();
243 check_no_node_flagged_for_delete();
249 int clear_modified_impl(node n) {
251 if (n && n->modified) {
254 if (n->delete_flag) TRIQS_RUNTIME_ERROR <<
" node " << n->key <<
" is flagged for delete";
255 n->delete_flag =
false;
256 r += clear_modified_impl(n->left);
257 r += clear_modified_impl(n->right);
268 template <
typename Fnt>
void apply_until_key_impl(node x, Key
const &key, Fnt
const &f)
const {
269 while (x !=
nullptr) {
271 if (compare(key, x->key))
273 else if (compare(x->key, key))
281 void set_modified_from_root_to(Key
const &key) {
282 apply_until_key_impl(root, key, [](node y) { y->modified =
true; });
290 node get(Key
const &key)
const {
return get(root, key); }
293 bool contains(Key
const &key)
const {
return (get(key) !=
nullptr); }
296 bool contains(node x, Key
const &key)
const {
return (get(x, key) !=
nullptr); }
300 node get(node x, Key
const &key)
const {
301 while (x !=
nullptr) {
302 if (compare(key, x->key))
304 else if (compare(x->key, key))
316 template <
typename Fnt>
friend node find_if(rb_tree
const &tr, Fnt f) {
return tr.find_if_impl(tr.root, f); }
320 template <
typename Fnt> node find_if_impl(node x, Fnt &f)
const {
321 if (x ==
nullptr)
return nullptr;
322 auto r = find_if_impl(x->left, f);
325 return find_if_impl(x->right, f);
334 void insert(Key
const &key, Value
const &val) {
335 root = insert(root, key, val);
342 node insert(node h, Key
const &key, Value
const &val) {
343 if (h ==
nullptr)
return new node_t(key, val,
true, 1);
345 if (compare(key, h->key))
346 h->left = insert(h->left, key, val);
347 else if (compare(h->key, key))
348 h->right = insert(h->right, key, val);
350 throw rbt_insert_error{};
353 if (is_red(h->right) && !is_red(h->left)) h = rotateLeft(h);
354 if (is_red(h->left) && is_red(h->left->left)) h = rotateRight(h);
355 if (is_red(h->left) && is_red(h->right)) flipColors(h);
356 h->N = size(h->left) + size(h->right) + 1;
367 node deleteMin(node h) {
368 if (h->left ==
nullptr) {
372 if (!is_red(h->left) && !is_red(h->left->left)) h = moveRedLeft(h);
373 h->left = deleteMin(h->left);
378 node deleteMax(node h) {
379 if (is_red(h->left)) h = rotateRight(h);
380 if (h->right ==
nullptr) {
385 if (!is_red(h->right) && !is_red(h->right->left)) h = moveRedRight(h);
386 h->right = deleteMax(h->right);
393 if (empty()) TRIQS_RUNTIME_ERROR <<
"BST underflow";
395 if (!is_red(root->left) && !is_red(root->right)) root->color = RED;
396 root = deleteMin(root);
397 if (!empty()) root->color = BLACK;
403 if (empty()) TRIQS_RUNTIME_ERROR <<
"BST underflow";
405 if (!is_red(root->left) && !is_red(root->right)) root->color = RED;
406 root = deleteMax(root);
407 if (!empty()) root->color = BLACK;
412 void delete_node(Key
const &key) {
413 if (!contains(key)) TRIQS_RUNTIME_ERROR <<
"symbol table does not contain " << key;
415 if (!is_red(root->left) && !is_red(root->right)) root->color = RED;
416 root = delete_node(root, key);
417 if (!empty()) root->color = BLACK;
423 node delete_node(node h, Key
const &key) {
424 if (!contains(h, key)) TRIQS_RUNTIME_ERROR <<
" oops";
426 if (compare(key, h->key)) {
427 if (!is_red(h->left) && !is_red(h->left->left)) h = moveRedLeft(h);
428 h->left = delete_node(h->left, key);
431 if (is_red(h->left)) h = rotateRight(h);
432 if (key == h->key && (h->right ==
nullptr)) {
436 if (!is_red(h->right) && !is_red(h->right->left)) h = moveRedRight(h);
438 node x = min(h->right);
440 h->Value::operator=(*x);
442 h->delete_flag =
false;
443 h->right = deleteMin(h->right);
445 h->right = delete_node(h->right, key);
456 node rotateRight(node h) {
457 rbt_assert((h !=
nullptr) && is_red(h->left));
461 x->color = x->right->color;
462 x->right->color = RED;
464 h->N = size(h->left) + size(h->right) + 1;
471 node rotateLeft(node h) {
472 rbt_assert((h !=
nullptr) && is_red(h->right));
476 x->color = x->left->color;
477 x->left->color = RED;
479 h->N = size(h->left) + size(h->right) + 1;
486 void flipColors(node h) {
488 rbt_assert((h !=
nullptr) && (h->left !=
nullptr) && (h->right !=
nullptr));
489 rbt_assert((!is_red(h) && is_red(h->left) && is_red(h->right)) || ((is_red(h) && !is_red(h->left) && !is_red(h->right))));
490 h->color = !h->color;
491 h->left->color = !h->left->color;
492 h->right->color = !h->right->color;
497 node moveRedLeft(node h) {
498 rbt_assert((h !=
nullptr));
499 rbt_assert(is_red(h) && !is_red(h->left) && !is_red(h->left->left));
502 if (is_red(h->right->left)) {
503 h->right = rotateRight(h->right);
511 node moveRedRight(node h) {
512 rbt_assert((h !=
nullptr));
513 rbt_assert(is_red(h) && !is_red(h->right) && !is_red(h->right->left));
515 if (is_red(h->left->left)) { h = rotateRight(h); }
520 node balance(node h) {
521 rbt_assert((h !=
nullptr));
523 if (is_red(h->right)) h = rotateLeft(h);
524 if (is_red(h->left) && is_red(h->left->left)) h = rotateRight(h);
525 if (is_red(h->left) && is_red(h->right)) flipColors(h);
527 h->N = size(h->left) + size(h->right) + 1;
538 int height()
const {
return height(root); }
541 int height(node x)
const {
542 if (x ==
nullptr)
return -1;
543 return 1 + std::max(height(x->left), height(x->right));
551 Key min_key()
const {
552 if (empty()) TRIQS_RUNTIME_ERROR <<
"rbt: taking max_key of an empty tree.";
553 return min(root)->key;
556 Key min_key(node x)
const {
return min(x)->key; }
559 node min(node x)
const {
560 rbt_assert(x !=
nullptr);
561 if (x->left ==
nullptr)
568 Key max_key()
const {
569 if (empty()) TRIQS_RUNTIME_ERROR <<
"rbt: taking max_key of an empty tree.";
570 return max(root)->key;
573 Key max_key(node x)
const {
return max(x)->key; }
576 node max(node x)
const {
577 rbt_assert(x !=
nullptr);
578 if (x->right ==
nullptr)
581 return max(x->right);
585 Key floor(Key
const &key)
const {
586 node x = floor(root, key);
595 node floor(node x, Key
const &key)
const {
596 if (x ==
nullptr)
return nullptr;
597 if (key == x->key)
return x;
598 if (compare(key, x->key))
return floor(x->left, key);
599 node t = floor(x->right, key);
608 Key ceiling(Key
const &key)
const {
609 node x = ceiling(root, key);
618 node ceiling(node x, Key
const &key)
const {
619 if (x ==
nullptr)
return nullptr;
620 if (key == x->key)
return x;
621 if (compare(x->key, key))
return ceiling(x->right, key);
622 node t = ceiling(x->left, key);
631 Key select(
int k)
const {
632 if (k < 0 || k >= size()) TRIQS_RUNTIME_ERROR <<
" unknow key";
633 node x = select(root, k);
639 node select(node x,
int k)
const {
640 rbt_assert(x !=
nullptr);
641 rbt_assert(k >= 0 && k < size(x));
642 int t = size(x->left);
644 return select(x->left, k);
646 return select(x->right, k - t - 1);
653 int rank(Key
const &key)
const {
return rank(key, root); }
657 int rank(Key
const &key, node x)
const {
658 if (x ==
nullptr)
return 0;
659 if (compare(key, x->key))
660 return rank(key, x->left);
661 else if (compare(x->key, key))
662 return 1 + size(x->left) + rank(key, x->right);
664 return size(x->left);
672 void rbt_assert(
bool c)
const {
673 if (!c) TRIQS_RUNTIME_ERROR <<
"Error";
691 bool isBST() {
return isBST(root, std::numeric_limits<int>::min(), std::numeric_limits<int>::max()); }
697 bool isBST(node x, Key min, Key max) {
698 if (x ==
nullptr)
return true;
701 return isBST(x->left, min, x->key) && isBST(x->right, x->key, max);
705 bool isSizeConsistent() {
return isSizeConsistent(root); }
707 bool isSizeConsistent(node x) {
708 if (x ==
nullptr)
return true;
709 if (x->N != size(x->left) + size(x->right) + 1)
return false;
710 return isSizeConsistent(x->left) && isSizeConsistent(x->right);
714 bool isRankConsistent() {
715 for (
int i = 0; i < size(); i++)
716 if (i != rank(select(i)))
return false;
724 bool is23() {
return is23(root); }
726 if (x ==
nullptr)
return true;
727 if (is_red(x->right))
return false;
728 if (x != root && is_red(x) && is_red(x->left))
return false;
729 return is23(x->left) && is23(x->right);
736 while (x !=
nullptr) {
737 if (!is_red(x)) black++;
740 return isBalanced(root, black);
744 bool isBalanced(node x,
int black) {
745 if (x ==
nullptr)
return black == 0;
746 if (!is_red(x)) black--;
747 return isBalanced(x->left, black) && isBalanced(x->right, black);