第一类斯特林数的一种 $O(n\log n)$倍增求法

摘要

本文介绍了一种可以 $O(n\log n)$求出第一类斯特林数 $S_n^1,S_n^2\cdots,S_n^n$的方法。

什么是第一类斯特林数

第一类斯特林数 $S_n^m$是将 $n$个元素划分为 $m$个圆排列的方案数。

圆排列:将 $1\cdots n$的元素排列成一个环的方案数,两个圆排列不同当且仅当其不旋转同构。

第一类斯特林数递推式

直接根据其组合意义得到:$S_n^m=S_{n-1}^{m-1}+(n-1)S_{n-1}^{m}$

将这个递推式写成生成函数的形式可以得到 $S_n^i$的生成函数:$S_n^i=[x^i]\prod_{a=0}^{n-1}(x+a)$

快速求出 $S_n^{0\cdots n}$

考虑使用倍增。

假设我们已知 $S_n$的生成函数 (一个 $n$次多项式),现在我们求 $S_{2n}$的生成函数。

设 $S_n$的生成函数为
$$
\prod_{i=0}^{n-1}(x+i)=\sum_{i=0}^{n}f_ix^i
$$
我们现在通过上面的生成函数求出 $\prod_{i=n}^{2n-1}(x+i)$是多少。

接着,将两个生成函数相乘就可以得到 $S_{2n}​$的生成函数。
$$
\begin{aligned}
& \prod_{i=n}^{2n-1}(x+i)\\
=& \prod_{i=0}^{n-1}(x+i+n)\\
=& \sum_{i=0}^{n}f_i(x+n)^i
\end{aligned}
$$
在这里使用二项式定理将其展开:
$$
\begin{aligned}
=& \sum_{i=0}^{n}f_i\sum_{j=0}^{i}C_i^jx^jn^{i-j}\\
=& \sum_{j=0}^{n}\sum_{i=j}^{n}f_i\frac{i!}{j!(i-j)!}x^jn^in^{-j}\\
=& \sum_{j=0}^{n}\frac{1}{j!n^j}(\sum_{i=j}^{n}(f_ii!n^i)\frac{1}{(i-j)!})x^j
\end{aligned}
$$
其中打括号的与 $i​$有关的式子是一个卷积的形式,可以用 NTT 或 FFT 在 $O(n\log n)​$的时间内求出,再在 $O(n)​$的时间内将对应项乘上相应的系数即可得到 $\frac{S_{2n}}{S_n}​$的生成函数。再用 $O(n\log n)​$的时间将 $\frac{S_{2n}}{S_n}​$与 $S_n​$相乘,即可得到 $S_{2n}​$的生成函数。整个过程需要 4 次 NTT。

在卷积时,需要注意以下几点:

  • 卷积的下标范围为 $0$到 $n$的闭区间,因此 NTT 时需要获得至少 $2n+1$个点值。
  • 求 $\frac{S_{2n}}{S_n}$时,设 $F(x)=\sum_{i=0}f_ii!n^ix^i$,$G(x)=\sum_{i=0}\frac{1}{i!}x^i$。由于公式中卷积的形式与我们平常见到的卷积方向相反,所以我们需要将 $F$翻转,与 $G$相乘,再将结果翻转,才能求出正确的结果。
  • 需要保证每次傅里叶变换前数组的高位都是 $0$。

现在,我们已经得到了如何将 $S_n$变成 $S_{2n}$,我们只需要能将 $S_n$变成 $S_{n+1}$,就可以求出 $n$为任何数时的 $S_n$了。

将 $S_n$变成 $S_{n+1}$显然只需要暴力乘上 $(x+n)$项就够了。

用途

可以比分治 $FFT$快一丁点。

经过比较,倍增算法的耗时是分治的 $40\%$,结合一些卡常技巧,更是可以将时间缩短至分治的 $15\%$左右。由于倍增过程中,每个阶段需要做 6 次变换,而分治只需要 3 次,$15\%$应该是这种算法的极限了。

这个方法可以用来帮助解决需要用到斯特林数的题目,如纯斯特林数、Dp、斯特林反演等。

代码 (CF960G)

#include <bits/stdc++.h>
#define MOD 998244353
#define G 3
#define GI 332748118LL
#define MOD2 996491788296388609LL
#define MX 262144

using namespace std;

typedef long long ll;

ll qm(const ll& x) {return (x>=MOD) ? (x-MOD) : (x);}

ll qpow(ll x, ll t)
{
    ll ans = 1;
    while(t)
    {
        if(t & 1) ans = ans * x % MOD;
        x = x * x % MOD;
        t >>= 1;
    }
    return ans;
}

ll inv(ll x)
{
    return qpow(x, MOD-2);
}

struct Fourier
{
    int n, bit, rev[MX];

    void init(int x)
    {
        n = 1, bit = 0;
        while(n < x) n <<=1, ++bit;
        for(int i=1; i<n; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(bit-1));
    }

    void dft(ll* x, int f)
    {
        for(int i=0; i<n; i++) if(i < rev[i]) swap(x[i], x[rev[i]]);
        for(int w=1,b=1; w<n; w<<=1,b++)
        {
            ll wx = (f==1 ? qpow(G, (MOD-1)>>b) : qpow(GI, (MOD-1)>>b));
            for(int j=0; j<n; j+=(w<<1))
            {
                ll wn = 1;
                for(int i=j; i<j+w; i++)
                {
                    ll a = x[i], b = wn*x[i+w];
                    x[i] = (a+b) % MOD;
                    x[i+w] = (a-b+MOD2) % MOD;
                    wn = wn * wx % MOD;
                }
            }
        }
        if(f == -1)
        {
            ll mul = inv(n);
            for(int i=0; i<n; i++) x[i] = x[i] * mul % MOD;
        }
    }
} NTT;

ll fac[MX], faci[MX];
ll t1[MX], t2[MX], t3[MX], t4[MX];

void get_s(ll* x, int n)
{
    if(n == 1)
    {
        x[0] = 0;
        x[1] = 1;
    }
    else if(n & 1)
    {
        get_s(x, n-1);
        x[n] = 0;
        for(int i=n-1; i>=0; i--) x[i+1] = ((n-1) * x[i+1] + x[i]) % MOD;
    }
    else
    {
        get_s(x, n/2);
        NTT.init(n+1);
        for(int i=n/2+1; i<NTT.n; i++) t1[i] = t2[i] = t3[i] = t4[i] = 0;
        for(int i=0; i<=n/2; i++)
        {
            t1[i] = x[i] * qpow(n/2, i) % MOD * fac[i] % MOD;
            t2[i] = faci[i];
        }
        reverse(t1, t1+n/2+1);
        NTT.dft(t1, +1);
        NTT.dft(t2, +1);
        for(int i=0; i<NTT.n; i++) t3[i] = t1[i] * t2[i] % MOD;
        NTT.dft(t3, -1);
        reverse(t3, t3+n/2+1);
        for(int i=0; i<=n/2; i++) t4[i] = t3[i] * inv(qpow(n/2, i)) % MOD * faci[i] % MOD;
        NTT.dft(t4, +1);
        NTT.dft(x, +1);
        for(int i=0; i<NTT.n; i++) x[i] = x[i] * t4[i] % MOD;
        NTT.dft(x, -1);
    }
}

void init()
{
    fac[0] = faci[0] = 1;
    for(int i=1; i<MX; i++) fac[i] = fac[i-1] * i % MOD;
    faci[MX-1] = inv(fac[MX-1]);
    for(int i=MX-1; i>=1; i--) faci[i-1] = faci[i] * i % MOD;
}

ll stiring[MX];
int n, a, b;

int main()
{
    scanf("%d%d%d", &n, &a, &b);
    if(n==1 && a==1 && b==1) printf("%d\n", 1);
    else if(a+b-1>n || a<1 || b<1) printf("%d\n", 0);
    else
    {
        init();
        get_s(stiring, n-1);
        ll way = stiring[a+b-2];
        printf("%lld\n", way * fac[a+b-2] % MOD * faci[a-1] % MOD * faci[b-1] % MOD);
    }
    return 0;
}
分类: 文章

发表评论

电子邮件地址不会被公开。 必填项已用*标注