跳转至

树上的博弈

题意

给定一棵 n 个点的无根树,点编号为 1,2,\ldots,n。在第 x 天,第 i 条边的边权为 k_ix+b_i

Bob 有两个棋子,初始时位置分别为 a=1b=2。Bob 可以执行如下两种操作之一:

  • a 从当前位置移动至一个未遍历过的点 a^\prime (a^\prime>a),并获取等于它们最短路径上的边权总和的分数。除了 a^\prime,路径上的其他点不算被遍历过。
  • b 从当前位置移动至一个未遍历过的点 b^\prime (b^\prime>b),并获取等于它们最短路径上的边权总和的分数。除了 b^\prime,路径上的其他点不算被遍历过。

Bob 会执行 n-2 次操作,遍历树上所有的点,并总是以最优方案得到最多的分数。请选择一天 x 让 Bob 获得的分数最小。

解析

如果我们已经确定了 x,就可以通过 O(n^2) 的动态规划计算出 Bob 能够得到的最多分数。设 dp_{i,j} 表示两个棋子分别在 i,j (i>j) 时 Bob 能获得的最大分数,则有转移方程:

  • 若将 i 移动至 i+1,用 dp_{i,j}+\operatorname{dist}(i,i+1) 更新 dp_{i+1,j}
  • 若将 j 移动至 i+1,用 dp_{i,j}+\operatorname{dist}(j,i+1) 更新 dp_{i+1,i}

这个 DP 还是比较简单的。dp_{i+1}dp_i 的不同之处只在于,dp_{i+1} 中的每个元素都比 dp_i 中的多出了 \operatorname{dist}(i,i+1),还有一个额外的元素 dp_{i+1,i}=\max (dp_{i,j}+\operatorname{dist}(j,i+1))

对于全体增加 \operatorname{dist}(i,i+1) 这一操作,可以用一个变量轻松维护。唯一困难之处在于,我们要在低于 O(n) 时间内找到一个能最大化 dp_{i,j}+\operatorname{dist}(j,i+1)j,用它来更新 dp_{i+1,i}

考虑构建原树的点分树,对于每个点记录其子树中的最大值和次大值。每一次我们可以 O(\log^2 n) (还有一个 log 是求距离产生的) 查询最大的 dp_{i,j}+\operatorname{dist}(j,i+1) 并修改。

这样,对于确定的某一天,我们就可以 O(n\log^2 n) 求出 Bob 的最大分数了。观察到,每种方案的分数是若干个线性函数的和,同样是线性函数;而每天的最大分数是所有方案的最大值,作为若干线性函数的最大值,它一定是一个开口向上的凸函数。我们可以用三分求出这个函数的最小值。总复杂度为 O(n\log^2 n\log m)

实现

#include <bits/stdc++.h>

template <typename T1, typename T2>
inline void chkmax(T1 &a, const T2 &b){
    if(b > a) a = b;
}

typedef long long int ll;
const int maxn = 1e5 + 19;
const ll inf = 1e18;

struct Edge{
    int to, next, k, b;
}edge[maxn << 1];

int head[maxn];

inline void add(int from, int to, int k, int b){
    edge[++head[0]] = (Edge){to, head[from], k, b};
    head[from] = head[0];
}

bool vist[maxn];

struct TreeEngine{
    int sz, rt, size[maxn], heavy[maxn];
    int st[maxn], tp;
    void getsz(int node, int f){
        st[++tp] = node;
        size[node] = 1, heavy[node] = 0;
        for(int i = head[node]; i; i = edge[i].next)
            if(!vist[edge[i].to] && edge[i].to != f){
                getsz(edge[i].to, node);
                size[node] += size[edge[i].to];
                heavy[node] = std::max(heavy[node], size[edge[i].to]);
            }
    }
    int find(int node){
        tp = 0, getsz(node, 0), sz = size[node];
        rt = st[1];
        for(int i = 1; i <= tp; ++i){
            int node = st[i];
            heavy[node] = std::max(heavy[node], sz - size[node]);
            if(heavy[node] < heavy[rt]) rt = node;
        }
        return rt;
    }
    int dep[maxn], fa[maxn], son[maxn], top[maxn];
    std::pair<int, ll> d[maxn];
    void dfs1(int node, int f){
        size[node] = 1, fa[node] = f, dep[node] = dep[f] + 1;
        for(int i = head[node]; i; i = edge[i].next)
            if(edge[i].to != f){
                d[edge[i].to] = std::make_pair(d[node].first + edge[i].k, d[node].second + edge[i].b);
                dfs1(edge[i].to, node);
                size[node] += size[edge[i].to];
                if(size[edge[i].to] > size[son[node]])
                    son[node] = edge[i].to;
            }
    }
    void dfs2(int node, int f, int t){
        top[node] = t;
        if(son[node]) dfs2(son[node], node, t);
        for(int i = head[node]; i; i = edge[i].next)
            if(edge[i].to != f && edge[i].to != son[node])
                dfs2(edge[i].to, node, edge[i].to);
    }
    int lca(int x, int y){
        while(top[x] != top[y])
            if(dep[top[x]] > dep[top[y]]) x = fa[top[x]];
            else y = fa[top[y]];
        return dep[x] < dep[y] ? x : y;
    }
    ll dist(int u, int v, int x){
        int l = lca(u, v);
        int k = d[u].first + d[v].first - d[l].first * 2;
        ll b = d[u].second + d[v].second - d[l].second * 2;
        return (ll)k * x + b;
    }
}mt;

int fa[maxn];
int build(int node){
    node = mt.find(node);
    vist[node] = true;
    for(int i = head[node]; i; i = edge[i].next)
        if(!vist[edge[i].to]){
            int v = build(edge[i].to);
            fa[v] = node;
        }
    return node;
}

int n, m;

struct Solver{
    int x;
    ll dp[maxn], tag;
    std::pair<ll, int> val[maxn][2];
    void push(int node, const std::pair<ll, int> &v){
        if(v.second == val[node][0].second || v.second == val[node][1].second){
            if(v.second == val[node][0].second) chkmax(val[node][0].first, v.first);
            else chkmax(val[node][1].first, v.first);
            if(val[node][1] > val[node][0]) std::swap(val[node][0], val[node][1]);
        }else{
            if(v.first > val[node][0].first)
                val[node][1] = val[node][0], val[node][0] = v;
            else if(v.first > val[node][1].first)
                val[node][1] = v;
        }
    }
    ll dist(int u, int v){
        return mt.dist(u, v, x);
    }
    void insert(int node){
        int p = node; push(node, std::make_pair(dp[node], 0));
        while(fa[p]){
            push(fa[p], std::make_pair(dp[node] + dist(fa[p], node), p));
            p = fa[p];
        }
    }
    ll query(int node){
        int p = node; ll res = val[p][0].first;
        while(fa[p]){
            if(p == val[fa[p]][0].second) chkmax(res, val[fa[p]][1].first + dist(fa[p], node));
            else chkmax(res, val[fa[p]][0].first + dist(fa[p], node));
            p = fa[p];
        }
        return res;
    }
    ll check(int x){
        for(int i = 1; i <= n; ++i) val[i][0] = val[i][1] = std::make_pair(-inf, 0);
        this->x = x, tag = 0ll, dp[1] = 0ll, insert(1);
        for(int i = 3; i <= n; ++i){
            dp[i - 1] = query(i) + tag;
            tag += dist(i, i - 1), dp[i - 1] -= tag;
            insert(i - 1);
        }
        ll res = -inf;
        for(int i = 1; i < n; ++i) chkmax(res, dp[i]);
        return res + tag;
    }
}sol;

int main(){
    std::scanf("%d%d", &n, &m);
    for(int i = 2, u, v, k, b; i <= n; ++i){
        std::scanf("%d%d%d%d", &u, &v, &k, &b);
        add(u, v, k, b), add(v, u, k, b);
    }
    build(1), mt.dfs1(1, 0), mt.dfs2(1, 0, 1);
    int l = 1, r = m;
    while(l < r){
        int mid = (l + r) >> 1;
        if(sol.check(mid) < sol.check(mid + 1)) r = mid;
        else l = mid + 1;
    }
    std::printf("%lld\n", sol.check(l));
    return 0;
}

来源

NOI.AC #2021 树上的博弈

评论