Siyuan

「Codeforces 438E」The Child and Binary Tree
「Codeforces 438E」The Child and Binary Tree
扫描右侧二维码阅读全文
27
2019/06

「Codeforces 438E」The Child and Binary Tree

题目链接:Codeforces 438E

我们的小朋友很喜欢计算机科学,尤其喜欢二叉树。

考虑一个含有 $n$ 个互不相同的正整数序列 $c_1, c_2, \dots, c_n$。如果一棵带点权有根二叉树满足其所有节点的权值都属于集合 $\{c_1, c_2, \dots, c_n\}$ 中,那么小朋友就会将其称作「好的」。并且他认为,这棵二叉树的权值是所有节点的权值总和。

给出一个整数 $m$,你需要对于所有整数 $s \in [1, m]$,计算出权值为 $s$ 的「好的」二叉树数量。答案对 $998244353$ 取模。

数据范围:$1 \le n, m, c_i \le 10 ^ 5$。


Solution

构造 $g(x) = [x \in \{c_1, c_2, \dots, c_n\}]$,设 $f(x)$ 表示权值为 $x$ 的二叉树数量,则有递推式:

$$ f(n) = \begin{cases} 1 & n = 0 \\ \sum_{i = 0} ^ n g(i) \sum_{j = 0} ^ {n - i} f(j) \cdot f(n - i - j) & n > 0 \end{cases} $$

发现在 $n > 0$ 的情况下,$f(x)$ 的值就是一个卷机形式,那么我们考虑 $g, f$ 的生成函数:

$$ G(x) = \sum_{i = 0} ^ m [i \in \{c_1, c_2, \dots, c_n\}] \\ F = 1 + G \ast F \ast F $$

利用求根公式得到:

$$ F = \frac{1 \pm \sqrt{1 - 4G}}{2G} $$

由于 $G(0) = 0$ 使得分母 $2G$ 无法直接求逆,化为:

$$ F = \frac{2}{1 \mp \sqrt{1 - 4G}} $$

通过 $G(0) = 0, F(0) = 1$ 可以得到上式中 $\mp$ 应该取 $+$ 号。

套上多项式开根和求逆就行了。

时间复杂度:$\mathcal O(n + m \log m)$。


Code

#include <cstdio>
#include <cmath>
#include <algorithm>
#include <vector>
#include <unordered_map>
#include <cassert>

typedef std::vector<int> Vec;

const int MOD = 998244353, G = 3;

void add(int &x, int y) {
    (x += y) >= MOD && (x -= MOD);
}
void sub(int &x, int y) {
    (x -= y) < 0 && (x += MOD);
}
int add(int x) {
    return x >= MOD ? x - MOD : x;
}
int sub(int x) {
    return x < 0 ? x + MOD : x;
}
void mod(int &x) {
    x >= MOD && (x -= MOD), x < 0 && (x += MOD);
}
int mul(int x, int y) {
    return 1LL * x * y % MOD;
}
int pow(int x, int p) {
    int ans = 1;
    for (; p; p >>= 1, x = 1LL * x * x % MOD) {
        if (p & 1) ans = 1LL * ans * x % MOD;
    }
    return ans;
}
int inv(int x) {
    return pow(x, MOD - 2);
}
int BSGS(int a, int b, int p) {
    std::unordered_map<int, int> mp;
    int m = ceil(::sqrt(p));
    for (int i = 0; i <= m; b = 1LL * b * a % p, i++) mp[b] = i;
    a = pow(a, m);
    for (int i = 0, j = 1; i < m; j = 1LL * j * a % p, i++) {
        if (mp.count(j) && i * m >= mp[j]) {
            return i * m - mp[j];
        }
    }
    return -1;
}
int degree(int a, int k, int p) {
    int x = BSGS(G, a, p);
    assert(x >= 0 && x % k == 0);
    int r = pow(G, x / k);
    return std::min(r, p - r);
}

namespace FFT {
    int extend(int x);
    void NTT(Vec &A, bool opt);
    void DFT(Vec &A);
    void IDFT(Vec &A);

    int extend(int x) {
        int n = 1;
        for (; n < x; n <<= 1);
        return n;
    }
    void NTT(Vec &A, bool opt) {
        int n = A.size(), k = 0;
        for (; (1 << k) < n; k++);
        Vec rev(n);
        for (int i = 0; i < n; i++) {
            rev[i] = rev[i >> 1] >> 1 | (i & 1) << (k - 1);
            if (i < rev[i]) std::swap(A[i], A[rev[i]]);
        }
        for (int l = 2; l <= n; l <<= 1) {
            int m = l >> 1, w = pow(G, (MOD - 1) / l);
            if (opt) w = inv(w);
            for (int j = 0; j < n; j += l) {
                int wk = 1;
                for (int i = 0; i < m; i++, wk = 1LL * wk * w % MOD) {
                    int p = A[i + j], q = 1LL * wk * A[i + j + m] % MOD;
                    A[i + j] = (p + q) % MOD;
                    A[i + j + m] = (p - q + MOD) % MOD;
                }
            }
        }
    }
    void DFT(Vec &A) {
        NTT(A, false);
    }
    void IDFT(Vec &A) {
        NTT(A, true);
        int t = inv(A.size());
        for (auto &x : A) x = 1LL * x * t % MOD;
    }
}
using namespace FFT;

namespace Poly {
    Vec operator + (Vec A, Vec B);
    Vec operator + (Vec A, int v);
    Vec operator - (Vec A, Vec B);
    Vec operator - (Vec A, int v);
    Vec operator - (Vec A);
    Vec operator * (Vec A, Vec B);
    Vec operator * (Vec A, int v);
    Vec operator / (Vec A, Vec B);
    Vec operator / (Vec A, int v);
    Vec operator % (Vec A, Vec B);
    Vec operator ~ (Vec A);
    Vec operator ^ (Vec A, int k);
    Vec operator << (Vec A, int x);
    Vec operator >> (Vec A, int x);
    Vec fix(Vec A, int n);
    Vec der(Vec A);
    Vec inte(Vec A);
    Vec sqrt(Vec A);
    Vec root(Vec A, int k);
    Vec ln(Vec A);
    Vec exp(Vec A);
    Vec sin(Vec A);
    Vec cos(Vec A);
    void print(Vec A);

    Vec operator + (Vec A, Vec B) {
        int n = std::max(A.size(), B.size());
        A.resize(n), B.resize(n);
        for (int i = 0; i < n; i++) add(A[i], B[i]);
        return A;
    }
    Vec operator + (Vec A, int v) {
        add(A[0], v);
        return A;
    }
    Vec operator - (Vec A, Vec B) {
        int n = std::max(A.size(), B.size());
        A.resize(n), B.resize(n);
        for (int i = 0; i < n; i++) sub(A[i], B[i]);
        return A;
    }
    Vec operator - (Vec A, int v) {
        sub(A[0], v);
        return A;
    }
    Vec operator - (Vec A) {
        for (auto &x : A) x = sub(-x);
        return A;
    }
    Vec operator * (Vec A, Vec B) {
        int n = A.size() + B.size() - 1, N = extend(n);
        A.resize(N), DFT(A);
        B.resize(N), DFT(B);
        for (int i = 0; i < N; i++) A[i] = mul(A[i], B[i]);
        IDFT(A), A.resize(n);
        return A;
    }
    Vec operator * (Vec A, int v) {
        for (auto &x : A) x = mul(x, v);
        return A;
    }
    Vec operator / (Vec A, Vec B) {
        int n = A.size() - B.size() + 1;
        if (n <= 0) return Vec(1, 0);
        std::reverse(A.begin(), A.end());
        std::reverse(B.begin(), B.end());
        A.resize(n), B.resize(n);
        A = A * ~B;
        A.resize(n);
        std::reverse(A.begin(), A.end());
        return A;
    }
    Vec operator / (Vec A, int v) {
        return A * inv(v);
    }
    Vec operator % (Vec A, Vec B) {
        int n = B.size() - 1;
        A = A - A / B * B;
        A.resize(n);
        return A;
    }
    Vec operator ~ (Vec A) {
        int n = A.size(), N = extend(n);
        A.resize(N);
        Vec I(N, 0);
        I[0] = inv(A[0]);
        for (int l = 2; l <= N; l <<= 1) {
            int t = l << 1;
            Vec P(t, 0), Q(t, 0);
            std::copy(A.begin(), A.begin() + l, P.begin());
            std::copy(I.begin(), I.begin() + l, Q.begin());
            DFT(P), DFT(Q);
            for (int i = 0; i < t; i++) {
                P[i] = 1LL * Q[i] * (2 - 1LL * P[i] * Q[i] % MOD + MOD) % MOD;
            }
            IDFT(P);
            std::copy(P.begin(), P.begin() + l, I.begin());
        }
        I.resize(n);
        return I;
    }
    Vec operator ^ (Vec A, int k) {
        int n = A.size(), x = 0;
        for (; x < n && A[x] == 0; x++);
        if (1LL * x * k >= n) return Vec(n, 0);
        A = A >> x;
        int v = A[0];
        A = A / v;
        A = exp(ln(A) * k) * pow(v, k);
        return A << (x * k);
    }
    Vec operator << (Vec A, int x) {
        int n = A.size();
        Vec B(n, 0);
        for (int i = 0; i < n - x; i++) B[i + x] = A[i];
        return B;
    }
    Vec operator >> (Vec A, int x) {
        int n = A.size();
        Vec B(n, 0);
        for (int i = 0; i < n - x; i++) B[i] = A[i + x];
        return B;
    }
    Vec fix(Vec A, int n) {
        A.resize(n);
        return A;
    }
    Vec der(Vec A) {
        int n = A.size();
        if (n == 1) return Vec(1, 0);
        Vec D(n - 1, 0);
        for (int i = 1; i < n; i++) D[i - 1] = mul(i, A[i]);
        return D;
    }
    Vec inte(Vec A) {
        int n = A.size();
        Vec I(n + 1, 0);
        for (int i = 1; i <= n; i++) I[i] = mul(inv(i), A[i - 1]);
        return I;
    }
    Vec sqrt(Vec A) {
        int n = A.size(), N = extend(n);
        A.resize(N);
        Vec R(N, 0);
        R[0] = degree(A[0], 2, MOD);
        int i2 = inv(2);
        for (int l = 2; l <= N; l <<= 1) {
            int t = l << 1;
            Vec P(t, 0), Q(t, 0);
            std::copy(A.begin(), A.begin() + l, P.begin());
            std::copy(R.begin(), R.begin() + l, Q.begin());
            Vec I = ~fix(Q, l);
            I.resize(t);
            DFT(P), DFT(Q), DFT(I);
            for (int i = 0; i < t; i++) {
                P[i] = 1LL * (P[i] + mul(Q[i], Q[i])) * i2 % MOD * I[i] % MOD;
            }
            IDFT(P);
            std::copy(P.begin(), P.begin() + l, R.begin());
        }
        R.resize(n);
        return R;
    }
    Vec root(Vec A, int k) {
        return k == 1 ? A : k == 2 ? sqrt(A) : exp(ln(A) / k);
    }
    Vec ln(Vec A) {
        assert(A[0] == 1);
        int n = A.size();
        A = inte(der(A) * ~A);
        A.resize(n);
        return A;
    }
    Vec exp(Vec A) {
        assert(A[0] == 0);
        int n = A.size(), N = extend(n);
        A.resize(N);
        Vec E(N, 0);
        E[0] = 1;
        for (int l = 2; l <= N; l <<= 1) {
            Vec P = (-ln(fix(E, l)) + fix(A, l) + 1) * fix(E, l);
            std::copy(P.begin(), P.begin() + l, E.begin());
        }
        E.resize(n);
        return E;
    }
    Vec sin(Vec A) {
        int i = degree(MOD - 1, 2, MOD);
        Vec E = exp(A * i);
        return (E - ~E) / (2LL * i % MOD);
    }
    Vec cos(Vec A) {
        int i = degree(MOD - 1, 2, MOD);
        Vec E = exp(A * i);
        return (E + ~E) / 2;
    }
    void print(Vec A) {
        int n = A.size();
        for (int i = 0; i < n; i++) printf("%d%c", A[i], " \n"[i == n - 1]);
    }
}
using namespace Poly;

int main() {
    int n, m;
    scanf("%d%d", &n, &m);
    Vec G(m + 1, 0);
    for (int i = 1; i <= n; i++) {
        int x;
        scanf("%d", &x);
        if (x <= m) G[x] = 1;
    }
    Vec F = ~(sqrt(-G * 4 + 1) + 1) * 2;
    for (int i = 1; i <= m; i++) printf("%d\n", F[i]);
    return 0;
}
最后修改:2019 年 06 月 27 日 09 : 26 AM

发表评论