Siyuan

「算法笔记」点分治
「算法笔记」点分治
扫描右侧二维码阅读全文
21
2019/03

「算法笔记」点分治

点分治是一种针对带权树上简单路径统计问题的算法,利用树的重心性质来优化复杂度。


思路

首先我们把这棵树钦定一个根,假设为 $x$,再将这棵树上的所有简单路径分为两个部分:

  • 经过 $x$ 的简单路径。
  • 不经过 $x$ 的简单路径(在 $x$ 的某棵子树内)。

又发现不经过 $x$ 的简单路径可以递归到子树内的点计算,于是我们对于 $x$ 只需要计算经过它的路径即可。

但是我们怎么选择这个分治的点呢?我们的复杂度瓶颈在递归次数上,如果随便选择的话,对于一条链,我们每次递归下一个点,复杂度轻松被卡到 $\mathcal O(n^2)$,这是完全不能接受的。

考虑树上有哪个点的性质很优秀?很容易想到重心,因为重心的子树大小都不超过 $\frac{size}{2}$,其中 $size$ 为总大小。

我们每次递归处理当前子树的重心 $x$,得到以 $x$ 为根的树的所有节点的信息。然后统计答案。

最后证明一下复杂度:由于重心的性质,我们可以发现每次递归子树大小至少缩小一半,于是递归次数为 $\mathcal O(\log n)$!


实现

点分治统计答案时,大致有如下 $2$ 种实现方法:

  1. 考虑可以走过重复点,我们使用容斥:用点 $x$ 的答案减去其每个儿子的答案(这里的答案都考虑重复经过点)。
  2. 假设 $x$ 有儿子 $v_1, v_2, \cdots, v_k$,那么我们用子树 $v_1, v_2, \cdots, v_{i - 1}$ 的答案来更新 $v_i$ 的答案。计算完要清空答案。

发表一下个人看法:建议使用第二种方法。因为第一种方法不但常数较大,并且很容易有漏算、多算的情况;而第二种方法是严格正确的。


代码

我们以「Luogu 3806」【模板】点分治 为例。

方法 1

#include <cstdio>
#include <algorithm>

const int N = 1e4 + 5, M = 2e4 + 5, Q = 1e2 + 5;
const int INF = 0x7f7f7f7f;

int n, m, tot, lnk[N], ter[M], nxt[M], val[M], q[Q];
int root, sum, tp, s[N], dis[N], sz[N], mx[N];
bool vis[N], ans[Q];

void add(int u, int v, int w) {
    ter[++tot] = v, nxt[tot] = lnk[u], lnk[u] = tot, val[tot] = w;
}
void getRoot(int u, int p) {
    sz[u] = 1, mx[u] = 0;
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (v == p || vis[v]) continue;
        getRoot(v, u);
        sz[u] += sz[v];
        mx[u] = std::max(mx[u], sz[v]);
    }
    mx[u] = std::max(mx[u], sum - sz[u]);
    if (mx[root] > mx[u]) root = u;
}
void getDis(int u, int p, int d) {
    s[++tp] = dis[u] = d;
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (v == p || vis[v]) continue;
        getDis(v, u, d + val[i]);
    }
}
void solve(int u, int d, int sgn) {
    tp = 0, getDis(u, 0, d);
    std::sort(s + 1, s + tp + 1);
    for (int i = 1; i <= tp; i++) {
        for (int j = 1; j <= m; j++) {
            int k = q[j];
            int l = std::lower_bound(s + i + 1, s + tp + 1, k - s[i]) - s;
            int r = std::upper_bound(s + i + 1, s + tp + 1, k - s[i]) - s - 1;
            if (r >= l) ans[j] += sgn * (r - l + 1);
        }
    }
}
void divide(int u) {
    vis[u] = 1, solve(u, 0, 1);
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (vis[v]) continue;
        solve(v, val[i], -1);
        root=0, sum = sz[u], getRoot(v, u);
        divide(root);
    }
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i < n; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        add(u, v, w), add(v, u, w);
    }
    for (int i = 1; i <= m; i++) {
        scanf("%d", &q[i]);
    }
    mx[0] = INF;
    root = 0, sum = n, getRoot(1, 0);
    divide(root);
    for (int i = 1; i <= m; i++) {
        puts(ans[i] ? "AYE" : "NAY");
    }
    return 0;
}

方法 2

#include <cstdio>
#include <algorithm>

const int N = 1e4 + 5, M = 2e4 + 5, Q = 1e2 + 5, K = 1e7 + 5;
const int INF = 0x7f7f7f7f;

int n, m, tot, lnk[N], ter[M], nxt[M], val[M], q[N];
int root, sum, tp, s[N], t[N], dis[N], sz[N], mx[N];
bool f[K], vis[N], ans[Q];

void add(int u, int v, int w) {
    ter[++tot] = v, nxt[tot] = lnk[u], lnk[u] = tot, val[tot] = w;
}
void getRoot(int u, int p) {
    sz[u] = 1, mx[u] = 0;
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (v == p || vis[v]) continue;
        getRoot(v, u);
        sz[u] += sz[v];
        mx[u] = std::max(mx[u], sz[v]);
    }
    mx[u] = std::max(mx[u], sum - sz[u]);
    if (mx[root] > mx[u]) root = u;
}
void getDis(int u, int p, int d) {
    s[++tp] = dis[u] = d;
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (v == p || vis[v]) continue;
        getDis(v, u, d + val[i]);
    }
}
void solve(int u) {
    int cnt = 0;
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (vis[v]) continue;
        tp = 0, getDis(v, u, val[i]);
        for (int j = 1; j <= tp; j++) {
            for (int k = 1; k <= m; k++) {
                if (q[k] >= s[j]) ans[k] |= f[q[k] - s[j]];
            }
        }
        for (int j = 1; j <= tp; j++) {
            f[t[++cnt] = s[j]] = 1;
        }
    }
    for (int i = 1; i <= cnt; i++) f[t[i]]=0;
}
void divide(int u) {
    vis[u] = f[0] = 1;
    solve(u);
    for (int i = lnk[u]; i; i = nxt[i]) {
        int v = ter[i];
        if (vis[v]) continue;
        sum = sz[v], root = 0, getRoot(v, 0);
        divide(root);
    }
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1; i < n; i++) {
        int u, v, w;
        scanf("%d%d%d", &u, &v, &w);
        add(u, v, w), add(v, u, w);
    }
    for (int i = 1; i <= m; i++) {
        scanf("%d", &q[i]);
    }
    mx[0] = INF;
    sum = n, root = 0, getRoot(1,0);
    divide(root);
    for (int i = 1; i <= m; i++) {
        puts(ans[i] ? "AYE" : "NAY");
    }
    return 0;
}

习题

发表评论