diff --git a/src/rbtree.c b/src/rbtree.c index 68b7ad2..204f42e 100644 --- a/src/rbtree.c +++ b/src/rbtree.c @@ -2,10 +2,15 @@ #include +enum rbColor {rbBlack=1, rbRed=2}; + + struct element { int data; + enum rbColor color; + struct element * parent; struct element * left; struct element * right; @@ -14,13 +19,14 @@ struct element struct element * newElement(int data) { - struct element * el = malloc(sizeof(struct element)); - el->data = data; - el->parent = NULL; - el->left = NULL; - el->right = NULL; + struct element * node = malloc(sizeof(struct element)); + node->data = data; + node->color = rbRed; + node->parent = NULL; + node->left = NULL; + node->right = NULL; - return el; + return node; } /** @@ -47,14 +53,14 @@ findElement(struct element * tree, int data) /** * insert element in tree */ -void +struct element * insertElement(struct element ** tree, int data) { struct element * node = *tree; if (NULL == node) { *tree = newElement(data); - return; + return *tree; } while (data != node->data) { @@ -62,7 +68,7 @@ insertElement(struct element ** tree, int data) if (NULL == node->left) { node->left = newElement(data); node->left->parent = node; - return; + return node->left; } else { node = node->left; } @@ -70,12 +76,14 @@ insertElement(struct element ** tree, int data) if (NULL == node->right) { node->right = newElement(data); node->right->parent = node; - return; + return node->right; } else { node = node->right; } } } + + return NULL; } /** @@ -188,10 +196,16 @@ traverse(struct element * tree, void (*cb)(int, int)) struct element * node = tree; int depth = 1; - // I think this has something like O(n+log(n)) on a ballanced - // tree because I have to traverse back the rightmost leaf to - // the root to get a break condition. + /* + * I think this has something like O(n+log(n)) on a ballanced + * tree because I have to traverse back the rightmost leaf to + * the root to get a break condition. + */ while (node) { + /* + * If we come from the right so nothing and go to our + * next parent. + */ if (previous == node->right) { previous = node; node = node->parent; @@ -200,6 +214,10 @@ traverse(struct element * tree, void (*cb)(int, int)) } if ((NULL == node->left || previous == node->left)) { + /* + * If there are no more elements to the left or we + * came from the left, process data. + */ cb(node->data, depth); previous = node; @@ -211,6 +229,9 @@ traverse(struct element * tree, void (*cb)(int, int)) depth--; } } else { + /* + * if there are more elements to the left go there. + */ previous = node; node = node->left; depth++; @@ -222,8 +243,188 @@ void printElement(int data, int depth) { int i; - for (i=0; iparent) { + return node->parent->parent; + } + + return NULL; +} + +struct element * +uncle(struct element * node) +{ + struct element * gp = grandparent(node); + + if (NULL == gp) { + return NULL; + } + + if (node->parent == gp->left) { + return gp->right; + } + + return gp->left; +} + +void +rotateLeft(struct element ** tree, struct element * node) +{ + struct element * rightChild = node->right; + struct element * rcLeftSub = node->right->left; + + rightChild->left = node; + rightChild->parent = node->parent; + node->right = rcLeftSub; + rcLeftSub->parent = node; + + if (node->parent) { + if (node->parent->left == node) { + node->parent->left = rightChild; + } else { + node->parent->right = rightChild; + } + } else { + *tree = rightChild; + } + + node->parent = rightChild; +} + +void +rotateRight(struct element ** tree, struct element * node) +{ + struct element * leftChild = node->left; + struct element * lcRightSub = node->left->right; + + leftChild->right = node; + leftChild->parent = node->parent; + node->left = lcRightSub; + lcRightSub->parent = node; + + if (node->parent) { + if (node->parent->left == node) { + node->parent->left = leftChild; + } else { + node->parent->right = leftChild; + } + } else { + *tree = leftChild; + } + + node->parent = leftChild; +} + +void +insertCase5(struct element ** tree, struct element * node) +{ + struct element *g = grandparent(node); + + node->parent->color = rbBlack; + g->color = rbRed; + + if (node == node->parent->left) { + rotateRight(tree, g); + } else { + rotateLeft(tree, g); + } +} + +void +insertCase4(struct element ** tree, struct element * node) +{ + struct element * g = grandparent(node); + + if ((node == node->parent->right) && (node->parent == g->left)) { + rotateLeft(tree, node->parent); + + /* + * rotate_left can be the below because of already + * having *g = grandparent(n) + * + * struct node *saved_p=g->left, *saved_left_n=n->left; + * g->left=n; + * n->left=saved_p; + * saved_p->right=saved_left_n; + * + * and modify the parent's nodes properly + */ + node = node->left; + + } else if ( + (node == node->parent->left) && + (node->parent == g->right)) { + + rotateRight(tree, node->parent); + + /* + * rotate_right can be the below to take advantage + * of already having *g = grandparent(n) + * + * struct node *saved_p=g->right, *saved_right_n=n->right; + * g->right=n; + * n->right=saved_p; + * saved_p->left=saved_right_n; + * + */ + node = node->right; + } + + insertCase5(tree, node); +} + +void insertCase1(struct element **, struct element *); + +void +insertCase3(struct element ** tree, struct element * node) +{ + struct element * u = uncle(node); + struct element * g; + + if ((u != NULL) && (u->color == rbRed)) { + node->parent->color = rbBlack; + u->color = rbBlack; + g = grandparent(node); + g->color = rbRed; + + insertCase1(tree, g); + } else { + insertCase4(tree, node); + } +} + +void +insertCase2(struct element ** tree, struct element * node) +{ + if (node->parent->color == rbBlack) { + return; + // Tree is still valid ... wow, again we're done... :) + } else { + insertCase3(tree, node); + } +} + +void +insertCase1(struct element ** tree, struct element * node) +{ + if (node->parent == NULL) { + node->color = rbBlack; + // we're done.... :) + } else { + insertCase2(tree, node); + } } /** @@ -232,17 +433,18 @@ void printElement(int data, int depth) int main(int argc, char * argv[]) { - int i; - struct element * root = NULL; - - insertElement(&root, 13); - insertElement(&root, 8); - insertElement(&root, 16); - insertElement(&root, 11); - insertElement(&root, 3); - insertElement(&root, 9); - insertElement(&root, 12); - insertElement(&root, 10); + struct element * root1 = NULL; + struct element * root2 = NULL; + struct element * inserted = NULL; + + insertElement(&root1, 13); + insertElement(&root1, 8); + insertElement(&root1, 16); + insertElement(&root1, 11); + insertElement(&root1, 3); + insertElement(&root1, 9); + insertElement(&root1, 12); + insertElement(&root1, 10); /* * after this I have the following: @@ -277,19 +479,39 @@ main(int argc, char * argv[]) * Looks like the insert works properly. * So the problem is out traversing... */ - puts("elements:"); - for (i=1; i<20; i++) { - struct element * element = findElement(root, i); - printf("Element %02d: n=0x%p p=0x%p l=0x%p r=0x%p\n", - i, - element, - element ? element->parent : 0x0, - element ? element->left : 0x0, - element ? element->right : 0x0); - } + // puts("elements:"); + // for (i=1; i<20; i++) { + // struct element * element = findElement(root, i); + // printf("Element %02d: n=0x%p p=0x%p l=0x%p r=0x%p\n", + // i, + // element, + // element ? element->parent : 0x0, + // element ? element->left : 0x0, + // element ? element->right : 0x0); + // } + + puts("traverse"); + traverse(root1, printElement); + + inserted = insertElement(&root2, 13); + insertCase1(&root2, inserted); + inserted = insertElement(&root2, 8); + insertCase1(&root2, inserted); + inserted = insertElement(&root2, 16); + insertCase1(&root2, inserted); + inserted = insertElement(&root2, 11); + insertCase1(&root2, inserted); + inserted = insertElement(&root2, 3); + insertCase1(&root2, inserted); + inserted = insertElement(&root2, 9); + insertCase1(&root2, inserted); + inserted = insertElement(&root2, 12); + insertCase1(&root2, inserted); + inserted = insertElement(&root2, 10); + insertCase1(&root2, inserted); puts("traverse"); - traverse(root, printElement); + traverse(root2, printElement); return 0; }