线段树
细节要点
线段树: 建立一棵二叉树,用递归建图
- 在叶子节点处无需下放懒惰标记,所以懒惰标记可以不下传到叶子节点。
- 下放懒惰标记可以写一个专门的函数 pushdown,从儿子节点更新当前节点也可以写一个专门的函数 maintain(或者对称地用 pushup),降低代码编写难度。
- 标记永久化,如果确定懒惰标记不会在中途被加到溢出(即超过了该类型数据所能表示的最大范围),那么就可以将标记永久化。标记永久化可以避免下传懒惰标记,只需在进行询问时把标记的影响加到答案当中,从而降低程序常数。具体如何处理与题目特性相关,需结合题目来写。这也是树套树和可持久化数据结构中会用到的一种技巧。
模板
递归建树
void build(int s, int t, int p) {
if(s == t) {
d[p] = a[s];
return;
}
int mid = s + ((t-s) >> 1);
build(s, mid, p*2), build(mid+1, t, p*2+1);
d[p] = d[p*2] + d[(p*2)+1];
}
区间求和
int getsum(int l, int r, int s, int t, int p) {
if(l <= s && t <= r) return d[p];
int mid = s + ((t-s) >> 1), sum = 0;
if(l <= mid) sum += getsum(l, r, s, mid, p*2);
if(r > mid) sum += getsum(l, r, mid+1, t, p*2+1);
return sum;
}
区间修改(区间加上某个值)
void update(int l, int r, int c, int s, int t, int p) {
if(l <= s && t <= r) {
d[p] += (t-s+1) * c;
lag[p] += c;
return ;
}
int mid = s + ((t-s) >> 1);
if(lag[p] && s!=t) {
d[p*2] += lag[p] * (mid-s+1), d[p*2+1] += lag[p] * (t-mid);
lag[p*2] += lag[p], lag[p*2+1] += lag[p];
lag[p] = 0;
}
if (l <= mid) update(l, r, c, s, mid, p*2);
if (r > mid) update(l, r, c, mid+1, t, p*2+1);
d[p] = d[p*2] + d[p*2+1];
}
区间修改(区间求和)
int getsum(int l, int r, int s, int t, int p) {
if(l <= s && t <= r) return d[p];
int mid = s + ((t-s) >> 1);
if(lag[p]) {
d[p*2] += lag[p] * (mid-s+1), d[p*2+1] += lag[p] * (t-mid);
lag[p*2] += lag[p], lag[p*2+1] += lag[p];
lag[p] = 0;
}
int sum = 0;
if(l <= mid) sum += getsum(l, r, s, mid, p*2);
if(r > mid) sum += getsum(l, r, mid, t, p*2+1);
return sum;
}
区间修改(将区间修改为某一特定值)
void pushdown() {
d[p*2] += lag[p] * (mid-s+1), d[p*2+1] += lag[p] * (t-mid);
lag[p*2] += lag[p]; lag[p*2+1] += lag[p];
lag[p] = 0;
}
void update(int l, int r, int c, int s, int t, int p) {
if(l <= s && t <= r) {
d[p] = (t-s+1) * c;
lag[p] = c;
return ;
}
int mid = s + ((t-s) >> 1);
if(lag[p]) pushdown();
if(l <= mid) update(l, r, c, s, mid, p*2);
if(r > mid) update(l, r, c, mid, t, p*2+1);
}
int getsum(int l, int r, int s, int t, int p) {
if(l <= s && t <= r) return d[p];
int mid = s + ((t-s) >> 1);
if (lag[p]) pushdown();
int sum = 0;
if(l <= mid) sum = getsum(l, r, s, mid, p*2);
if(r > mid) sum += getsum(l, r, mid+1, t, p*2+1);
return sum;
}
P3372 【模板】线段树 1
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e5+10;
typedef unsigned long long ULL;
ll lag[maxn*4], d[maxn*4];
ll a[maxn];
void Build(ll s, ll t, ll p) {
if(s == t) {
d[p] = a[s];
return ;
}
ll mid = s + ((t-s)>>1);
Build(s, mid, p*2);
Build(mid+1, t, p*2+1);
d[p] = d[p*2] + d[p*2+1];
}
void pushdown(ll p, ll s, ll t) {
ll mid = s + ((t-s) >> 1);
d[p*2] += lag[p]*(mid-s+1); d[p*2+1] += lag[p]*(t-mid);
lag[p*2] += lag[p]; lag[p*2+1] += lag[p];
lag[p] = 0;
}
void add(ll l, ll r, ll c, ll s, ll t, ll p) {
if(l <= s && t <= r) {
d[p] += c*(t-s+1);
lag[p] += c;
return;
}
ll mid = s + ((t-s) >> 1);
if(lag[p] && s!=t) pushdown(p, s, t);
if(l <= mid) add(l, r, c, s, mid, p*2);
if(r > mid) add(l, r, c, mid+1, t, p*2+1);
d[p] = d[p*2] + d[p*2+1];
}
ll getsum(ll l, ll r, ll s, ll t, ll p) {
if(l <= s && r >= t) return d[p];
int mid = s + ((t-s) >> 1);
if(lag[p]) pushdown(p, s, t);
ll sum = 0;
if(l <= mid) sum = getsum(l, r, s, mid, p*2);
if(r > mid) sum += getsum(l, r, mid+1, t, p*2+1);
return sum;
}
int main() {
ll n, m;
scanf("%lld%lld", &n, &m);
for(int i = 1; i <= n; i++) scanf("%lld", &a[i]);
Build(1, n, 1);
while(m--) {
int op;
scanf("%d", &op);
if(op == 1) {
ll l, r, k;
scanf("%lld%lld%lld", &l, &r, &k);
add(l, r, k, 1, n, 1);
}
else {
ll l, r;
scanf("%lld%lld", &l, &r);
printf("%lld\n", getsum(l, r, 1, n, 1));
}
}
return 0;
}
P3373 【模板】线段树 2
实现区间乘和区间加 debug 易错:
- 每次运算完记得 laz 下放
- 加法运算的时候 记得+laz*(区间长)
#include<cstdio>
#include<iostream>
#define ll long long
using namespace std;
const int maxn = 1e5+10;
inline ll read() {
ll x = 0, f = 1;
char c = getchar();
while(c < '0' || c > '9') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
x = x*10+c-'0';
c = getchar();
}
return x*f;
}
ll n, m, mod;
ll a[maxn], sum[maxn*4], mul[maxn*4], laz[maxn*4];
void up(int i) {
sum[i] = (sum[i*2] + sum[i*2+1])%mod;
}
void pd(int i, int s, int t) {
int l = (i*2), r = (i*2+1), mid = s+t>>1;
if(mul[i] != 1) {
mul[l] *= mul[i]; mul[l] %= mod;
mul[r] *= mul[i]; mul[r] %= mod;
laz[l] *= mul[i]; laz[l] %= mod;
laz[r] *= mul[i]; laz[r] %= mod;
sum[l] *= mul[i]; sum[l] %= mod;
sum[r] *= mul[i]; sum[r] %= mod;
mul[i] = 1;
}
if(laz[i] != 0) {
sum[l] += laz[i]*(mid-s+1); sum[l] %= mod;
sum[r] += laz[i]*(t-mid); sum[r] %= mod;
laz[l] += laz[i]; laz[l] %= mod;
laz[r] += laz[i]; laz[r] %= mod;
laz[i] = 0;
}
return ;
}
void build(int s, int t, int i) {
mul[i] = 1;
laz[i] = 0;
if(s == t) {
sum[i] = a[s];
return ;
}
int mid = s+((t-s) >> 1);
build(s, mid, i*2);
build(mid+1, t, i*2+1);
up(i);
}
void multi(int l, int r, int s, int t, ll k, int i) {
int mid = s+((t-s) >> 1);
if(l <= s && t <= r) {
sum[i] *= k; sum[i] %= mod;
mul[i] *= k; mul[i] %= mod;
laz[i] *= k; laz[i] %= mod;
return ;
}
pd(i, s, t);
if(l <= mid) multi(l, r, s, mid, k, i*2);
if(r > mid) multi(l, r, mid+1, t, k, i*2+1);
up(i);
}
void add(int l, int r, int s, int t, ll k, int i) {
int mid = s+((t-s) >> 1);
if(l <= s && r >= t) {
sum[i] += k*(t-s+1); sum[i] %= mod;
laz[i] += k; laz[i] %= mod;
return ;
}
pd(i, s, t);
if(l <= mid) add(l, r, s, mid, k, i*2);
if(r > mid) add(l, r, mid+1, t, k, i*2+1);
up(i);
}
ll getsum(int l, int r, int s, int t, int i) {
int mid = s+((t-s) >> 1);
if(l <= s && t <= r) return sum[i] % mod;
pd(i, s, t);
int sum = 0;
if(l <= mid) sum = getsum(l, r, s, mid, i*2);
if(mid < r) sum += getsum(l, r, mid+1, t, i*2+1);
return sum % mod;
}
int main() {
n = read(); m = read(); mod = read();
for(int i = 1; i <= n; i++) a[i] = read();
build(1, n, 1);
while(m--) {
int op, l, r;
scanf("%d%d%d", &op, &l, &r);
if(op == 1) {
ll k = read();
multi(l, r, 1, n, k, 1);
}
else if(op == 2) {
ll k = read();
add(l, r, 1, n, k, 1);
}
else {
printf("%lld\n", getsum(l, r, 1, n, 1));
}
}
return 0;
}
将区间修改为同一个值 关键代码部分
void up(i) {
sum[i] = sum[i*2] + sum[i*2+1];
}
void pd(int i, int s, int t) {
int l = i*2, r = i*2+1, mid = s+t>>1;
sum[l] += laz[i]*(mid-s+1);
sum[r] += laz[i]*(t-mid);
laz[i*2] += laz[i];
laz[i*2+1] += laz[i];
return ;
}
void build(int s, int t, int i) {
if(s == t) {
sum[i] = a[s];
return ;
}
int mid = l+r>>1;
build(s, mid, i*2);
build(mid+1, r, i*2+1);
up(i);
}
void update(int l, int r, int s, int t, int k, int i){
if(l <= s && r >= k) {
laz[i] = k;
sum[i] = k*(t-s+1);
return ;
}
pd(i, s, t);
int mid = s+t>>1;
if(l <= mid) update(l, r, s, mid, k, i*2);
if(r > mid) update(l, r, mid+1, t, k, i*2+1);
}
ll getsum(int l, int r, int s, int t, int i) {
if(l <= s && r >= t) return sum[i];
int mid = s+t>>1;
pd(i, s, t);
ll sum = 0;
if(l <= mid) sum = getsum(l, r, s, mid, i*2);
if(r > mid) sum += getsum(l, r, mid+1, t, i*2+1);
return sum;
}
|