题目链接:点击跳转
代码如下:
#include<bits/stdc++.h>
#include <ostream>
using namespace std;
typedef long long ll;
#define endl '\n'
typedef pair<int, int> PII;
#define debug() cout.flush()
#define for0(i, a) for (int i = 0; i < a; ++i)
#define REP(i, a, b) for (int i = a; i < b; ++i)
#define FOR(i, a, b) for (int i = a; i <= b; ++i)
#define REPC(i, a, b, c) for (ll i = a; i < b && i < c; ++i)
#define RREP(i, a, b) for (int i = a; i >= b; --i)
const ll MOD = 1e9 + 7;
const ll mod = 998244353;
const int INF = 0x3f3f3f3f;
const int MAXN = 1e5 + 5e3;
inline void init() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cout.tie(nullptr);
}
struct Edge{
int to, nxt;
}e[MAXN << 1];
int n, num, head[MAXN], cnt, w[MAXN];
int f[MAXN], deep[MAXN], siz[MAXN], son[MAXN];
int top[MAXN], id[MAXN], rk[MAXN];
int sum[MAXN << 2], maxn[MAXN << 2];
inline void add(int u, int v) {
e[cnt] = Edge{v, head[u]};
head[u] = cnt++;
}
inline void dfs1(int u, int fa) {
f[u] = fa; deep[u] = deep[fa] + 1; siz[u] = 1;
for (int i = head[u]; i != -1; i = e[i].nxt) {
if (e[i].to != fa) {
dfs1(e[i].to, u);
siz[u] += siz[e[i].to];
if (siz[e[i].to] > siz[son[u]]) {
son[u] = e[i].to;
}
}
}
}
inline void dfs2(int u, int t) {
top[u] = t;
id[u] = num;
rk[num++] = u;
if (son[u]) {
dfs2(son[u], t);
}
for (int i = head[u]; i != -1; i = e[i].nxt) {
if (e[i].to != f[u] && e[i].to != son[u]) {
dfs2(e[i].to, e[i].to);
}
}
}
inline void pushup(int x) {
sum[x] = sum[x << 1] + sum[x << 1 | 1];
maxn[x] = max(maxn[x << 1], maxn[x << 1 | 1]);
}
inline void build(int l, int r, int x) {
if (l == r) {
maxn[x] = sum[x] = w[rk[l]];
return;
}
int mid = l + r >> 1;
build(l, mid, x << 1);
build(mid + 1, r, x << 1 | 1);
pushup(x);
}
inline void update(int l, int r, int x, int q, int val) {
if (l == r) {
sum[x] = maxn[x] = val;
return;
}
int mid = l + r >> 1;
if (q <= mid) update(l, mid, x << 1, q, val);
else update(mid + 1, r, x << 1 | 1, q, val);
pushup(x);
}
inline int querySum(int l, int r, int x, int ql, int qr) {
if (l >= ql && r <= qr) {
return sum[x];
}
int mid = l + r >> 1;
int res = 0;
if (ql <= mid) res += querySum(l, mid, x << 1, ql, qr);
if (qr > mid) res += querySum(mid + 1, r, x << 1 | 1, ql, qr);
return res;
}
inline int queryMax(int l, int r, int x, int ql, int qr) {
if (l >= ql && r <= qr) {
return maxn[x];
}
int mid = l + r >> 1;
int res = -INF;
if (ql <= mid) res = max(res, queryMax(l, mid, x << 1, ql, qr));
if (qr > mid) res = max(res, queryMax(mid + 1, r, x << 1 | 1, ql, qr));
return res;
}
inline int getSum(int a, int b) {
int res = 0;
while (top[a] != top[b]) {
if (deep[top[a]] < deep[top[b]]) swap(a, b);
res += querySum(1, n, 1, id[top[a]], id[a]);
a = f[top[a]];
}
if (deep[a] > deep[b]) swap(a, b);
return res + querySum(1, n, 1, id[a], id[b]);
}
inline int getMax(int a, int b) {
int res = -INF;
while (top[a] != top[b]) {
if (deep[top[a]] < deep[top[b]]) swap(a, b);
res = max(res, queryMax(1, n, 1, id[top[a]], id[a]));
a = f[top[a]];
}
if (deep[a] > deep[b]) swap(a, b);
return max(res, queryMax(1, n, 1, id[a], id[b]));
}
inline void solve() {
num = 1;
cnt = 0;
memset(head, -1, sizeof(head));
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
for (int i = 1; i <= n; i++) cin >> w[i];
dfs1(1, 0);
dfs2(1, 1);
build(1, n, 1);
int q;
cin >> q;
while (q--) {
int u, t;
string op;
cin >> op >> u >> t;
if (op == "QMAX") {
cout << getMax(u, t) << endl;
} else if (op == "QSUM") {
cout << getSum(u, t) << endl;
} else {
update(1, n, 1, id[u], t);
}
}
}
signed main() {
init();
cin >> n;
solve();
return 0;
}
|