compro_library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub siro53/compro_library

:heavy_check_mark: ローリングハッシュ
(string/rolling-hash.hpp)

Depends on

Required by

Verified with

Code

#pragma once

#include <algorithm>
#include <cassert>
#include <random>
#include <string>
#include <vector>

#include "../modint/modint_2_61.hpp"

struct RollingHash {
    using u64 = unsigned long long;
    using mint = ModInt_2_61;
    static constexpr u64 mod = (1ULL << 61) - 1;
    const u64 base;
    std::vector<mint> hashed, power;

    explicit RollingHash(const std::vector<int> &v, u64 base) : base(base) {
        int n = (int)v.size();
        hashed.assign(n + 1, 0);
        power.assign(n + 1, 0);
        power[0] = 1;
        for(int i = 0; i < n; i++) {
            power[i + 1] = power[i] * base;
            hashed[i + 1] = (hashed[i] * base) + v[i];
        }
    }
    explicit RollingHash(const std::string &s, u64 base) : base(base) {
        int n = (int)s.size();
        hashed.assign(n + 1, 0);
        power.assign(n + 1, 0);
        power[0] = 1;
        for(int i = 0; i < n; i++) {
            power[i + 1] = power[i] * base;
            hashed[i + 1] = (hashed[i] * base) + s[i];
        }
    }
    static inline u64 gen_base() {
        std::random_device seed_gen;
        std::mt19937_64 engine(seed_gen());
        std::uniform_int_distribution<u64> rand(2, mod - 2);
        return rand(engine);
    }
    mint get(int l, int r) {
        assert(0 <= l);
        assert(l <= r);
        assert(r < (int)power.size());
        return (hashed[r] - (hashed[l] * power[r - l]));
    }
    mint connect(mint h1, mint h2, int h2len) {
        return (h1 * power[h2len] + h2);
    }
    int get_lcp(RollingHash &b, int l1, int r1, int l2, int r2) {
        assert(mod == b.mod);
        int len = std::min(r1 - l1, r2 - l2);
        int low = -1, high = len + 1;
        while(high - low > 1) {
            int mid = (low + high) >> 1;
            if(get(l1, l1 + mid) == b.get(l2, l2 + mid)) {
                low = mid;
            } else {
                high = mid;
            }
        }
        return low;
    }
};
#line 2 "string/rolling-hash.hpp"

#include <algorithm>
#include <cassert>
#include <random>
#include <string>
#include <vector>

#line 2 "modint/modint_2_61.hpp"

#include <istream>
#include <utility>

// ローリングハッシュ用 modint
// https://qiita.com/keymoon/items/11fac5627672a6d6a9f6
class ModInt_2_61 {
  public:
    using M = ModInt_2_61;
    ModInt_2_61() : x(0) {}
    ModInt_2_61(long long y) : x(y >= 0 ? y % mod : (mod - (-y) % mod) % mod) {}
    unsigned long long val() const { return x; }
    M &operator+=(const M &m) {
        if((x += m.x) >= mod) x -= mod;
        return *this;
    }
    M &operator-=(const M &m) {
        if((x += mod - m.x) >= mod) x -= mod;
        return *this;
    }
    M &operator*=(const M &m) {
        __uint128_t t = (__uint128_t)x * m.x;
        unsigned long long na = t >> 61;
        unsigned long long nb = t & mod;
        if((na += nb) >= mod) na -= mod;
        x = na;
        return *this;
    }
    M &operator/=(const M &m) {
        *this *= m.inv();
        return *this;
    }
    M operator-() const { return M(-(long long)x); }
    M operator+(const M &m) const { return M(*this) += m; }
    M operator-(const M &m) const { return M(*this) -= m; }
    M operator*(const M &m) const { return M(*this) *= m; }
    M operator/(const M &m) const { return M(*this) /= m; }
    bool operator==(const M &m) const { return (x == m.x); }
    bool operator!=(const M &m) const { return (x != m.x); }
    M inv() const {
        long long a = x, b = mod, u = 1, v = 0, t;
        while(b > 0) {
            t = a / b;
            std::swap(a -= t * b, b);
            std::swap(u -= t * v, v);
        }
        return M(u);
    }
    M pow(unsigned long long n) const {
        M ret(1), mul(x);
        while(n > 0) {
            if(n & 1) ret *= mul;
            mul *= mul;
            n >>= 1;
        }
        return ret;
    }
    friend std::ostream &operator<<(std::ostream &os, const M &p) {
        return os << p.x;
    }
    friend std::istream &operator>>(std::istream &is, M &a) {
        long long t;
        is >> t;
        a = M(t);
        return (is);
    }
    static constexpr unsigned long long get_mod() { return mod; }

  private:
    unsigned long long x;
    static constexpr unsigned long long mod = (1LL << 61) - 1;
};
#line 10 "string/rolling-hash.hpp"

struct RollingHash {
    using u64 = unsigned long long;
    using mint = ModInt_2_61;
    static constexpr u64 mod = (1ULL << 61) - 1;
    const u64 base;
    std::vector<mint> hashed, power;

    explicit RollingHash(const std::vector<int> &v, u64 base) : base(base) {
        int n = (int)v.size();
        hashed.assign(n + 1, 0);
        power.assign(n + 1, 0);
        power[0] = 1;
        for(int i = 0; i < n; i++) {
            power[i + 1] = power[i] * base;
            hashed[i + 1] = (hashed[i] * base) + v[i];
        }
    }
    explicit RollingHash(const std::string &s, u64 base) : base(base) {
        int n = (int)s.size();
        hashed.assign(n + 1, 0);
        power.assign(n + 1, 0);
        power[0] = 1;
        for(int i = 0; i < n; i++) {
            power[i + 1] = power[i] * base;
            hashed[i + 1] = (hashed[i] * base) + s[i];
        }
    }
    static inline u64 gen_base() {
        std::random_device seed_gen;
        std::mt19937_64 engine(seed_gen());
        std::uniform_int_distribution<u64> rand(2, mod - 2);
        return rand(engine);
    }
    mint get(int l, int r) {
        assert(0 <= l);
        assert(l <= r);
        assert(r < (int)power.size());
        return (hashed[r] - (hashed[l] * power[r - l]));
    }
    mint connect(mint h1, mint h2, int h2len) {
        return (h1 * power[h2len] + h2);
    }
    int get_lcp(RollingHash &b, int l1, int r1, int l2, int r2) {
        assert(mod == b.mod);
        int len = std::min(r1 - l1, r2 - l2);
        int low = -1, high = len + 1;
        while(high - low > 1) {
            int mid = (low + high) >> 1;
            if(get(l1, l1 + mid) == b.get(l2, l2 + mid)) {
                low = mid;
            } else {
                high = mid;
            }
        }
        return low;
    }
};
Back to top page