题意: 一个n 个点n-1 条边的连通图,图中距离为2 的点对
(
u
,
v
)
(u,v)
(u,v)会产生
w
u
?
w
v
w_u * w_v
wu??wv?的联合权值,求图中所有的点对的联合权值的最大值和所有联合权值的和(和要对10007取模)
首先注意只让对和取模了,对最大值没有让取模,所以要注意最大值不能取模 最大值权值就是最大节点权值乘上次大节点权值
思路一:
首先对距离为2进行入手,可以发现,完全可以枚举,枚举每个点,然后遍历与该点相连的点,对所有的点进行计算就行。这里牵扯到一个计算几个数之间的两两相乘的方法,我也是不知道,做了这道题就知道了 比如
A
B
C
D
A B C D
ABCD四个值,我们需要计算两两相乘的和 我们只需要从前往后遍历 求一个前缀和sum ,用前缀和乘上当前遍历的值,注意顺序为:先求前缀和,再乘上当前值 过程如下:
res = 0
sum = A
res = res + sum * B = A * B
sum = A + B
res = res + sum * C = A * B + (A + B ) * C
sum = A + B + C
res = res + sum * D = A * B + A * C + B * C + (A + B + C) * D
sum = A + B + C + D
所以这就很好记录结果了
!!!注意啊 !!! 建图一定要注意为双向图呀,边的个数最大为2* N个,我就是因为这损失了好长的查bug的时间
思路二:
还是利用遍历的思想,我们使用dfs进行遍历,从随便地一个节点开始即可。 但是我们换了一个统计结果的思路:这个思路还是很巧妙的,我也是没想到 两两之间的乘积我们可以利用数学方法进行推导
2
a
b
=
(
a
+
b
)
2
?
(
a
2
+
b
2
)
2ab = (a+b)^2 - (a^2+b^2)
2ab=(a+b)2?(a2+b2)
2
(
a
b
+
a
c
+
b
c
)
=
(
a
+
b
+
c
)
2
?
(
a
2
+
b
2
+
c
2
)
2(ab+ac+bc) = (a+b+c)^2-(a^2+b^2+c^2)
2(ab+ac+bc)=(a+b+c)2?(a2+b2+c2) 一般而言,可以得 两两之间的乘积 = 所有值的和的平方 - 所有值平方的和
然后我们就利用上述的思路进行求解,使用dfs,只能向子节点遍历,但是父节点的信息也需要统计上
!!!注意啊!!!
我们统计结果的时候出现了减法的操作,而且减法的两个操作数都有取模运算,可能存在一种结果,前面的数取模之后小于后面取模之后的数,最后结果可能为负数,所以我们要对结果加上一个模数然后再取模
思路一代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 2e5+5;
const int mod = 10007;
int h[N],e[N*2],ne[N*2],idx;
int n,w[N],mx,res;
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
int main()
{
scanf("%d",&n);
memset(h,-1,sizeof h);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
for(int u=1;u<=n;u++)
{
int fmx = 0,smx = 0,sum = 0;
for(int i=h[u];~i;i=ne[i])
{
int v = e[i];
if(w[v]>=fmx) smx = fmx,fmx = w[v];
else if(w[v]>smx) smx = w[v];
res = (res + sum * w[v]) % mod;
sum = (sum + w[v]) % mod;
}
mx = max(mx,fmx * smx);
}
printf("%d %d\n",mx,(res*2)%mod);
return 0;
}
思路二代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 2e5+5,M = 2*N;
const int mod = 10007;
int h[N],e[M],ne[M],idx;
int w[N],n;
int mx,res;
void add(int a,int b)
{
e[idx]=b,ne[idx]=h[a],h[a]=idx++;
}
void dfs(int u,int p)
{
int sum = 0,fmx = 0,smx = 0,ans = 0;
for(int i=h[u];~i;i=ne[i])
{
int v = e[i];
if(v!=p) dfs(v,u);
if(w[v]>=fmx) smx = fmx,fmx = w[v];
else if(w[v]>smx) smx = w[v];
sum = (sum + w[v]) % mod;
ans = (ans + w[v] * w[v]) % mod;
}
mx = max(mx,fmx * smx);
res = (res + sum * sum % mod - ans ) %mod;
}
int main()
{
scanf("%d",&n);
memset(h,-1,sizeof h);
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);
add(v,u);
}
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
dfs(1,-1);
printf("%d %d\n",mx,(res+mod)%mod);
return 0;
}
往期优质文章推荐
|