跳转至

树上 GCD

题意

给定一棵有根树,定义 d(u,v) 为从 uv 的简单路径包含的边数。令 a=\operatorname{LCA}(u,v),f(u,v)=\gcd(d(u, a), d(a, v)),求满足 f(u,v)=i 的数对 (u,v) 个数。

解析

只要求出 i\mid d(u,a)i\mid d(a,v) 的数对个数,我们就可以通过容斥(莫比乌斯反演公式)解决这个问题。

考虑点分治。

设当前处理所有经过 c 的路径 (u,v),且当前树根为 r。一般的点分治是在无根树上的,而对于有根树点分治我们要讨论两种情况:

c 子树的高度为 Hu,v 同在 c 子树中时,只需暴力合并深度为 d=1..H 倍数的路径,合并单棵子树的复杂度为调和级数 O(H\log H)

uc 子树中,v 不在 c 子树中时,我们枚举 uv 的最近公共祖先 a。对于长度为 d 倍数的路径,考虑根号分治。若 d\le \sqrt{H},则利用一个数组 F[d][i] 存下深度除以 d 的余数为 i 的点的个数,建立 F 数组总复杂度为 O(H\sqrt H),对于单个 d 可以 O(1) 查询;若 d\ge \sqrt H,则每次需要时暴力在桶里查,对于单个 d 可以 O(\sqrt H) 查询。

单层的复杂度为 O(n\sqrt n),运用主定理,总复杂度也是 O(n\sqrt n)。看似常数爆炸,实际上常数并不大,可以通过该题。

实现

#include <bits/stdc++.h>

typedef long long int ll;
const int maxn = 2e5 + 19, maxb = 4e2 + 19;

struct Edge{
    int to, next;
}edge[maxn << 1];

int head[maxn];

inline void add(int from, int to){
    edge[++head[0]] = (Edge){to, head[from]};
    head[from] = head[0];
}

int n;
bool vist[maxn];
ll ans[maxn];

int fa[maxn], dep[maxn], size[maxn], mega[maxn], line[maxn];
int st[maxn], top;
int H, B, D, maxh, subh;
int cnt[maxn], tmp[maxn];
int dp[maxb][maxb], box[maxn];

void dfs1(int node){
    dep[node] = dep[fa[node]] + 1, size[node] = 1, mega[node] = 0;
    st[++top] = node, H = std::max(H, dep[node]);
    for(int i = head[node]; i; i = edge[i].next)
        if(!vist[edge[i].to] && edge[i].to != fa[node]){
            dfs1(edge[i].to);
            if(size[edge[i].to] > mega[node]) mega[node] = size[edge[i].to];
            size[node] += size[edge[i].to];
        }
}

void dfs2(int node){
    ++tmp[dep[node] - D], subh = std::max(subh, dep[node] - D), ++box[dep[node]];
    for(int i = 1; i <= B; ++i)
        ++dp[i][dep[node] % i];
    for(int i = head[node]; i; i = edge[i].next)
        if(!vist[edge[i].to] && edge[i].to != fa[node])
            dfs2(edge[i].to);
}

void dfs3(int node){
    ++cnt[dep[node] - D], maxh = std::max(maxh, dep[node] - D);
    for(int i = head[node]; i; i = edge[i].next)
        if(!vist[edge[i].to] && edge[i].to != fa[node])
            dfs3(edge[i].to);
}

void solve(int node, int root){
    top = H = 0, dep[fa[root]] = 0, dfs1(root), B = std::pow(H, 0.319);
    for(int i = 1; i <= top; ++i)
        if(std::max(mega[st[i]], size[root] - size[st[i]]) <= size[root] / 2){
            node = st[i];
            break;
        }
    D = dep[node];

    maxh = 0, ++box[dep[node]];
    for(int i = 1; i <= B; ++i) ++dp[i][dep[node] % i];
    for(int i = head[node]; i; i = edge[i].next)
        if(!vist[edge[i].to] && edge[i].to != fa[node]){
            subh = 0, dfs2(edge[i].to);
            for(int d = subh; d; --d){
                int a = 0;
                for(int j = 1; j * d <= subh; ++j) a += tmp[j * d];
                ans[d] += (ll)a * cnt[d], cnt[d] += a;
            }
            for(int j = 1; j <= subh; ++j) tmp[j] = 0;
            maxh = std::max(maxh, subh);
        }
    for(int i = 1; i <= maxh; ++i) cnt[i] = 0;

    int a = node, p;
    while(a != root){
        p = a, a = fa[a], maxh = 0, D = dep[a];
        for(int i = head[a]; i; i = edge[i].next)
            if(!vist[edge[i].to] && edge[i].to != fa[a] && edge[i].to != p)
                dfs3(edge[i].to);
        for(int d = std::min(maxh, B); d; --d){
            int b = 0;
            for(int j = 1; j * d <= maxh; ++j) b += cnt[j * d];
            ans[d] += (ll)b * dp[d][dep[a] % d];
        }
        for(int d = std::min(maxh, B) + 1; d <= maxh; ++d){
            int x = 0, y = 0;
            for(int j = dep[a]; j <= H; j += d) x += box[j];
            for(int j = 1; j * d <= maxh; ++j) y += cnt[j * d];
            ans[d] += (ll)x * y;
        }
        for(int i = 1; i <= maxh; ++i) cnt[i] = 0;
    }

    for(int i = 1; i <= H; ++i) box[i] = 0;

    for(int i = 1; i <= B; ++i)
        for(int j = 0; j < i; ++j)
            dp[i][j] = 0;

    vist[node] = true;  
    for(int i = head[node]; i; i = edge[i].next)
        if(!vist[edge[i].to] && edge[i].to != fa[node])
            solve(edge[i].to, edge[i].to);
    if(fa[node] && !vist[fa[node]]) solve(fa[node], root);
}

int main(){
    std::scanf("%d", &n);
    for(int i = 2; i <= n; ++i){
        std::scanf("%d", fa + i);
        add(fa[i], i), add(i, fa[i]), ++line[dep[i] = dep[fa[i]] + 1];
    }

    solve(1, 1);

    for(int i = n; i >= 1; --i){
        line[i] += line[i + 1];
        for(int j = 2; i * j <= n; ++j)
            ans[i] -= ans[i * j];
    }

    for(int i = 1; i < n; ++i) std::printf("%lld\n", line[i] + ans[i]);
    return 0;
}

来源

UOJ Round #2 C 树上 GCD

评论