线段树的定义: 首先,线段树是一棵完全二叉树。它的特点是:每个结点表示的是一个线段,或者说是一个区间。事实上,一棵线段树的根结点表示的是“整体”区间,而它的左右子树也是一棵线段树,分别表示区间的左半边和右半边。树中的每个结点表示一个区间[a,b]。每一个叶子结点表示一个单位区间。对于每一个非叶结点所表示的结点[a,b],其左孩子表示的区间为[a,(a+b)/2],右孩子表示的区间为[(a+b)/2,b]。 用T(a, b)表示一棵线段树,参数a,b表示区间[a,b],其中b-a称为区间的长度,记为L。如下图 实际上,用语言描述有点麻烦,具体可以看B站视频,还有一些其他人的介绍:线段树
话不多说,上代码:
class SegTree {
vector<int> nums;
int numsLen;
vector<int> trees;
void buildTree(int index, int left, int right) {
if (left >= right) {
trees[index] = nums[left];
return;
}
int mid = (left + right) / 2;
int node_left = index * 2 + 1;
int node_right = index * 2 + 2;
buildTree(node_left, left, mid);
buildTree(node_right, mid + 1, right);
trees[index] = trees[node_left] + trees[node_right];
}
int query(int index, int left, int right, int target_left, int target_right) {
if (left == target_left && right == target_right) {
return trees[index];
}
int mid = (left + right) / 2;
int node_left = index * 2 + 1;
int node_right = index * 2 + 2;
if (target_left > mid) {
return query(node_right, mid + 1, right, target_left, target_right);
}
else if (target_right <= mid) {
return query(node_left, left, mid, target_left, target_right);
}
else {
return query(node_left, left, mid, target_left, mid) + query(node_right, mid + 1, right, mid + 1, target_right);
}
}
void updata(int index, int left, int right, int target_index, int data) {
if (left == right) {
if (left == target_index) {
nums[target_index] = data;
trees[index] = data;
cout << "index = " << index << ",updata success \n";
}
return;
}
int mid = (left + right) / 2;
int node_left = index * 2 + 1;
int node_right = index * 2 + 2;
if (target_index > mid) {
updata(node_right, mid + 1, right, target_index, data);
}
else if (target_index <= mid) {
updata(node_left, left, mid, target_index, data);
}
trees[index] = trees[node_left] + trees[node_right];
}
public:
SegTree(vector<int> nums_) {
nums = nums_;
numsLen = nums.size();
trees = vector<int>(4 * numsLen, 0);
buildTree(0, 0, numsLen - 1);
}
int query(int left, int right) {
return query(0, 0, nums.size()-1 ,left, right);
}
void updata(int index, int data) {
updata(0, 0, nums.size()-1, index, data);
}
};
int main() {
vector<int> nums = { 1, 3, 5 };
SegmentTree* obj = new SegmentTree(nums);
SegTree* obj1 = new SegTree(nums);
cout << obj1->query(1, 2) << endl;
obj1->updata(1, 9);
cout << obj1->query(1, 2) << endl;
cout << obj1->query(0, 2) << endl;
obj1->updata(0, 9);
cout << obj1->query(1, 2) << endl;
cout << obj1->query(0, 2) << endl;
return 0;
}
输出:
|