在两年前,我学习了快速傅里叶变换。当时有许多的问题没能透彻地理解。

在半年前,我接触到了任意模数 NTT,但是当时只背了个板子,没有弄懂它的原理。

今天,隔壁机房一个高二的大佬点醒了我,我突然就明白任意模数 NTT 究竟再干啥了。

任意模数 NTT 原理

当我们处理模数为任意 $10^9$左右质数的多项式乘法时,我们往往将每一次 dft 拆分为两次元素大小在 $10^5$级别的 $dft$。

这样,我们总共需要 $8$次 $dft$。

然而,有一个技巧却可以允许我们使用 $4$次 $dft$做到原先的 $8$次 dft 完成的事情。这个技巧人称 $mtt$,它的精髓在于,将两次 $dft$合并为一次完成。

基础知识

默认大家对 fft,dft,ntt 的原理非常熟悉,并且已经学习了有关复数和三角函数的相关知识。

我们先定义一些东西。

我们默认所有的数列的长度、函数的项数等,统统为 $n$。

定义 $dft$是一个作用于函数 $f$的算子,其中 $F=dft(f)$得到的是一个数列 $F$,其中 $F_k=f(\omega_n^k)$。如果我们把多项式函数和数列同等对待的话,我们也可以认为 $F_k=\sum_{i=0}^{n-1}f_i\omega_n^{ki}$

定义 $reverse$是一个作用于数列 $s$的算子,它的结果是 $s$翻转之后的数列,如 $rev({1,2,3})={3,2,1}$

定义 $idft$是一个作用于数列 $F$的算子,它是 $dft$的逆运算。并且,如果我们把多项式函数和数列同等对待的话,$f=idft(F)$,则 $f_k=\frac{1}{n}\sum_{i=0}^{n-1}F_i\omega_n^{-ki}$

定义 $conj(z)$为 $z$的共轭复数。

如果将任意一个算子、算符作用于一个数列,得到的则是一个新的数列,每一项为原数列对应项在算子、算符作用下的结果。

算法原理

假设我们需要对于实多项式函数$A$和 $B$分别做 $dft$,即,求出
$$
\begin{aligned}
F_i & =A(\omega_n^i)=\sum_{k=0}^{n-1}a_k\omega_n^{ki}\\
G_i & =B(\omega_n^i)=\sum_{k=0}^{n-1}b_k\omega_n^{ki}
\end{aligned}
$$
怎么一次完成呢?

有公式如下:

$$
\begin{aligned}
dft(A+iB) _ k & =\sum _ {i=0}^{n-1}(a_i+ib_i)\omega _ n^{ki}=A(\omega _ n^k)+iB(\omega _ n^k)\\
dft(A-iB) _ k & =\sum _ {i=0}^{n-1}(a_i-ib_i)\omega _ n^{ki}=A(\omega _ n^k)-iB(\omega _ n^k)=conj(dft(A+iB) _ {n-k})
\end{aligned}
$$

大家对于上述公式的最后一个等号可能不是很能理解。这里待会会给出证明。

假设上述公式是正确的,我们只需要计算 $dft(A+iB)$,即可还原出 $dft(A-iB)=conj(reverse(dft(A+iB)))$,而:

$$
\begin{aligned}
dft(A)_k & =\frac{dft(A+iB)+dft(A-iB)}{2}\\
\
dft(B)_k & =\frac{dft(A+iB)-dft(A-iB)}{2i}
\end{aligned}
$$

这样,我们就通过一次 $dft$完成之前两次 dft 的工作了。

下面给出上面那个等号的证明。

$$
\begin{aligned}
dft(A+iB) _ {n-k} & =\sum _ {i=0}^{n-1}(a _ i+ib _ i)\omega _ n^{-ki}\\
& =\sum _ {i=0}^{n-1}(a _ i+ib _ i)(\cos \theta-i\sin \theta)\\
& =\sum _ {i=0}^{n-1}(a _ i\cos\theta+b _ i\sin\theta)-i(a _ i\sin\theta-b _ i\cos\theta)\\
& =\sum _ {i=0}^{n-1}conj((a _ i\cos\theta+b _ i\sin\theta)+i(a _ i\sin\theta-b _ i\cos\theta))\\
& =\sum _ {i=0}^{n-1}conj((a _ i-ib _ i)(cos\theta+i\sin\theta))\\
& =\sum _ {i=0}^{n-1}conj((a _ i-ib _ i)\omega _ n^{ki})\\
& =conj(dft(A-iB) _ k)
\end{aligned}
$$

注意,这里的多项式函数的系数必须是实数,因为推导过程中默认 $a_i,b_i$都是实数。

算法实现

在高效地实现 $mtt$时,往往会用到另一个技巧,即 $idft(s)=\frac{1}{n}dft(rev(s))$

另外,运用到一些其他的七里八里的技巧,可以让你的代码更短,但是不会太易懂。

// luogu-judger-enable-o2
#include <bits/stdc++.h>
#define ML 262149
#define pi (acos(-1))
#define M 32768ll

using namespace std;

typedef long long ll;

typedef long double ldb;

void read(ll& x)
{
    x = 0; char c = getchar();
    while(!isdigit(c)) c = getchar();
    while(isdigit(c)) x = x*10+c-'0', c = getchar();
}

ll mod;

namespace fft
{
    struct Z
    {
        ldb r, i;

        Z (const ldb &r0 = 0, const ldb &i0 = 0) : r(r0), i(i0) {}
        Z operator + (const Z& t) const {return Z(r+t.r, i+t.i);}
        Z operator - (const Z& t) const {return Z(r-t.r, i-t.i);}
        Z operator * (const Z& t) const {return Z(r*t.r-i*t.i, r*t.i+i*t.r);}
        Z conj() const {return Z(r, -i);}
        void operator /= (const ldb& t) {r /= t, i /= t;}
    };

    struct Fourier
    {
        int n, bit, rev[ML];

        void init(int x)
        {
            n = 1, bit = 0;
            while(n < x) n <<= 1, bit++;
            for(int i=1; i<n; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(bit-1));
        }

        void dft(Z *x, int f)
        {
            for(int i=0; i<n; i++)
                if(i < rev[i])
                    swap(x[i], x[rev[i]]);
            for(int w=1; w<n; w<<=1)
            {
                for(int i=0; i<n; i+=(w<<1))
                {
                    for(int j=0; j<w; j++)
                    {
                        Z a = x[i+j], b = x[i+j+w] * Z(cos(pi/w*j), f*sin(pi/w*j));;
                        x[i+j] = a + b;
                        x[i+j+w] = a - b;
                    }
                }
            }
            if(f == -1) for(int i=0; i<n; i++) x[i] /= n;
        }
    } F;

    Z Xq[ML], Yq[ML], xlyl[ML], xlyh[ML], xhyl[ML], xhyh[ML];

    void fast_multiply(ll *x, ll *y, ll *ret)
    {
        for(int i=0; i<F.n; i++)
            Xq[i] = Z(x[i]>>15, x[i]&((1<<15)-1)),
            Yq[i] = Z(y[i]>>15, y[i]&((1<<15)-1));
        F.dft(Xq, +1), F.dft(Yq, +1);
        for(int i=0; i<F.n; i++)
        {
            int j = (F.n-i) & (F.n-1);
            Z xh = (Xq[i]+Xq[j].conj()) * Z(0.5, 0);
            Z xl = (Xq[i]-Xq[j].conj()) * Z(0, -0.5);
            Z yh = (Yq[i]+Yq[j].conj()) * Z(0.5, 0);
            Z yl = (Yq[i]-Yq[j].conj()) * Z(0, -0.5);
            xhyh[j] = xh*yh, xhyl[j] = xh*yl, xlyh[j] = xl*yh, xlyl[j] = xl*yl;
        }
        for(int i=0; i<F.n; i++)
            Xq[i] = xhyh[i] + xhyl[i] * Z(0, 1),
            Yq[i] = xlyh[i] + xlyl[i] * Z(0, 1);
        F.dft(Xq, +1), F.dft(Yq, +1);
        for(int i=0; i<F.n; i++)
        {
            ll xhyh = ll(Xq[i].r/F.n + 0.5) % mod;
            ll xhyl = ll(Xq[i].i/F.n + 0.5) % mod;
            ll xlyh = ll(Yq[i].r/F.n + 0.5) % mod;
            ll xlyl = ll(Yq[i].i/F.n + 0.5) % mod;
            ret[i] = ((xhyh<<30) + (xhyl<<15) + (xlyh<<15) + (xlyl)) % mod;
        }
    }
}

ll a[ML], b[ML], c[ML], n, m;

int main()
{
    read(n); read(m); read(mod);
    fft::F.init(n+m+2);
    for(int i=0; i<=n; i++) read(a[i]);
    for(int i=0; i<=m; i++) read(b[i]);
    fft::fast_multiply(a, b, c);
    for(int i=0; i<=n+m; i++) printf("%lld ", c[i]); putchar('\n');
    return 0;
}
分类: 文章

0 条评论

发表回复

Avatar placeholder

您的电子邮箱地址不会被公开。 必填项已用 * 标注