插入结点
- 使用维基百科红黑树的插入实现
- log级别 注意考虑可能递归向上调整的情况,所以旋转时要把子树转移好
- 代码如下
void insert(int key, T value)
{
Node *insertNode = new Node(key, value);
Node *traver = root;
while (traver)
{
if (key < traver->key)
{
if (traver->left_Son)
traver = traver->left_Son;
else
{
traver->left_Son = insertNode;
break;
}
}
else if (key > traver->key)
{
if (traver->right_Son)
traver = traver->right_Son;
else
{
traver->right_Son = insertNode;
break;
}
}
else
{
traver->value = value;
delete insertNode;
return;
}
}
insertNode->parent = traver;
insert_case1(insertNode);
}
void insert_case1(Node *son)
{
cout << "insert_case1" << endl;
if (son->parent == nullptr)
{
root = son;
root->color = Black;
}
else
insert_case2(son);
}
void insert_case2(Node *son)
{
cout << "insert_case2" << endl;
if (son->parent->color == Black)
return;
else
insert_case3(son);
}
void insert_case3(Node *son)
{
cout << "insert_case3 " << endl;
if (get_uncle(son) && get_uncle(son)->color == Red)
{
get_uncle(son)->color = son->parent->color = Black;
get_grandparent(son)->color = Red;
insert_case1(get_grandparent(son));
}
else
insert_case4(son);
}
void insert_case4(Node *son)
{
cout << "insert_case4" << endl;
if (son == son->parent->right_Son && son->parent == get_grandparent(son)->left_Son)
{
cout << "LR" << endl;
rotated_left(son);
son = son->left_Son;
}
else if (son == son->parent->left_Son && son->parent == get_grandparent(son)->right_Son)
{
cout << "RL" << endl;
rotated_right(son);
son = son->right_Son;
}
insert_case5(son);
}
void insert_case5(Node *son)
{
cout << "insert_case5" << endl;
son->parent->color = Black;
get_grandparent(son)->color = Red;
if (son == son->parent->left_Son)
{
rotated_right(son->parent);
}
else
{
rotated_left(son->parent);
}
}
删除结点
本来也想按维基百科的实现,但是它的方法实现起来假定空叶被不是NULL的实际节点表示。虽然代码很简洁,但是觉得把空叶子当实际节点来做,还是不好实现
- 采用了博客的删除做法
- 兄弟结点是黑色:兄弟结点的儿子全黑或全空直接处理,其他情况都可以转换成一种情况处理
- 兄弟结点是红色:旋转一次转换成兄弟结点是红色结点的情况
- log级别 注意考虑可能递归向上调整的情况,所以旋转时要把子树转移好
- 具体代码如下:看似代码挺多其实分类下就一点点
bool erase(int key)
{
Node *traver = root;
while (traver)
{
cout << traver->key << endl;
if (key < traver->key)
traver = traver->left_Son;
else if (traver->key < key)
traver = traver->right_Son;
else
break;
}
if (traver == nullptr)
return false;
Node *preNode = get_pre_node(traver->left_Son);
if (preNode)
swap(traver, preNode);
else
preNode = traver;
erase_adjust(preNode);
return true;
}
void erase_adjust(Node *son)
{
cout << "erase_adjust" << endl;
if (son == root)
{
root = son->right_Son;
if (root != nullptr)
root->color = Black;
delete son;
}
else if (son->color == Red)
{
if (son->parent->left_Son == son)
son->parent->left_Son = nullptr;
else
son->parent->right_Son = nullptr;
delete son;
}
else if (son->color == Black && son->left_Son != nullptr)
{
swap(son, son->left_Son);
delete son->left_Son;
son->left_Son = nullptr;
}
else if (son->color == Black && son->right_Son != nullptr)
{
swap(son, son->right_Son);
delete son->right_Son;
son->right_Son = nullptr;
}
else
{
erase_case1(son);
if (son->parent && son->parent->left_Son == son)
son->parent->left_Son = nullptr;
else
son->parent->right_Son = nullptr;
delete son;
}
}
void erase_case1(Node *son)
{
cout << "erase_case1" << endl;
if (son == root)
return;
if(get_brother(son)->color == Black)
erase_case2(son);
else erase_case3(son);
}
void erase_case2(Node *son)
{
Node *brother = get_brother(son);
Node *parent = son->parent;
if ((brother->left_Son != nullptr && brother->left_Son->color == Red) || (brother->right_Son != nullptr && brother->right_Son->color == Red))
{
if (brother == parent->right_Son)
{
if (brother->right_Son == nullptr || brother->right_Son->color == Black)
{
rotated_right(brother->left_Son);
brother = brother->parent;
brother->color = Black;
brother->right_Son->color = Red;
}
if (parent->color == Black)
brother->right_Son->color = Black;
rotated_left(brother);
}
else
{
if (brother->left_Son == nullptr || brother->left_Son->color == Black)
{
rotated_left(brother->right_Son);
brother = brother->parent;
brother->color = Black;
brother->left_Son->color = Red;
}
if (parent->color == Black)
brother->left_Son->color = Black;
rotated_right(brother);
}
}
else
{
if (parent->color == Black)
{
brother->color = Red;
erase_case1(parent);
}
else
{
parent->color = Black;
brother->color = Red;
}
}
}
void erase_case3(Node *son)
{
Node *brother = get_brother(son);
Node *parent = brother->parent;
if (brother == parent->right_Son)
rotated_left(brother);
else
rotated_right(brother);
brother->color = Black;
parent->color = Red;
erase_case2(son);
}
测试插入删除调整后的树是否正确
- 每个结点的黑色平衡
- 每个结点是否满足二叉搜索树左小右大的性质
int check_tree(Node *temp)
{
if (temp == nullptr)
return 0;
int ct_black = 0;
if (temp->color == Black)
ct_black = 1;
if (temp->left_Son && temp->left_Son->parent != temp)
{
cout << temp->key << " left error " << temp->left_Son->key << endl;
exit(-1);
}
if (temp->right_Son && temp->right_Son->parent != temp)
{
cout << temp->key << " right error " << temp->right_Son->key << endl;
exit(-1);
}
if ((temp->left_Son != nullptr && temp->left_Son->key > temp->key) || (temp->right_Son != nullptr && temp->key > temp->right_Son->key))
{
cout << "check_Node error in key" << endl;
exit(-1);
}
int left_black = check_tree(temp->left_Son);
if (left_black != check_tree(temp->right_Son))
{
cout << temp->key << endl;
cout << "check_Node error in color" << endl;
exit(-1);
}
return ct_black + left_black;
}
打印红黑树
- 把高度为h的二维红黑树(看成满二叉树)压缩看成只有一层
- 根结点在2^(h-1)的位置
- 根结点左儿子在2^(h-1) - 2(h - 2)的位置,右儿子在2^(h - 1) +2 ^(h - 2)的位置
- 以此类推每个结点应该所在的位置 n为对应高度满二叉树的结点数
- 左儿子坐标 = 父结点坐标 - n/结点当前高度
- 右儿子坐标 = 父节点坐标 + n/结点当前高度
- 使用层序遍历打印结点,如果前面的结点空了,打印空格计数替代
- 注意这种情况适合树节点个数不是很大的情况,否则因为有些位置前面没有结点,但是打印出每个结点的占位符不一定相等,不好把握个数,导致位置有点偏差,,而且电脑显示屏长度有限,结点个数一多一行就放不下了。
void tourist()
{
if (root == nullptr)
return;
vector<pair<int, Node *>> a, b;
int n = 1 << get_height(root);
a.push_back(pair<int, Node *>(n >> 1, root));
int h = 2;
while (1)
{
int cnt = 0;
for (auto it : a)
{
while (cnt < it.first)
{
++cnt;
printf(" ");
}
if (it.second->color)
printf("%3dB",it.second->key);
else
printf("%3dR",it.second->key);
if (it.second->left_Son)
{
b.push_back(pair<int, Node *>(it.first - (n >> h), it.second->left_Son));
}
if (it.second->right_Son)
{
b.push_back(pair<int, Node *>(it.first + (n >> h), it.second->right_Son));
}
}
cout << "\n\n\n";
++h;
a = b;
b.clear();
if (!a.size())
break;
}
}
随机生成数据测试
- 测试时先随机生成n个key插入,再随机把这n个key删除
- 每次插入删除后都要调用测试函数,判断是否满足黑色平衡和结点左小右大
void test()
{
int n;
cin >> n;
RB_Tree<int> rbTree{};
map<int, int> mp;
vector<int> index;
while (n--)
{
int key;
while (mp.find(key = rand() % 1000) != mp.end())
{}
mp[key] = 1;
index.push_back(key);
cout << "insert " << key << endl;
rbTree.insert(key, 1);
rbTree.check_tree(rbTree.get_root());
}
cout << "--------------------------------" << endl;
cout << "--------------------------------" << endl;
while (index.size())
{
int id = rand() % index.size();
int key = index[id];
cout << "key " << key << endl;
index.erase(index.begin() + id);
rbTree.erase(key);
rbTree.check_tree(rbTree.get_root());
}
cout << "success" << endl;
}
完整代码
#include <bits/stdc++.h>
#define Red 0
#define Black 1
using namespace std;
template <typename T>
class RB_Tree
{
private:
struct Node
{
Node(int k, T v) : color(Red), parent(nullptr), left_Son(nullptr), right_Son(nullptr), key(k), value(v) {}
Node() : color(Red), parent(nullptr), left_Son(nullptr), right_Son(nullptr) {}
bool color;
int key;
T value;
Node *parent;
Node *left_Son;
Node *right_Son;
};
Node *root;
public:
RB_Tree() : root(nullptr) {}
~RB_Tree()
{
queue<Node *> q;
if (root)
{
q.push(root);
while (!q.empty())
{
root = q.front();
if (root->left_Son)
q.push(root->left_Son);
if (root->right_Son)
q.push(root->right_Son);
delete root;
}
}
}
Node *get_root()
{
return root;
}
Node *get_grandparent(Node *son)
{
if (son->parent == nullptr || son->parent->parent == nullptr)
{
cout << "error: no get_uncle error " << endl;
return nullptr;
}
return son->parent->parent;
}
Node *get_uncle(Node *son)
{
Node *grandparent = get_grandparent(son);
Node *father = son->parent;
return grandparent->left_Son == father ? grandparent->right_Son : grandparent->left_Son;
}
Node *get_brother(Node *son)
{
if (son == son->parent->left_Son)
return son->parent->right_Son;
else
return son->parent->left_Son;
}
void rotated_left(Node *son)
{
Node *parent = son->parent;
Node *grandparent = parent->parent;
parent->right_Son = son->left_Son;
if (son->left_Son)
son->left_Son->parent = parent;
son->left_Son = parent;
parent->parent = son;
son->parent = grandparent;
if (grandparent)
if (grandparent->left_Son == parent)
grandparent->left_Son = son;
else
grandparent->right_Son = son;
else
root = son;
}
void rotated_right(Node *son)
{
Node *parent = son->parent;
Node *grandparent = parent->parent;
parent->left_Son = son->right_Son;
if (son->right_Son)
son->right_Son->parent = parent;
son->right_Son = parent;
parent->parent = son;
son->parent = grandparent;
if (grandparent)
if (grandparent->left_Son == parent)
grandparent->left_Son = son;
else
grandparent->right_Son = son;
else
root = son;
}
void insert(int key, T value)
{
Node *insertNode = new Node(key, value);
Node *traver = root;
while (traver)
{
if (key < traver->key)
{
if (traver->left_Son)
traver = traver->left_Son;
else
{
traver->left_Son = insertNode;
break;
}
}
else if (key > traver->key)
{
if (traver->right_Son)
traver = traver->right_Son;
else
{
traver->right_Son = insertNode;
break;
}
}
else
{
traver->value = value;
delete insertNode;
return;
}
}
insertNode->parent = traver;
insert_case1(insertNode);
}
void insert_case1(Node *son)
{
cout << "insert_case1" << endl;
if (son->parent == nullptr)
{
root = son;
root->color = Black;
}
else
insert_case2(son);
}
void insert_case2(Node *son)
{
cout << "insert_case2" << endl;
if (son->parent->color == Black)
return;
else
insert_case3(son);
}
void insert_case3(Node *son)
{
cout << "insert_case3 " << endl;
if (get_uncle(son) && get_uncle(son)->color == Red)
{
get_uncle(son)->color = son->parent->color = Black;
get_grandparent(son)->color = Red;
insert_case1(get_grandparent(son));
}
else
insert_case4(son);
}
void insert_case4(Node *son)
{
cout << "insert_case4" << endl;
if (son == son->parent->right_Son && son->parent == get_grandparent(son)->left_Son)
{
cout << "LR" << endl;
rotated_left(son);
son = son->left_Son;
}
else if (son == son->parent->left_Son && son->parent == get_grandparent(son)->right_Son)
{
cout << "RL" << endl;
rotated_right(son);
son = son->right_Son;
}
insert_case5(son);
}
void insert_case5(Node *son)
{
cout << "insert_case5" << endl;
son->parent->color = Black;
get_grandparent(son)->color = Red;
if (son == son->parent->left_Son)
{
rotated_right(son->parent);
}
else
{
rotated_left(son->parent);
}
}
Node *get_pre_node(Node *traver)
{
if (traver == nullptr)
return traver;
while (traver->right_Son != nullptr)
traver = traver->right_Son;
return traver;
}
void swap(Node *a, Node *b)
{
T t_value = a->value;
int t_key = a->key;
a->value = b->value;
a->key = b->key;
b->value = t_value;
b->key = t_key;
}
bool erase(int key)
{
Node *traver = root;
while (traver)
{
cout << traver->key << endl;
if (key < traver->key)
traver = traver->left_Son;
else if (traver->key < key)
traver = traver->right_Son;
else
break;
}
if (traver == nullptr)
return false;
Node *preNode = get_pre_node(traver->left_Son);
if (preNode)
swap(traver, preNode);
else
preNode = traver;
erase_adjust(preNode);
return true;
}
void erase_adjust(Node *son)
{
cout << "erase_adjust" << endl;
if (son == root)
{
root = son->right_Son;
if (root != nullptr)
root->color = Black;
delete son;
}
else if (son->color == Red)
{
if (son->parent->left_Son == son)
son->parent->left_Son = nullptr;
else
son->parent->right_Son = nullptr;
delete son;
}
else if (son->color == Black && son->left_Son != nullptr)
{
swap(son, son->left_Son);
delete son->left_Son;
son->left_Son = nullptr;
}
else if (son->color == Black && son->right_Son != nullptr)
{
swap(son, son->right_Son);
delete son->right_Son;
son->right_Son = nullptr;
}
else
{
erase_case1(son);
if (son->parent && son->parent->left_Son == son)
son->parent->left_Son = nullptr;
else
son->parent->right_Son = nullptr;
delete son;
}
}
void erase_case1(Node *son)
{
cout << "erase_case1" << endl;
if (son == root)
return;
if (get_brother(son)->color == Black)
erase_case2(son);
else
erase_case3(son);
}
void erase_case2(Node *son)
{
Node *brother = get_brother(son);
Node *parent = son->parent;
if ((brother->left_Son != nullptr && brother->left_Son->color == Red) || (brother->right_Son != nullptr && brother->right_Son->color == Red))
{
if (brother == parent->right_Son)
{
if (brother->right_Son == nullptr || brother->right_Son->color == Black)
{
rotated_right(brother->left_Son);
brother = brother->parent;
brother->color = Black;
brother->right_Son->color = Red;
}
if (parent->color == Black)
brother->right_Son->color = Black;
rotated_left(brother);
}
else
{
if (brother->left_Son == nullptr || brother->left_Son->color == Black)
{
rotated_left(brother->right_Son);
brother = brother->parent;
brother->color = Black;
brother->left_Son->color = Red;
}
if (parent->color == Black)
brother->left_Son->color = Black;
rotated_right(brother);
}
}
else
{
if (parent->color == Black)
{
brother->color = Red;
erase_case1(parent);
}
else
{
parent->color = Black;
brother->color = Red;
}
}
}
void erase_case3(Node *son)
{
Node *brother = get_brother(son);
Node *parent = brother->parent;
if (brother == parent->right_Son)
rotated_left(brother);
else
rotated_right(brother);
brother->color = Black;
parent->color = Red;
erase_case2(son);
}
int get_height(Node *root)
{
if (root == nullptr)
return 0;
return 1 + max(get_height(root->left_Son), get_height(root->right_Son));
}
void tourist()
{
if (root == nullptr)
return;
vector<pair<int, Node *>> a, b;
int n = 1 << get_height(root);
a.push_back(pair<int, Node *>(n >> 1, root));
int h = 2;
while (1)
{
int cnt = 0;
for (auto it : a)
{
while (cnt < it.first)
{
++cnt;
printf(" ");
}
if (it.second->color)
printf("%3dB",it.second->key);
else
printf("%3dR",it.second->key);
if (it.second->left_Son)
{
b.push_back(pair<int, Node *>(it.first - (n >> h), it.second->left_Son));
}
if (it.second->right_Son)
{
b.push_back(pair<int, Node *>(it.first + (n >> h), it.second->right_Son));
}
}
cout << "\n\n\n";
++h;
a = b;
b.clear();
if (!a.size())
break;
}
}
int check_tree(Node *temp)
{
if (temp == nullptr)
return 0;
int ct_black = 0;
if (temp->color == Black)
ct_black = 1;
if (temp->left_Son && temp->left_Son->parent != temp)
{
cout << temp->key << " left error " << temp->left_Son->key << endl;
exit(-1);
}
if (temp->right_Son && temp->right_Son->parent != temp)
{
cout << temp->key << " right error " << temp->right_Son->key << endl;
exit(-1);
}
if ((temp->left_Son != nullptr && temp->left_Son->key > temp->key) || (temp->right_Son != nullptr && temp->key > temp->right_Son->key))
{
cout << "check_Node error in key" << endl;
exit(-1);
}
int left_black = check_tree(temp->left_Son);
if (left_black != check_tree(temp->right_Son))
{
cout << temp->key << endl;
cout << "check_Node error in color" << endl;
exit(-1);
}
return ct_black + left_black;
}
T *serarch(int key)
{
Node *traver = root;
while (traver)
{
if (key < traver->key)
traver = traver->left;
else if (key > traver->key)
traver = traver->right;
else
return traver->value;
}
return nullptr;
}
};
void test()
{
int n;
cin >> n;
RB_Tree<int> rbTree{};
map<int, int> mp;
vector<int> index;
while (n--)
{
int key;
while (mp.find(key = rand() % 1000) != mp.end())
{}
mp[key] = 1;
index.push_back(key);
cout << "insert " << key << endl;
rbTree.insert(key, 1);
rbTree.check_tree(rbTree.get_root());
}
cout << "--------------------------------" << endl;
cout << "--------------------------------" << endl;
while (index.size())
{
int id = rand() % index.size();
int key = index[id];
cout << "key " << key << endl;
index.erase(index.begin() + id);
rbTree.erase(key);
rbTree.check_tree(rbTree.get_root());
}
cout << "success" << endl;
}
int main()
{
test();
return 0;
}
|