题目分析
将每次添加的字符……看成一个节点(节点有标号)?!!
添加一个字符时删除的字符……都是它的儿子?!!
这样 0
就是叶子节点,1
就是非叶节点。任意一个添加字符的方案,构成了一种唯一的森林,而本题求的东西就变成了……这个森林只有一棵树,节点数为 n的方案数?!!
这什么神仙思路啊 (╯°Д°)╯︵ ┻━┻
然后设 f(i)表示 i个节点构成的树的方案数(该树根节点一定是一个字符 1
),g(i)表示 i个节点构成的森林的方案数(森林里每棵树的根节点都是字符 1
),A(i)=[i∈A],B(i)=[i∈B],就有如下 DP 方程:
f(n)=n−1∑i=0A(i)g(n−1−i)Cin−1+B(n−1)
g(n)=n∑i=1f(i)g(n−i)Ci−1n−1
其中
f的方程是考虑与根节点直接相连的叶子有几个,
g是考虑
1号节点所在的那棵树的大小。
然后写成卷积的形式:
f(n)(n−1)!=n−1∑i=0g(n−1−i)(n−1−i)!A(i)i!+B(n−1)n−1
g(n)n!=1n(n∑i=1f(i)(i−1)!g(n−i)(n−i)!)
发现 f的 DP 式中带的那个卷积其实卷出来是 n−1,所以将 A(i)整体往后平移一格,卷积式改成 ∑ni=1A(i)(i−1)!g(n−i)(n−i)!。
还有一点就是每次删除的子段不能为空,所以强制让 B(0)=0,最后全部算出来后再让 f(1)+=1(即只有一个字符 0
即可)
初值:f(0)=0,g(0)=1。
然后一个分治 NTT 即可解决,完结,撒花!

……个鬼啦!
这个分治 NTT 怎么做啊!g的方程里又有 f又有 g的!
我花了一下午在这上面啊!

好吧,其实是这样做的。首先,分治 NTT 的时候,是递归左半边区间,用 [l,mid]里的值去更新 [mid+1,r]中的值,然后递归右半边区间嘛。
对于 f,用 [l,mid]里的 g(已经求出正确的值了)和前 r−l+1项 A去卷积,加到右半边的 f中。
对于 g,用 [l,mid]里的 f(已经求出正确的值了)和前 r−l+1项 g去卷积,加到右半边的 g中。但是有一个问题,若 l=1,可能出现前 r−l+1项 g还没有完全算出的情况。
设 i+j=k,j>mid,i≤mid,那么此时我们漏算了 f(i)g(j)对 g(k)的贡献。
不过没关系,当我们递归到一个左半边包含 j的 [l,r](此时 l≠1)时,我们再用 [l,mid]里的 g和前 r−l+1项 f去卷积,加到右半边的 g中即可。
所以当 l≠1时,要额外多卷积一次。
另外,注意不要不小心使用在本次卷积算贡献时,用被更新了的 f和 g去卷。
代码
#include<bits/stdc++.h>
using namespace std;
#define RI register int
int read() {
int q=0;char ch=' ';
while(ch<'0'||ch>'9') ch=getchar();
while(ch>='0'&&ch<='9') q=q*10+ch-'0',ch=getchar();
return q;
}
const int mod=998244353,N=262150;
int n,ma,mb,ans;
int fac[N],inv[N],ifac[N],A[N],B[N],f[N],g[N],rev[N];
int k1[N],k2[N],k3[N],k4[N];
int qm(int x) {return x>=mod?x-mod:x;}
int ksm(int x,int y) {
int re=1;
for(;y;y>>=1,x=1LL*x*x%mod) if(y&1) re=1LL*re*x%mod;
return re;
}
void NTT(int *a,int n,int x) {
for(RI i=0;i<n;++i) if(rev[i]>i) swap(a[i],a[rev[i]]);
for(RI i=1;i<n;i<<=1) {
int gn=ksm(3,(mod-1)/(i<<1));
for(RI j=0;j<n;j+=(i<<1)) {
int t1,t2,g=1;
for(RI k=0;k<i;++k,g=1LL*g*gn%mod) {
t1=a[j+k],t2=1LL*g*a[j+i+k]%mod;
a[j+k]=qm(t1+t2),a[j+i+k]=qm(t1-t2+mod);
}
}
}
if(x==1) return;
reverse(a+1,a+n);int invn=ksm(n,mod-2);
for(RI i=0;i<n;++i) a[i]=1LL*a[i]*invn%mod;
}
void prework() {
fac[0]=1;for(RI i=1;i<=n;++i) fac[i]=1LL*fac[i-1]*i%mod;
inv[0]=inv[1]=1,ifac[0]=1;
for(RI i=2;i<=n;++i) inv[i]=1LL*(mod-mod/i)*inv[mod%i]%mod;
for(RI i=1;i<=n;++i) ifac[i]=1LL*ifac[i-1]*inv[i]%mod;
}
void work(int l,int r) {
int mid=(l+r)>>1,n=r-l+1,kn=1,len=0;
while(kn<(n<<1)) kn<<=1,++len;
for(RI i=0;i<kn;++i) rev[i]=(rev[i>>1]>>1)|((i&1)<<(len-1));
{ for(RI i=0;i<kn;++i) k1[i]=k2[i]=0;
for(RI i=l;i<=mid;++i) k1[i-l]=g[i];
for(RI i=1;i<=n;++i) k2[i]=1LL*A[i]*ifac[i-1]%mod;
NTT(k1,kn,1),NTT(k2,kn,1);
for(RI i=0;i<kn;++i) k1[i]=1LL*k1[i]*k2[i]%mod;
NTT(k1,kn,-1);
for(RI i=mid+1;i<=r;++i) f[i]=qm(f[i]+k1[i-l]);
}
{ for(RI i=0;i<kn;++i) k1[i]=k2[i]=0;
for(RI i=l;i<=mid;++i) k1[i-l]=f[i];
for(RI i=0;i<=n;++i) k2[i]=g[i];
NTT(k1,kn,1),NTT(k2,kn,1);
for(RI i=0;i<kn;++i) k1[i]=1LL*k1[i]*k2[i]%mod;
NTT(k1,kn,-1);
for(RI i=mid+1;i<=r;++i) g[i]=qm(g[i]+k1[i-l]);
}
if(l>1) {
for(RI i=0;i<kn;++i) k1[i]=k2[i]=0;
for(RI i=l;i<=mid;++i) k1[i-l]=g[i];
for(RI i=0;i<=n;++i) k2[i]=f[i];
NTT(k1,kn,1),NTT(k2,kn,1);
for(RI i=0;i<kn;++i) k1[i]=1LL*k1[i]*k2[i]%mod;
NTT(k1,kn,-1);
for(RI i=mid+1;i<=r;++i) g[i]=qm(g[i]+k1[i-l]);
}
}
void cdq(int l,int r) {
if(l==r) {
f[l]=qm(f[l]+1LL*B[l-1]*ifac[l-1]%mod);
g[l]=1LL*qm(g[l]+f[l])*inv[l]%mod;
return;
}
int mid=(l+r)>>1;cdq(l,mid),work(l,r),cdq(mid+1,r);
}
int main()
{
n=read(),ma=read(),mb=read();
for(RI i=1;i<=ma;++i) A[read()+1]=1;
for(RI i=1;i<=mb;++i) B[read()]=1;
prework();
B[0]=0,g[0]=1,cdq(1,n),ans=1LL*f[n]*fac[n-1]%mod;
if(n==1) ans=qm(ans+1);
printf("%d\n",ans);
return 0;
}
0 条评论