compro_library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub siro53/compro_library

:heavy_check_mark: math/convolution/ntt.hpp

Depends on

Verified with

Code

#pragma once

#include <array>
#include <utility>
#include <vector>

#include "../../modint/modint.hpp"
#include "../primitive-root.hpp"

namespace ntt {
    constexpr int exp_limit(int mod) { return __builtin_ctz(mod - 1); }

    template <class mint, int mod = mint::get_mod(),
              int g = primitive_root(mod)>
    struct ntt_info {
        static constexpr int limit = exp_limit(mod);
        std::array<mint, limit + 1> root;
        std::array<mint, limit + 1> iroot;

        ntt_info() {
            root[limit] = mint(g).pow((mod - 1) >> limit);
            iroot[limit] = root[limit].inv();
            for(int i = limit - 1; i >= 0; i--) {
                root[i] = root[i + 1] * root[i + 1];
                iroot[i] = iroot[i + 1] * iroot[i + 1];
            }
        }
    };

    inline uint64_t revbit(uint64_t mask) {
        int b = __builtin_clzll(mask);
        uint64_t x = mask;
        x = (x >> 32) | (x << 32);
        x = ((x >> 16) & 0x0000FFFF0000FFFF) | ((x << 16) & 0xFFFF0000FFFF0000);
        x = ((x >> 8) & 0x00FF00FF00FF00FF) | ((x << 8) & 0xFF00FF00FF00FF00);
        x = ((x >> 4) & 0x0F0F0F0F0F0F0F0F) | ((x << 4) & 0xF0F0F0F0F0F0F0F0);
        x = ((x >> 2) & 0x3333333333333333) | ((x << 2) & 0xCCCCCCCCCCCCCCCC);
        x = ((x >> 1) & 0x5555555555555555) | ((x << 1) & 0xAAAAAAAAAAAAAAAA);
        return (x >> b);
    }

    template<class mint>
    void ntt(std::vector<mint>& a) {
        int n = (int)a.size();
        int bitlen = __builtin_ctz(n);
        static const ntt_info<mint> info;
        for(int len = n, b = bitlen; len > 1; len >>= 1, b--) {
            for(int i = 0; i < n; i += len) {
                int t = len >> 1;
                mint wj = 1;
                for(int j = 0; j < t; j++) {
                    int p = i + j;
                    mint l = a[p] + a[p + t];
                    mint r = (a[p] - a[p + t]) * wj;
                    a[p] = l, a[p + t] = r;
                    wj *= info.root[b];
                }
            }
        }
        for(int i = 0; i < n; i++) {
            int j = revbit(i);
            if(i < j) std::swap(a[i], a[j]);
        }
    }

    template<class mint>
    void intt(std::vector<mint>& a) {
        int n = (int)a.size();
        static const ntt_info<mint> info;
        for(int i = 0; i < n; i++) {
            int j = revbit(i);
            if(i < j) std::swap(a[i], a[j]);
        }
        for(int len = 2, b = 1; len <= n; len <<= 1, b++) {
            for(int i = 0; i < n; i += len) {
                int t = len >> 1;
                mint wj = 1;
                for(int j = 0; j < t; j++) {
                    int p = i + j;
                    auto l = a[p] + a[p + t] * wj;
                    auto r = a[p] - a[p + t] * wj;
                    a[p] = l, a[p + t] = r;
                    wj *= info.iroot[b];
                }
            }
        }
        mint invn = mint(n).inv();
        for(int i = 0; i < n; i++) a[i] *= invn;
    }

    template<class mint>
    std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
        int m = (int)a.size() + (int)b.size() - 1;
        int n = 1;
        while(n < m) n <<= 1;
        a.resize(n), b.resize(n);
        ntt<mint>(a);
        ntt<mint>(b);
        for(int i = 0; i < n; i++) a[i] *= b[i];
        intt<mint>(a);
        a.resize(m);
        return a;
    }
}; // namespace ntt
#line 2 "math/convolution/ntt.hpp"

#include <array>
#include <utility>
#include <vector>

#line 2 "modint/modint.hpp"

#include <istream>
#include <ostream>
#line 6 "modint/modint.hpp"

template <int mod> class ModInt {
  public:
    ModInt() : x(0) {}
    ModInt(long long y)
        : x(y >= 0 ? y % umod() : (umod() - (-y) % umod()) % umod()) {}
    unsigned int val() const { return x; }
    ModInt &operator+=(const ModInt &p) {
        if((x += p.x) >= umod()) x -= umod();
        return *this;
    }
    ModInt &operator-=(const ModInt &p) {
        if((x += umod() - p.x) >= umod()) x -= umod();
        return *this;
    }
    ModInt &operator*=(const ModInt &p) {
        x = (unsigned int)(1ULL * x * p.x % umod());
        return *this;
    }
    ModInt &operator/=(const ModInt &p) {
        *this *= p.inv();
        return *this;
    }
    ModInt operator-() const { return ModInt(-(long long)x); }
    ModInt operator+(const ModInt &p) const { return ModInt(*this) += p; }
    ModInt operator-(const ModInt &p) const { return ModInt(*this) -= p; }
    ModInt operator*(const ModInt &p) const { return ModInt(*this) *= p; }
    ModInt operator/(const ModInt &p) const { return ModInt(*this) /= p; }
    bool operator==(const ModInt &p) const { return x == p.x; }
    bool operator!=(const ModInt &p) const { return x != p.x; }
    ModInt inv() const {
        long long a = x, b = mod, u = 1, v = 0, t;
        while(b > 0) {
            t = a / b;
            std::swap(a -= t * b, b);
            std::swap(u -= t * v, v);
        }
        return ModInt(u);
    }
    ModInt pow(unsigned long long n) const {
        ModInt ret(1), mul(x);
        while(n > 0) {
            if(n & 1) ret *= mul;
            mul *= mul;
            n >>= 1;
        }
        return ret;
    }
    friend std::ostream &operator<<(std::ostream &os, const ModInt &p) {
        return os << p.x;
    }
    friend std::istream &operator>>(std::istream &is, ModInt &a) {
        long long t;
        is >> t;
        a = ModInt<mod>(t);
        return (is);
    }
    static constexpr int get_mod() { return mod; }

  private:
    unsigned int x;
    static constexpr unsigned int umod() { return mod; }
};
#line 2 "math/primitive-root.hpp"

#line 2 "math/pow_mod.hpp"

constexpr long long pow_mod(long long x, long long k, long long m) {
    long long res = 1;
    long long mul = (x >= 0 ? x % m : x % m + m);
    while(k) {
        if(k & 1) res = (__int128_t)res * mul % m;
        mul = (__int128_t)mul * mul % m;
        k >>= 1;
    }
    return res;
}
#line 4 "math/primitive-root.hpp"

constexpr int primitive_root(int p) {
    if(p == 2) return 1;
    if(p == 998244353) return 3;
    int primes[31] = {};
    int sz = 0, t = p - 1;
    for(int i = 2; i * i <= t; i++) {
        if(t % i == 0) {
            primes[sz++] = i;
            while(t % i == 0) t /= i;
        }
    }
    if(t > 1) primes[sz++] = t;
    for(int g = 2;;g++) {
        bool f = true;
        for(int i = 0; i < sz; i++) {
            if(pow_mod(g, (p - 1) / primes[i], p) == 1) {
                f = false;
                break;
            }
        }   
        if(f) return g;
    }
}
#line 9 "math/convolution/ntt.hpp"

namespace ntt {
    constexpr int exp_limit(int mod) { return __builtin_ctz(mod - 1); }

    template <class mint, int mod = mint::get_mod(),
              int g = primitive_root(mod)>
    struct ntt_info {
        static constexpr int limit = exp_limit(mod);
        std::array<mint, limit + 1> root;
        std::array<mint, limit + 1> iroot;

        ntt_info() {
            root[limit] = mint(g).pow((mod - 1) >> limit);
            iroot[limit] = root[limit].inv();
            for(int i = limit - 1; i >= 0; i--) {
                root[i] = root[i + 1] * root[i + 1];
                iroot[i] = iroot[i + 1] * iroot[i + 1];
            }
        }
    };

    inline uint64_t revbit(uint64_t mask) {
        int b = __builtin_clzll(mask);
        uint64_t x = mask;
        x = (x >> 32) | (x << 32);
        x = ((x >> 16) & 0x0000FFFF0000FFFF) | ((x << 16) & 0xFFFF0000FFFF0000);
        x = ((x >> 8) & 0x00FF00FF00FF00FF) | ((x << 8) & 0xFF00FF00FF00FF00);
        x = ((x >> 4) & 0x0F0F0F0F0F0F0F0F) | ((x << 4) & 0xF0F0F0F0F0F0F0F0);
        x = ((x >> 2) & 0x3333333333333333) | ((x << 2) & 0xCCCCCCCCCCCCCCCC);
        x = ((x >> 1) & 0x5555555555555555) | ((x << 1) & 0xAAAAAAAAAAAAAAAA);
        return (x >> b);
    }

    template<class mint>
    void ntt(std::vector<mint>& a) {
        int n = (int)a.size();
        int bitlen = __builtin_ctz(n);
        static const ntt_info<mint> info;
        for(int len = n, b = bitlen; len > 1; len >>= 1, b--) {
            for(int i = 0; i < n; i += len) {
                int t = len >> 1;
                mint wj = 1;
                for(int j = 0; j < t; j++) {
                    int p = i + j;
                    mint l = a[p] + a[p + t];
                    mint r = (a[p] - a[p + t]) * wj;
                    a[p] = l, a[p + t] = r;
                    wj *= info.root[b];
                }
            }
        }
        for(int i = 0; i < n; i++) {
            int j = revbit(i);
            if(i < j) std::swap(a[i], a[j]);
        }
    }

    template<class mint>
    void intt(std::vector<mint>& a) {
        int n = (int)a.size();
        static const ntt_info<mint> info;
        for(int i = 0; i < n; i++) {
            int j = revbit(i);
            if(i < j) std::swap(a[i], a[j]);
        }
        for(int len = 2, b = 1; len <= n; len <<= 1, b++) {
            for(int i = 0; i < n; i += len) {
                int t = len >> 1;
                mint wj = 1;
                for(int j = 0; j < t; j++) {
                    int p = i + j;
                    auto l = a[p] + a[p + t] * wj;
                    auto r = a[p] - a[p + t] * wj;
                    a[p] = l, a[p + t] = r;
                    wj *= info.iroot[b];
                }
            }
        }
        mint invn = mint(n).inv();
        for(int i = 0; i < n; i++) a[i] *= invn;
    }

    template<class mint>
    std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
        int m = (int)a.size() + (int)b.size() - 1;
        int n = 1;
        while(n < m) n <<= 1;
        a.resize(n), b.resize(n);
        ntt<mint>(a);
        ntt<mint>(b);
        for(int i = 0; i < n; i++) a[i] *= b[i];
        intt<mint>(a);
        a.resize(m);
        return a;
    }
}; // namespace ntt
Back to top page