Loading [MathJax]/jax/output/HTML-CSS/jax.js

题目链接

因为 n这个值是最大的,前缀最大不用考虑 n后面的,后缀最大不用考虑 n前面的

f(i,j)表示 i个数的排列,有 j个数字是前缀最大的方案数

考虑枚举最小的数字放在哪个位置,放在第一个位置则等于 f(i1,j1)。放在别的位置则等于 f(i1,j),共有 i1个 “别的位置”

所以:

f(i,j)=f(i1,j1)+(i1)×f(i1,j)

这就是第一类斯特林数 S1(i,j)

其原因是:有 j个前缀最大,相当于把 i个数的排列分成了 j段,第 k段为 [k个前缀最大 ,k+1个前缀最大 )

答案 Ans=ni=1S1(i1,a1)×S1(ni,b1)×Ci1n1

也就是选择 i1个数字放到 n前面,剩下的放到 n后面这样

n前面的结成 a1个环,n后面的结成 b1个环,相当于一共用 n1个数字结成 a+b2个环

当然这 a+b2个环中有 a1个是正向排列的,b1个是反向排列的,所以有:

Ans=S1(n1,a+b2)×Cb1a+b2

然后因为第一类斯特林数有生成函数:

n1i=0(x+i)

这个生成函数的 k次项的系数就是 S1(n,k)

这个用分治 NTT 求就行了

具体看代码吧

#include <bits/stdc++.h>

#define NS (262144)
#define LGS (18)
#define MOD (998244353)
#define G (3)

#define pls(a, b) ((a) + (b) < MOD ? (a) + (b) : (a) + (b) - MOD)
#define mns(a, b) ((a) - (b) < 0 ? (a) - (b) + MOD : (a) - (b))
#define mul(a, b) (1ll * (a) * (b) % MOD)
#define Inv(a) (qpow((a), MOD - 2))

using namespace std;

template<typename _Tp> inline void IN(_Tp& dig)
{
    char c; bool flag = 0; dig = 0;
    while (c = getchar(), !isdigit(c)) if (c == '-') flag = 1;
    while (isdigit(c)) dig = dig * 10 + c - '0', c = getchar();
    if (flag) dig = -dig;
}

int qpow(int a, int b)
{
    int res = 1;
    while (b)
    {
        if (b & 1) res = mul(res, a);
        a = mul(a, a), b >>= 1;
    }
    return res;
}

int n, A, B;

int rev[NS];

struct poly
{
    int d[NS], N, bs;
    int& operator [] (const int a) {return d[a];}
    void resize(int s)
    {
        int tmp = N;
        N = 1, bs = 0;
        while (N < s) N <<= 1, bs++;
        for (int i = tmp; i < N; i += 1) d[i] = 0;
    }
    void ntt(int t)
    {
        for (int i = 1; i < N; i += 1)
        {
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bs - 1));
            if (i < rev[i]) swap(d[i], d[rev[i]]);
        }
        for (int l = 1; l < N; l <<= 1)
        {
            int dt = qpow(G, (MOD - 1) / (l << 1));
            if (t == -1) dt = Inv(dt);
            for (int i = 0; i < N; i += (l << 1))
            {
                int g = 1, t1, t2;
                for (int j = i; j < i + l; j += 1, g = mul(g, dt))
                {
                    t1 = d[j], t2 = mul(g, d[j + l]);
                    d[j] = pls(t1, t2), d[j + l] = mns(t1, t2);
                }
            }
        }
        if (t == -1)
        {
            int inv = Inv(N);
            for (int i = 0; i < N; i += 1) d[i] = mul(d[i], inv);
        }
    }
    void operator *= (poly &oth)
    {
        for (int i = 0; i < N; i += 1) d[i] = mul(d[i], oth[i]);
    }
} P[LGS];

stack<int> rub;

int Binary(int l, int r)
{
    if (l == r)
    {
        int a = rub.top(); rub.pop();
        P[a].resize(2), P[a][0] = l, P[a][1] = 1;
        return a;
    }
    int mid = (l + r) >> 1;
    int a = Binary(l, mid), b = Binary(mid + 1, r);
    P[a].resize(r - l + 2), P[b].resize(r - l + 2);
    P[a].ntt(1), P[b].ntt(1), P[a] *= P[b], P[a].ntt(-1), rub.push(b);
    return a;
}

int C(int a, int b)
{
    int x = 1, y = 1;
    for (int i = a - b + 1; i <= a; i += 1) x = mul(x, i);
    for (int i = 1; i <= b; i += 1) y = mul(y, i);
    return mul(x, Inv(y));
}

int main(int argc, char const* argv[])
{
    IN(n), IN(A), IN(B), n--;
    if (!A || !B || A + B - 2 > n) puts("0"), exit(0);
    if (!n) puts("1"), exit(0);
    for (int i = 0; i < LGS; i += 1) rub.push(i);
    int a = Binary(0, n - 1);
    printf("%lld\n", mul(P[a][A + B - 2], C(A + B - 2, B - 1)));
    return 0;
}
C++
分类: 文章

Remmina

No puzzle that couldn't be solved.

5 条评论

boshi · 2019年1月27日 12:40 下午

有一种 O(nlogn)的倍增算法,比你的快到不知道哪里去了。

回复 boshi 取消回复

Avatar placeholder

您的邮箱地址不会被公开。 必填项已用 * 标注