typedef struct {
int *parent;
int *size;
int cap;
int count;
}LcUnionFind;
LcUnionFind * LcUnionFindInit(int cap) {
LcUnionFind *unionFind = malloc(sizeof(LcUnionFind));
memset(unionFind, 0, sizeof(LcUnionFind));
unionFind->cap = cap;
unionFind->parent = malloc(sizeof(int) * cap);
unionFind->size = malloc(sizeof(int) * cap);
for (int i = 0; i < cap; ++i) {
/**
* 每个节点的父节点初始化为自身
* 每个节点的“重量”初始化为1
*/
unionFind->parent[i] = i;
unionFind->size[i] = 1;
}
unionFind->count = cap;
return unionFind;
}
int LcUnionFindFind(LcUnionFind* unionFind, int n)
{
/**
* find的关键是路径压缩,即通过递归或迭代将相关操作压缩为O(1)
*/
if (unionFind->parent[n] != n) {
unionFind->parent[n] = LcUnionFindFind(unionFind, unionFind->parent[n]);
}
return unionFind->parent[n];
}
void LcUnionFindUnion(LcUnionFind *unionFind, int i, int j)
{
int in = LcUnionFindFind(unionFind, i);
int jn = LcUnionFindFind(unionFind, j);
if (in == jn) {
return;
}
/**
* union的关键在于平衡,通过“重量”进行平衡,即小的挂到大的下面
*/
if(unionFind->size[in] > unionFind->size[jn]) {
unionFind->parent[jn] = in;
unionFind->size[in] += unionFind->size[jn];
} else {
unionFind->parent[in] = jn;
unionFind->size[jn] += unionFind->size[in];
}
unionFind->count --;
}
int LcUnionFindGetCount(LcUnionFind* unionFind) {
return unionFind->count;
}
bool LcUnionFindIsConnectd(LcUnionFind* unionFind, int i, int j)
{
int in = LcUnionFindFind(unionFind, i);
int jn = LcUnionFindFind(unionFind, j);
return in == jn;
}
void LcUnionFindDestory(LcUnionFind *find)
{
if (find != NULL) {
if (find->parent != NULL) {
free(find->parent);
}
free(find);
}
}
|