题目链接:登录—专业IT笔试面试备考平台_牛客网
题意比较简单,就是先给出一棵树的点及其边,然后让我们对其边进行分组,分组要求是任意两条边必须要有公共点,问我们一共的方案组数。
这道题目是一道树形DP题目,我们令f[i]表示以i为根的子树中的合法方案数,我们先来思考一个问题,假如我们当前遍历到了节点x,我们应该怎么确定其子节点与x的连边与以子节点为根的子树的边分到一组呢还是与x的另一条边分为一组呢?其实我建议大家不要这样思考问题,这样会有很多种情况,我建议从下往上进行思考,比如说我现在已经遍历到了叶子节点,当然这种情况分组情况肯定默认为1,现在我们再向后退一层,就是当前节点x的所有子节点全是叶子节点,如果当前子节点个数为cnt个,如果cnt为偶数,我们以x为根的子树中有没有可能存在边与x以及其父节点之间连边分为一组呢?其实这种情况是不会存在的,我们假如以x为根的子树中有一条边与x以及其父节点之间的连边分为一组,那么由于x与其父节点只会有一条连边,所以x与其其他子节点的连边则不可能与x与其父节点之间的连边分为一组,那么只能自己内部分组,但是我们发现,除了一条已经分好组的边,其他剩余的边数为奇数,所以无论如何也不可能再完成分组,所以这种情况是不会存在的。再思考一种情况,就是如果cnt是奇数,那么有没有可能以x为根的子树中的边内部分组呢?这显然也是不可能的对吧,所以我们就可以得到一条信息,一棵子树中的边是否与其父节点之间的边分为一组取决于当前未分组的边数,如果是偶数则不会也不可能与其父节点之间的边分为一组,而如果是奇数那就必须要分为一组了,所以现在问题来了,那是不是对于每一层而言,这个cnt仅仅是他的子节点个数呢?答案其实也不是的,比如当前x节点有一个子节点j,j内部的边需要与x与j之间的边分为一组,则x的这条边则不再参与分组,也不会被计入cnt了,所以cnt的含义到底是什么呢?其实就是当前节点还未分组的边的数量,知道了这一点就好说了,注意任何节点与其子节点的分组情况都是独立的,是符合乘法原理的,则f[x]里面不仅要包括所有的子节点f[j]的乘积,还要包括x当前未分组的边的分组情况,所以我们剩下的问题就是处理如何求对剩余的边进行分组的方案数了。
比方说,当前有n个点,我们要两两分为一组,那么一共有多少种分组方案呢?
为了方便讲解,我就把思考过程写在纸上了。
?我们先来思考一下n为偶数的情况,先对n进行一个全排列,显然有n!种方案,但是每两个分为一组的数是没有先后之分的,这样的组有n/2组,那么也就是除以2^(n/2),组与组之间的分组顺序也是不做区分的,所以还需要除以(n/2)!,那么再进行一下化简即可得到(n-1)!!
再来看一下奇数的情况,还是先对n进行一个全排列,有n!种方案,我们先忽略单独的一个,但是每两个分为一组没有先后之分,这样的组有(n-1)/2组,那么也就是除以2^((n-1)/2),又由于组与组之间的分组顺序也是不做区分的,所以还需要除以((n-1)/2)!,所以化简完也可以得到n!!,我们再来换一种思考方式,就是我们先随机拿出一个,这样的可能是n种,然后剩下的n-1就是偶数了,所以总的方案数是(n-1-1)!!,再乘以n刚好是n!!
总的来说,将n个数两两一组进行分组,总的方案数就是1~n中的所有奇数的乘积。
处理完这个问题,这道题目基本上就算是解决了,下面是代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<iostream>
#include<queue>
using namespace std;
const int N=1e6+10,mod=998244353;
int h[N],e[N],ne[N],idx;
long long f[N];
void add(int x,int y)
{
e[idx]=y;
ne[idx]=h[x];
h[x]=idx++;
}
int dfs(int x,int father)//返回当前节点内部边是否能自行分组,1代表不能,0代表能
{
f[x]=1;
int cnt=0;//记录当前节点有多少个子节点的边是内部完成分组的
for(int i=h[x];i!=-1;i=ne[i])
{
int j=e[i];
if(j==father) continue;
if(dfs(j,x)==0) cnt++;
f[x]=f[x]*f[j]%mod;//当前节点分组与子节点的分组是符合乘法原理的
}
if(cnt&1)
for(int i=cnt+1;i>=2;i-=2)
f[x]=(i-1)*f[x]%mod;
else
for(int i=cnt;i>=2;i-=2)
f[x]=(i-1)*f[x]%mod;
return (cnt&1);
}
int main()
{
int n;
memset(h,-1,sizeof h);
cin>>n;
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
add(u,v);add(v,u);
}
dfs(1,-1);
printf("%lld",f[1]);
return 0;
}
|