- 本人的LeetCode账号:魔术师的徒弟,欢迎关注获取每日一题题解,快来一起刷题呀~
- 本人Gitee账号:路由器,欢迎关注获取博客内容源码。
一、单点修改的线段树
1 线段树的原理及其基本操作
一个最简单的线段树一般会提供五个操作:
pushup(u) :根据子结点信息来算父节点的信息;
build() :将一段区间初始化为线段树;
modify() :修改操作,修改单点(easy)或修改某个区间(使用pushdown,比较hard的一个操作)
query() :查询操作,查询某一端区间的信息。
pushdown() :把当前父节点的修改信息下传到子结点。
??线段树是一个满二叉树,假设要用线段树维护1~10的闭区间。
??先看看build 操作的伪代码:
??要注意的是,是建完结点回来然后根据子结点信息更新父节点信息。
??再看看query 操作的伪代码,比如我们要查询某区间的最大值,每个结点存当前区间的最大值。
??假设查[5, 9] 的最大值:
??证明查询时总访问的区间数量一定是在log(n) 范围内。
??每个1 2情况展开有两个点,估算一下最多有4logn 个点,这一点,比树状数组的复杂度常数上要高一些。
??modify 单点修改:直接递归就好了,比如要更新6这个点:
2 Acwing1275 最大数
??如果要动态的添加点是在是太难了,因为最多有m个数,所以我们可以直接开m个坑,用n维护当前有多少个位置已经被占了,增加操作相当于把n + 1位置的数修改为我们要添加的数,然后n++。
??所以我们总共需要两个操作:
- 在某一个位置修改一个数;
- 询问
[n - L + 1, n] 区间内的最大值;
??这就是线段树的一个经典操作。
??线段树首先需要一个结点,首先必然要存的是左端点和右端点l和r,然后在本题中,我们另外需要存储的就是区间内的最大值。
??如何判断线段树中要存什么信息?看看问的是某个区间的某种属性,一般问的属性要存下来,有时候可能还要存辅助信息,一般就看看当前属性能否由两个子树的属性求出来,如果不能就增加一下属性,能的话就ok了。
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long LL;
const int N = 200010;
int m, p;
struct Node
{
int l, r;
int v;
}tr[4 * N];
void build(int u, int l, int r)
{
tr[u] = { l, r };
if (l == r) return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void pushup(int u)
{
tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}
int query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u].v;
int mid = tr[u].l + tr[u].r >> 1;
int v = 0;
if (l <= mid) v = query(u << 1, l, r);
if (r > mid) v = max(v, query(u << 1 | 1, l, r));
return v;
}
void modify(int u, int x, int v)
{
if (tr[u].l == x && tr[u].r == x)
{
tr[u].v = v;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if (mid >= x) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
int main()
{
cin >> m >> p;
int n = 0;
build(1, 1, m);
int t;
int last;
char op[2];
while (m--)
{
scanf("%s%d", op, &t);
if (op[0] == 'Q')
{
last = query(1, n - t + 1, n);
printf("%d\n", last);
}
else
{
modify(1, n + 1, ((LL)t + last) % p);
++n;
}
}
return 0;
}
??单点修改,动态维护区间的最大值:
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long LL;
const int N = 200010;
struct Node
{
int l, r;
int v;
}tr[4 * N];
void build(int u, int l, int r)
{
tr[u] = { l, r };
if (l == r) return;
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
}
void pushup(int u)
{
tr[u].v = max(tr[u << 1].v, tr[u << 1 | 1].v);
}
int query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r) return tr[u].v;
int mid = tr[u].l + tr[u].r >> 1;
int v = 0;
if (l <= mid) v = query(u << 1, l, r);
if (r > mid) v = max(v, query(u << 1 | 1, l, r));
return v;
}
void modify(int u, int x, int v)
{
if (tr[u].l == x && tr[u].r == x)
{
tr[u].v = v;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if (mid >= x) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
3 Acwing245.你能回答这些问题吗
??更新策略如下:
#include <cstdio>
#include <iostream>
using namespace std;
const int N = 500010;
struct Node
{
int l, r;
int tmax, lmax, rmax, sum;
}tr[N * 4];
int w[N];
int n, m;
void merge(Node& u, Node& l, Node& r)
{
u.sum = l.sum + r.sum;
u.lmax = max(l.lmax, l.sum + r.lmax);
u.rmax = max(r.rmax, r.sum + l.rmax);
u.tmax = max(max(l.tmax, r.tmax), l.rmax + r.lmax);
}
void pushup(int u)
{
merge(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
if (l == r)
{
tr[u] = { r, r, w[r], w[r], w[r], w[r] };
return;
}
else
{
tr[u] = { l, r };
int mid = (l + r) >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int x, int v)
{
if (tr[u].l == x && tr[u].r == x)
{
tr[u] = { x, x, v, v, v, v };
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
Node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
return tr[u];
}
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
else if (l > mid) return query(u << 1 | 1, l, r);
else
{
auto left = query(u << 1, l, r);
auto right = query(u << 1 | 1, l, r);
Node res;
merge(res, left, right);
return res;
}
}
int main()
{
cin >> n >> m;
for (int i = 1; i <= n; ++i) scanf("%d", &w[i]);
build(1, 1, n);
int k, x, y;
while (m--)
{
scanf("%d%d%d", &k, &x, &y);
if (k == 1)
{
if (x > y) swap(x, y);
int res = query(1, x, y).tmax;
printf("%d\n", res);
}
else
{
modify(1, x, y);
}
}
return 0;
}
4 Acwing246.区间的最大公约数
??考虑线段树结点需要存的信息,首先要存最大公约数,然后考虑怎么由左区间的最大公约数和右区间的最大公约数得到父节点的最大公约数:
g
c
d
(
[
L
,
R
]
)
=
g
c
d
(
g
c
d
(
[
L
,
m
i
d
]
)
,
g
c
d
(
[
m
i
d
+
1
,
R
]
)
)
gcd([L,R]) = gcd(gcd([L,mid]), gcd([mid + 1, R]))
gcd([L,R])=gcd(gcd([L,mid]),gcd([mid+1,R])) ??即父区间的最大公约数等于左区间的最大公约数与右区间的最大公约数再取一个最大公约数,那么查询操作就只存最大公约数就够了。
??发现对于一个区间同时增加1个数非常难以操作,如果每次只修改一个数,那么非常好操作(一个数的最大公约数就是它自己)。
??所以有没有什么方法可以把区间的修改变成单点的修改?
??联想到了差分的技巧。
??注意到有等式:
(
a
1
,
a
2
,
.
.
.
,
a
n
)
=
(
a
1
,
a
2
?
a
1
,
a
3
?
a
2
,
.
.
.
,
a
n
?
a
n
?
1
)
(a_1, a_2, ..., a_n) = (a_1, a_2 - a_1, a_3 - a_2,...,a_n - a_{n - 1})
(a1?,a2?,...,an?)=(a1?,a2??a1?,a3??a2?,...,an??an?1?) Proof :
假
设
d
=
(
a
1
,
a
2
,
.
.
.
,
a
n
)
,
证
明
它
一
定
是
右
边
这
n
个
数
的
一
个
约
数
这
是
显
然
的
,
右
边
的
每
一
项
都
整
除
d
,
又
因
为
右
边
是
右
边
n
个
数
的
最
大
公
约
数
所
以
d
<
=
(
a
1
,
a
2
?
a
1
,
a
3
?
a
2
,
.
.
.
,
a
n
?
a
n
?
1
)
设
d
=
(
a
1
,
a
2
?
a
1
,
a
3
?
a
2
,
.
.
.
,
a
n
?
a
n
?
1
)
首
先
d
能
整
除
a
1
,
有
因
为
d
能
整
除
a
1
且
能
整
除
a
2
?
a
1
所
以
d
能
整
除
他
们
的
和
,
a
2
以
此
类
推
,
d
能
整
除
a
1
到
a
n
中
的
每
一
个
数
所
以
d
是
a
1
?
a
n
的
公
约
数
所
以
d
<
=
(
a
1
,
a
2
,
.
.
.
,
a
n
)
综
上
,
(
a
1
,
a
2
,
.
.
.
,
a
n
)
=
(
a
1
,
a
2
?
a
1
,
a
3
?
a
2
,
.
.
.
,
a
n
?
a
n
?
1
)
假设d =(a_1, a_2, ..., a_n),证明它一定是右边这n个数的一个约数\\ 这是显然的,右边的每一项都整除d,又因为右边是右边n个数的最大公约数\\ 所以d<=(a_1, a_2 - a_1, a_3 - a_2,...,a_n - a_{n - 1})\\ 设d =(a_1, a_2 - a_1, a_3 - a_2,...,a_n - a_{n - 1})\\ 首先d能整除a_1,有因为d能整除a_1且能整除a_2 - a_1\\ 所以d能整除他们的和,a_2\\ 以此类推,d能整除a_1到a_n中的每一个数\\ 所以d是a_1-a_n的公约数\\ 所以d<=(a_1, a_2, ..., a_n)\\ 综上,(a_1, a_2, ..., a_n) = (a_1, a_2 - a_1, a_3 - a_2,...,a_n - a_{n - 1})
假设d=(a1?,a2?,...,an?),证明它一定是右边这n个数的一个约数这是显然的,右边的每一项都整除d,又因为右边是右边n个数的最大公约数所以d<=(a1?,a2??a1?,a3??a2?,...,an??an?1?)设d=(a1?,a2??a1?,a3??a2?,...,an??an?1?)首先d能整除a1?,有因为d能整除a1?且能整除a2??a1?所以d能整除他们的和,a2?以此类推,d能整除a1?到an?中的每一个数所以d是a1??an?的公约数所以d<=(a1?,a2?,...,an?)综上,(a1?,a2?,...,an?)=(a1?,a2??a1?,a3??a2?,...,an??an?1?) ??所以一个数列的最大公约数等于其差分数列的最大公约数,如果要对原数列进行区间整体加x,那么对差分数列进行两个单点增加即可。
??所以如果要求区间[L,R] 的最大公约数,只要求a[L], b[L + 1], ..., b[R] 的最大公约数即可,右边很好维护,也很好求。
??a[L] 就是要求一个前缀和,所以我们需要求的是:
g
c
d
(
a
[
L
]
,
g
c
d
(
b
[
L
+
1
]
,
.
.
.
,
b
[
R
]
)
)
gcd(a[L], gcd(b[L + 1], ... ,b[R]))
gcd(a[L],gcd(b[L+1],...,b[R])) ??所以维护两个信息:sum 和gcd ,差分序列的前缀和和差分序列的最大公约数,因为他们维护的是统一个序列,所以可以放在一起写。
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long LL;
LL gcd(LL a, LL b)
{
return b ? gcd(b, a % b) : a;
}
const int N = 500010;
LL w[N];
struct Node
{
int l, r;
LL d, sum;
}tr[N * 4];
void merge(Node& u, Node& l, Node& r)
{
u.sum = l.sum + r.sum;
u.d = gcd(l.d, r.d);
}
void pushup(int u)
{
merge(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r)
{
if (l == r)
{
tr[u] = { r, r, w[r] - w[r - 1], w[r] - w[r - 1] };
return;
}
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int x, LL v)
{
if (tr[u].l == x && tr[u].r == x)
{
tr[u].sum += v;
tr[u].d += v;
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
Node query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
return tr[u];
}
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
auto left = query(u << 1, l, r);
auto right = query(u << 1 | 1, l, r);
Node res;
merge(res, left, right);
return res;
}
int n, m;
int main()
{
cin >> n >> m;
for (int i = 1; i <= n; ++i) scanf("%lld", &w[i]);
build(1, 1, n);
int l, r;
LL d;
char op[2];
while (m--)
{
scanf("%s", op);
if (op[0] == 'Q')
{
scanf("%d%d", &l, &r);
auto left = query(1, 1, l);
Node right({0, 0, 0, 0});
if (l + 1 <= r) right = query(1, l + 1, r);
LL res = abs(gcd(left.sum, right.d));
printf("%lld\n", res);
}
else
{
scanf("%d%d%lld", &l, &r, &d);
modify(1, l, d);
if (r + 1 <= n) modify(1, r + 1, -d);
}
}
return 0;
}
二、区间修改的线段树
1 原理
- 懒标记:pushdown,把父节点信息下传子结点;
- 扫描线法
??什么是懒标记?如果只有pushup 操作,那么我们只能做单点修改,如果用单点修改来修改区间,最坏情况下要修改O(n) 个区间,为了解决这个问题,提出了pushdown 操作。
??它的思想与query 类似,当我们查询到一个区间被完全包含的时候,就不往下走了,直接打上一个懒标记。
??以区间和为例:
??区间属性中增加一个add 懒标记,其含义为给以当前结点为根的子树中的每一个结点都修改(加上)这个数(不包含当前区间自己),这样可以保证我们的修改的操作时间复杂度在O(LOGN) 内。
??查询时:我们用到的每一个区间的值,必须要把祖宗的懒标记加上,因此我们增加一个操作,如果当前区间不符合要求,我们就先把它的懒标记清空,然后传给两个孩子结点,这样的操作就是pushdown操作,加上了这种操作时,计算某区间和时,递归到单点时,其祖宗的懒标记就已经全部清空(被计算过了)了,累加到了根节点上。
root add;
left.add += root.add; left.sum += (left.r - left.l + 1) * root.add;
right.add += root.add; right.sum += (right.r - left.l + 1) * root.add;
root.add = 0;
??对于修改操作,如果我们仅仅要修改当前区间的某一部分,那么一定要把懒标记往子孙传,否则可能会出现一个区间的左右两部分懒标记不同。
2 Acwing243.简单的整数问题
??线段树结点中维护两个信息:
- sum:如果考虑当前结点及子结点上的所有标记,当前区间和为多少(没有考虑祖先结点的标记)。
- add:给当前区间的所有儿子(不包括它自己)加上add。
#include <iostream>
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
struct Node
{
int l, r;
LL sum, add;
}tr[N * 4];
int w[N];
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
auto& root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add != 0)
{
left.add += root.add;
left.sum += (LL)root.add * (left.r - left.l + 1);
right.add += root.add;
right.sum += (LL)root.add * (right.r - right.l + 1);
root.add = 0;
}
}
void build(int u, int l, int r)
{
if (l == r)
{
tr[u] = {r, r, w[r], 0};
return;
}
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r, int d)
{
if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].sum += (LL)d * (tr[u].r - tr[u].l + 1);
tr[u].add += d;
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, d);
if (r > mid) modify(u << 1 | 1, l, r, d);
pushup(u);
}
LL query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
return tr[u].sum;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
return query(u << 1, l, r) + query(u << 1 | 1, l, r);
}
int n, m;
int main()
{
cin >> n >> m;
for (int i = 1; i <= n; ++i) scanf("%d", &w[i]);
build(1, 1, n);
int l, r, d;
char op[2];
while (m--)
{
scanf("%s%d%d", op, &l, &r);
if (op[0] == 'C')
{
scanf("%d", &d);
modify(1, l, r, d);
}
else
{
printf("%lld\n", query(1, l, r));
}
}
return 0;
}
3 区间加d 求区间和的线段树模板
??简单写了个模板:
struct Node
{
int l, r;
LL sum, add;
};
class SegmentTree
{
typedef long long LL;
public:
template<class InputIterator>
SegmentTree(InputIterator first, InputIterator last)
{
int i = 1;
while (first != last)
{
w[i] = *first;
++first;
++i;
}
n = i - 1;
build(1, 1, n);
}
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
auto& root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add)
{
left.add += root.add;
left.sum += (LL)root.add * (left.r - left.l + 1);
right.add += root.add;
right.sum += (LL)root.add * (right.r - right.l + 1);
root.add = 0;
}
}
void build(int u, int l, int r)
{
if (l == r)
{
tr[u] = {r, r, w[r], 0};
return;
}
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int l, int r, int d)
{
_modify(1, l, r, d);
}
void _modify(int u, int l, int r, int d)
{
if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].add += d;
tr[u].sum += (LL)d * (tr[u].r - tr[u].l + 1);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) _modify(u << 1, l, r, d);
if (r > mid) _modify(u << 1 | 1, l, r, d);
pushup(u);
}
LL query(int l, int r)
{
return _query(1, l, r);
}
LL _query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
return tr[u].sum;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return _query(u << 1, l, r);
if (l > mid) return _query(u << 1 | 1, l, r);
return _query(u << 1, l, r) + _query(u << 1 | 1, l, r);
}
private:
static const int N = 1e5 + 10;
Node tr[N * 4];
int w[N];
int n;
};
4 Acwing1277. 维护序列
??思路:
#include <iostream>
using namespace std;
const int N = 1e5 + 10;
typedef long long LL;
int n, m, p;
struct Node
{
int l, r;
int sum;
int mul;
int add;
}tr[N * 4];
int w[N];
void pushup(int u)
{
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}
void eval(Node& t, int mul, int add)
{
t.sum = ((LL)t.sum * mul + (LL)add * (t.r - t.l + 1)) % p;
t.mul = (LL)t.mul * mul % p;
t.add = ((LL)t.add * mul + add) % p;
}
void pushdown(int u)
{
eval(tr[u << 1], tr[u].mul, tr[u].add);
eval(tr[u << 1 | 1], tr[u].mul, tr[u].add);
tr[u].mul = 1, tr[u].add = 0;
}
void build(int u, int l, int r)
{
if (l == r)
{
tr[u] = {r, r, w[r], 1, 0};
return;
}
tr[u] = {l, r, 0, 1, 0};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r, int mul, int add)
{
if (tr[u].l >= l && tr[u].r <= r)
{
eval(tr[u], mul, add);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, mul, add);
if (r > mid) modify(u << 1 | 1, l, r, mul, add);
pushup(u);
}
int query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
return tr[u].sum;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
return (query(u << 1, l, r) + query(u << 1 | 1, l, r)) % p;
}
int main()
{
cin >> n >> p;
for (int i = 1; i <= n; ++i) scanf("%d", &w[i]);
build(1, 1, n);
scanf("%d", &m);
int op, l, r, d;
while (m--)
{
scanf("%d%d%d", &op, &l, &r);
if (op == 1)
{
scanf("%d", &d);
modify(1, l, r, d, 0);
}
else if (op == 2)
{
scanf("%d", &d);
modify(1, l, r, 1, d);
}
else printf("%d\n", query(1, l, r));
}
return 0;
}
5 LCP 52.二叉树染色的线段树做法
??本题可以把染成蓝色看做是对应区间[l, r] 乘以0,染成红色是对应区间[l, r] 加上1,因为二叉树结点值比较大,但二叉树结点的个数是1e5 数量级的,所以可以用离散化。最后,询问每个区间(i, i) 的值,若不为零,说明此点为蓝色,答案加1。
struct Node
{
int l, r;
int sum, mul, add;
};
class SegmentTree
{
public:
template<class InputIterator>
SegmentTree(InputIterator first, InputIterator last)
{
int i = 1;
while (first != last)
{
w[i] = *first;
first++;
i++;
}
n = i - 1;
build(1, 1, n);
}
SegmentTree(int sz)
: n(sz)
{
memset(w, 0, sizeof(w));
build(1, 1, n);
}
void modify(int l, int r, int mul, int add)
{
_modify(1, l, r, mul, add);
}
int query(int l, int r)
{
return _query(1, l, r);
}
private:
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void eval(Node& t, int mul, int add)
{
t.sum = t.sum * mul + add * (t.r - t.l + 1);
t.mul *= mul;
t.add = t.add * mul + add;
}
void pushdown(int u)
{
eval(tr[u << 1], tr[u].mul, tr[u].add);
eval(tr[u << 1 | 1], tr[u].mul, tr[u].add);
tr[u].mul = 1, tr[u].add = 0;
}
void build(int u, int l, int r)
{
if (l == r)
{
tr[u] = {r, r, w[r], 1, 0};
return;
}
tr[u] = {l, r, 0, 1, 0};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void _modify(int u, int l, int r, int mul, int add)
{
if (tr[u].l >= l && tr[u].r <= r)
{
eval(tr[u], mul, add);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) _modify(u << 1, l, r, mul, add);
if (r > mid) _modify(u << 1 | 1, l, r, mul, add);
pushup(u);
}
int _query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
return tr[u].sum;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return _query(u << 1, l, r);
if (l > mid) return _query(u << 1 | 1, l, r);
return _query(u << 1, l, r) + _query(u << 1 | 1, l, r);
}
static const int N = 1e5 + 10;
int w[N];
Node tr[N * 4];
int n;
};
class Solution {
public:
vector<int> inorder;
int getNumber(TreeNode* root, vector<vector<int>>& ops)
{
Inorder(root);
unordered_map<int, int> myhash;
int n = inorder.size();
for (int i = 0; i < n; ++i) myhash[inorder[i]] = i + 1;
SegmentTree SegT(n);
for (auto& op : ops)
{
if (op[0] == 0)
{
int l = myhash[op[1]], r = myhash[op[2]];
SegT.modify(l, r, 0, 0);
}
else
{
int l = myhash[op[1]], r = myhash[op[2]];
SegT.modify(l, r, 1, 1);
}
}
int ans = 0;
for (int i = 1; i <= n; ++i)
{
if (SegT.query(i, i) != 0) ++ans;
}
return ans;
}
void Inorder(TreeNode* root)
{
if (root == nullptr) return;
Inorder(root->left);
inorder.push_back(root->val);
Inorder(root->right);
}
};
6 区间乘c加d 求区间和的线段树模板
??简单总结一个模板:
struct Node
{
int l, r;
int sum, mul, add;
};
class SegmentTree
{
public:
template<class InputIterator>
SegmentTree(InputIterator first, InputIterator last)
{
int i = 1;
while (first != last)
{
w[i] = *first;
first++;
i++;
}
n = i - 1;
build(1, 1, n);
}
SegmentTree(int sz)
: n(sz)
{
memset(w, 0, sizeof(w));
build(1, 1, n);
}
void modify(int l, int r, int mul, int add)
{
_modify(1, l, r, mul, add);
}
int query(int l, int r)
{
return _query(1, l, r);
}
private:
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void eval(Node& t, int mul, int add)
{
t.sum = t.sum * mul + add * (t.r - t.l + 1);
t.mul *= mul;
t.add = t.add * mul + add;
}
void pushdown(int u)
{
eval(tr[u << 1], tr[u].mul, tr[u].add);
eval(tr[u << 1 | 1], tr[u].mul, tr[u].add);
tr[u].mul = 1, tr[u].add = 0;
}
void build(int u, int l, int r)
{
if (l == r)
{
tr[u] = {r, r, w[r], 1, 0};
return;
}
tr[u] = {l, r, 0, 1, 0};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void _modify(int u, int l, int r, int mul, int add)
{
if (tr[u].l >= l && tr[u].r <= r)
{
eval(tr[u], mul, add);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) _modify(u << 1, l, r, mul, add);
if (r > mid) _modify(u << 1 | 1, l, r, mul, add);
pushup(u);
}
int _query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
{
return tr[u].sum;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return _query(u << 1, l, r);
if (l > mid) return _query(u << 1 | 1, l, r);
return _query(u << 1, l, r) + _query(u << 1 | 1, l, r);
}
static const int N = 1e5 + 10;
int w[N];
Node tr[N * 4];
int n;
};
|