#include <stdio.h>
#include <stdlib.h>
#include <iostream>
using namespace std;
typedef struct Node {
int data, color;
Node* lchild, * rchild;
} Node;
Node __NIL;
#define NIL (&__NIL)
class init_NIL {
public:
init_NIL() {
NIL->data = 0;
NIL->color = 1;
NIL->lchild = NIL->rchild = NIL;
return;
}
};
init_NIL p;
Node* getNewNode(int val) {
Node* p = (Node*)malloc(sizeof(Node));
p->data = val;
p->lchild = p->rchild = NIL;
p->color = 0;
return p;
}
int has_red_child(Node* root) {
return root->lchild->color == 0 || root->rchild->color == 0;
}
Node* left_rotate(Node* root) {
printf("left node: %d\n", root->data);
Node* new_root = root->rchild;
root->rchild = new_root->lchild;
new_root->lchild = root;
return new_root;
}
Node* right_rotate(Node* root) {
printf("right node: %d\n", root->data);
Node* new_root = root->lchild;
root->lchild = new_root->rchild;
new_root->rchild = root;
return new_root;
}
const char* insert_maintain_type[] = {
"1: change color",
"2 : LL",
"2 : LR",
"2 : RR",
"2: RL"
};
Node* insert_maintain(Node* root) {
if (!has_red_child(root)) return root;
if (
!(root->lchild->color == 0 && has_red_child(root->lchild)) &&
!(root->rchild->color == 0 && has_red_child(root->rchild))
)return root;
int type = 0;
if (root->rchild->color == 1) {
if (root->lchild->rchild->color == 0) {
root->lchild = left_rotate(root->lchild);
type += 1;
}
type += 1;
root = right_rotate(root);
}
else if (root->lchild->color == 1) {
type = 2;
if (root->rchild->lchild->color == 0) {
root->rchild = right_rotate(root->rchild);
type += 1;
}
type += 1;
root = left_rotate(root);
}
printf("insert maintain type = %s\n", insert_maintain_type[type]);
root->color = 0;
root->lchild->color = root->rchild->color = 1;
return root;
}
Node* __insert(Node* root, int val) {
if (root == NIL) return getNewNode(val);
if (root->data == val) return root;
if (root->data > val) root->lchild = __insert(root->lchild, val);
else root->rchild = __insert(root->rchild, val);
return insert_maintain(root);
}
Node* insert(Node* root, int val) {
root = __insert(root, val);
root->color = 1;
return root;
}
void clear(Node* root) {
if (root == NIL) return;
clear(root->lchild);
clear(root->rchild);
free(root);
return;
}
void print_node(Node* root) {
printf("( %d(%d) | %d, %d )\n",
root->data, root->color,
root->lchild->data,
root->rchild->data
);
return;
}
void output(Node* root) {
if (root == NIL) return;
print_node(root);
output(root->lchild);
output(root->rchild);
return;
}
int main() {
Node* root = NIL;
int val = 0;
while (~scanf_s("%d", &val)) {
printf("\n=== insert %d to red black tree ===\n", val);
root = insert(root, val);
output(root);
}
clear(root);
return 0;
}
|