题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=3308
大意:给出n个数,m次操作。操作如下:
1:U A B 表示将下标为A的数修改为B。
2:Q A B 询问区间[A,B]内最长连续递增子序列的长度。
思路:线段树单点修改+区间查询。
先看看变量的意义:
- ?表示以编号l作为左端点的最长连续递增子序列长度。
- ?表示以编号r作为右端点的最长连续递增子序列长度。
- ?表示[l,r]内最长连续递增子序列长度。
那区间怎样合并呢?
首先看,有三种情况:
- 左区间中的最大值:
- 右区间中的最大值:
- 如果左区间最右边的数比右区间最左边的数小的话,那么以左区间右端点的和以右区间左端点的是可以合并在一起作为连续递增子序列的。所以还有:
- 取三者中的最大值
然后是:同样,如果左区间最右边的数比右区间最左边的数小的话,可以合并,那么。
最后是:和上面一样,。
那区间查询呢?
?和普通线段树一样,只是最后注意是如果左区间最右边的数比右区间最左边的数小的话,是可以合并的,但是和取最小值,和取最小值,因为肯定要在区间内吧。
int res = 0;
res = max(query(l, mid, ql, qr, rt << 1),
query(mid + 1, r, ql, qr, rt << 1 | 1));
int lsum = min(qr - mid, lmax[rt << 1 | 1]);
int rsum = min(mid - ql + 1, rmax[rt << 1]);
if(val[mid] < val[mid+1]) res = max(res, lsum + rsum);
Code
#include <bits/stdc++.h>
#define ll long long
#define pir pair<int, int>
#define pirl pair<ll, ll>
#define debug(x) cout << #x << ":" << x << "\n"
const int mod = 1e9 + 7;
const ll ds = 1e18;
const double eps = 1e-8;
using namespace std;
const int N = 1e6 + 5;
int val[N];
int lmax[N], rmax[N], mmax[N];
void pushUp(int l, int r, int rt) {
int mid = (l + r) >> 1;
mmax[rt] = max(mmax[rt << 1], mmax[rt << 1 | 1]);
if (val[mid] < val[mid + 1])
mmax[rt] = max(mmax[rt], rmax[rt << 1] + lmax[rt << 1 | 1]);
lmax[rt] = lmax[rt << 1];
if (lmax[rt] == mid - l + 1 && val[mid] < val[mid + 1])
lmax[rt] += lmax[rt << 1 | 1];
rmax[rt] = rmax[rt << 1 | 1];
if (rmax[rt] == r - mid && val[mid] < val[mid + 1])
rmax[rt] += rmax[rt << 1];
}
void build(int l, int r, int rt) {
if (l == r) {
lmax[rt] = rmax[rt] = mmax[rt] = 1;
return;
}
int mid = (l + r) >> 1;
build(l, mid, rt << 1);
build(mid + 1, r, rt << 1 | 1);
pushUp(l, r, rt);
}
void update(int l, int r, int x, int y, int rt) {
if (l == r) {
val[l] = y;
lmax[rt] = rmax[rt] = mmax[rt] = 1;
return;
}
int mid = (l + r) >> 1;
if (x <= mid)
update(l, mid, x, y, rt << 1);
else
update(mid + 1, r, x, y, rt << 1 | 1);
pushUp(l, r, rt);
}
int query(int l, int r, int ql, int qr, int rt) {
if (ql <= l && qr >= r) return mmax[rt];
int mid = (l + r) >> 1;
if (qr <= mid)
return query(l, mid, ql, qr, rt << 1);
else if (ql > mid)
return query(mid + 1, r, ql, qr, rt << 1 | 1);
int res = 0;
res = max(query(l, mid, ql, qr, rt << 1),
query(mid + 1, r, ql, qr, rt << 1 | 1));
int lsum = min(qr - mid, lmax[rt << 1 | 1]);
int rsum = min(mid - ql + 1, rmax[rt << 1]);
if(val[mid] < val[mid+1]) res = max(res, lsum + rsum);
return res;
}
void solve() {
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", &val[i]);
}
build(1, n, 1);
char c;
int a, b;
while (m--) {
cin >> c;
scanf("%d%d", &a, &b);
if (c == 'U') {
update(1, n, a + 1, b, 1);
} else {
printf("%d\n", query(1, n, a + 1, b + 1, 1));
}
}
}
int main() {
int T;
cin >> T;
while (T--) solve();
// system("exit");
return 0;
}
|