题目描述
题解
既然是和
T
T
T 的前缀匹配,那么我们不妨枚举
S
S
S 中的每个子串,求出当其作为
T
T
T 的前缀能匹配的最长子串时
T
T
T 的数量。
由于是枚举子串且不重复,所以我们先暴力建出一个后缀字典树。先不考虑合不合法,当我们枚举子串
x
x
x 时,对应
T
T
T 的数量
v
a
l
x
=
(
5
?
o
u
t
x
)
?
5
m
?
l
e
n
x
?
1
val_x=(5-out_x)*5^{m-len_x-1}
valx?=(5?outx?)?5m?lenx??1(满足
l
e
n
x
len_x
lenx? 即子串长度
≤
m
\le m
≤m,特别地,当
l
e
n
x
=
m
len_x=m
lenx?=m 时,
v
a
l
x
=
1
val_x=1
valx?=1),其中
o
u
t
x
out_x
outx? 表示后缀字典树上
x
x
x 的儿子个数。这个其实很好理解,只需要在
x
x
x 后面加一个字符使新前缀在
S
S
S 中不存在,那么后面的字符随便选。
显然只要
x
x
x 合法,那么贡献一定为
v
a
l
x
val_x
valx?。接下来考虑合法条件的转化:题目中的匹配算法相当于要求我们在做 KMP 算法的时候,一旦遇到失配的地方,此处的最长 Border 必须为0。由于我们枚举了最长匹配前缀
x
x
x,那么所有失配的地方必然是后缀字典树上
x
x
x 的祖先(前缀)的儿子节点。当然
x
x
x 的儿子也算。 考虑某一个失配点
y
y
y,当其在
S
S
S 中不作为
x
x
x 的子串,也不作为
x
x
x 的前缀的子串,而是作为独立的一个串出现的时候,假设出现位置前面的匹配都是合法的,那么此处枚举到
y
y
y 的时候匹配必然清零,也就是要满足
k
m
p
y
=
0
kmp_y=0
kmpy?=0。
所以我们判断
x
x
x 是否合法的时候,可以判断所有能够独立出现的失配串
y
y
y 中有无最长 Border 不为 0 的,并且这个判断是充要的。
怎么维护能够独立出现的失配串集合呢?考虑在后缀字典树上 DFS,从
x
x
x 走到它儿子
v
v
v,我们需要把
v
v
v 的兄弟节点加入集合,然后对于所有作为
v
v
v 的后缀且在
v
v
v 中仅出现一次的失配串,扣除它们在
v
v
v 中出现的那部分。设
c
n
t
x
cnt_x
cntx? 表示子串
x
x
x 在
S
S
S 中的出现次数,那么你需要这些失配串在集合中出现次数减
c
n
t
v
cnt_v
cntv?,如果减到 0 了就移出集合。
找这个需要扣除次数的失配串其实很简单,就是
x
→
v
x\rightarrow v
x→v 做 KMP 的时候失配的那些 Border 的某个儿子。 我们要找的失配串是
v
v
v 的后缀,所以它的父亲一定是
x
x
x 的后缀,而它的父亲又是
x
x
x 的某个前缀,所以我们可以枚举
x
x
x 的 Border 然后判断。如果在跳 Border 的时候发现某个 Border 匹配上了
v
v
v,那么就没必要再跳了,因为后面即使找到了失配串,也一定是某个祖先扣除过的,再扣就重了。所以在用 KMP 求
v
v
v 的最长 Border 的同时检查失配跳过的那些串,刚好可以不重不漏。
维护集合只需要用桶记录出现次数,同时记录最长 Border 不为 0 的失配串个数即可。虽然是按 DFS 序做 trie 上 KMP,还要回溯,但是复杂度仍然不会劣于对
S
S
S 的每个后缀做 KMP 的复杂度,即
O
(
n
2
)
O(n^2)
O(n2)。
代码
#include<bits/stdc++.h>
#define ll long long
#define lll __int128
#define uns unsigned
#define fi first
#define se second
#define IF (it->fi)
#define IS (it->se)
#define END putchar('\n')
#define lowbit(x) ((x)&-(x))
#define inline jzmyyds
using namespace std;
const int MAXN=2005;
const ll INF=1e18;
ll read(){
ll x=0;bool f=1;char s=getchar();
while((s<'0'||s>'9')&&s>0){if(s=='-')f^=1;s=getchar();}
while(s>='0'&&s<='9')x=(x<<1)+(x<<3)+(s^48),s=getchar();
return f?x:-x;
}
int ptf[50],lpt;
void print(ll x,char c='\n'){
if(x<0)putchar('-'),x=-x;
ptf[lpt=1]=x%10;
while(x>9)x/=10,ptf[++lpt]=x%10;
while(lpt>0)putchar(ptf[lpt--]^48);
if(c>0)putchar(c);
}
const ll MOD=998244353;
ll ksm(ll a,ll b,ll mo){
ll res=1;
for(;b;b>>=1,a=a*a%mo)if(b&1)res=res*a%mo;
return res;
}
ll mi[MAXN],ans;
int tr[4000005][5],IN;
int n,m,fa[4000005],cnt[4000005];
int st[4000005],nx[4000005],nm;
char s[MAXN];
void dfs(int x,int len){
for(int i=0;i<5;i++)if(tr[x][i]){
int v=tr[x][i],tf=fa[x];
while(tf>0&&nx[tf]!=i)tf=fa[tf];
if(x>0&&nx[tf]==i)fa[v]=tr[tf][i];
else fa[v]=0;
}
for(int i=0;i<5;i++)if(tr[x][i]){
int v=tr[x][i],tf=fa[x];
while(tf>0&&nx[tf]!=i){
if(tr[tf][i]){
int o=tr[tf][i];
st[o]-=cnt[v];
if(!st[o])nm-=(fa[o]>0);
}tf=fa[tf];
}
for(int j=0;j<5;j++)if((j^i)&&tr[x][j])
st[tr[x][j]]+=cnt[tr[x][j]],nm+=(fa[tr[x][j]]>0);
nx[x]=i,dfs(v,len+1);
for(int j=0;j<5;j++)if((j^i)&&tr[x][j])
st[tr[x][j]]-=cnt[tr[x][j]],nm-=(fa[tr[x][j]]>0);
tf=fa[x];
while(tf>0&&nx[tf]!=i){
if(tr[tf][i]){
int o=tr[tf][i];
if(!st[o])nm+=(fa[o]>0);
st[o]+=cnt[v];
}tf=fa[tf];
}
}
bool ok=1;int cnt=5;
for(int j=0;j<5;j++)if(tr[x][j])ok&=(fa[tr[x][j]]==0),cnt--;
if(!ok)return;
if(!nm){
if(len<m)(ans+=cnt*mi[m-len-1])%=MOD;
else if(len==m)ans++,ans%=MOD;
}
}
int main()
{
freopen("match.in","r",stdin);
freopen("match.out","w",stdout);
n=read(),m=read(),*new(int)=scanf("%s",s+1),mi[0]=1;
for(int i=1;i<=m+1;i++)mi[i]=mi[i-1]*5%MOD;
for(int l=1,p;l<=n;l++){
p=0;
for(int r=l;r<=n;r++){
int c=s[r]-'a';
if(!tr[p][c])tr[p][c]=++IN;
p=tr[p][c],cnt[p]++;
}
}dfs(0,0);
print(ans*ksm(mi[m],MOD-2,MOD)%MOD);
return 0;
}
|