题目链接

UOJ

洛谷

题意简述

给你一棵树,$i$ 号点有 $p_i$、$q_i$ 和 $l_i$ 三个属性,每条边有给定的长度。

从一个点出发可以到达其祖先中与其距离不超过 $l_i$ 的点,费用为 $p_i\cdot dis+q_i$,求每个点到根的最小费用。

点数不超过 $2\cdot 10^5$ 。

简要做法

用 $dis[i]$ 表示$i$ 到根的距离,$f_i$ 表示点 $i$ 的答案,那么: $$ f_i=\min\limits_{\begin{array}{cc}j\text{ is an ancestor of }i\\dis[j]\ge dis[i]-l_i\end{array}}\{f_j-p_i dis[j]+p_i dis[i]+q[i]\} $$ 这个东西可以用斜率优化,然而有三个不太友好的地方:

  1. 这是棵树;
  2. 有 $dis[j]\ge dis[i]-l_i$ 这个限制;
  3. $p_i$ 不是单调的。

类似于 「NOI2007」货币兑换 ,可以用类似于 CDQ 分治的点分治来解决上述三个问题。(然而我不仅没做过货币兑换,而且想到这个做法的时候都没意识到它是 CDQ 分治,只是在 四色的 NOI 中听说了这题可以点分治然后就 yy 出来了..)

具体来说,每个转移都可以看成一条路径,但只有竖直向上的路径是合法的,所以“向上”的那个子树需要特殊处理。类似于 CDQ 分治先处理左半部分再计算左半部分对右半部分的贡献,这题每次分治时先处理“向上”的那个子树,然后用分治中心到根的链除了分治中心本身外在当前分治树上的部分来更新分治中心的 DP 值,再用分治中心到根的链在当前分治树上的部分来更新除了“向上”的那个外的其它子树的 DP 值,最后再分治下去处理除了“向上”的那个外的其它子树。

更新除了“向上”的那个外的其它子树的 DP 值时,需要把这些子树里的点按 $dis[i]-l_i$ 排序,然后从下往上把分治中心到根的链在当前分治树上的部分在可行时加入凸包。

由于 $p_i$ 不是单调的,并不是用单调队列/单调栈维护凸包,而是保留整个凸包,查询的时候二分。

用叉积判斜率会爆 long long,可以用 __int128 / double 解决。

总复杂度 $O(n\log^2 n)$ 。

参考代码

#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>

using namespace std;

typedef long long ll;

const ll INF = 1e18;

vector<bool> vis;
int n, type, tsiz, rt;
vector<vector<int> > g;
vector<ll> dis, p, q, k, f;
vector<int> siz, wt, dep, pa;

void getroot(int u, int fa)
{
    siz[u] = wt[u] = 1;
    for (int i = 0; i < g[u].size(); ++i)
    {
        int v = g[u][i];
        if (v == fa || vis[v]) continue;
        getroot(v, u);
        siz[u] += siz[v];
        wt[u] = max(wt[u], siz[v]);
    }
    wt[u] = max(wt[u], tsiz - siz[u]);
    if (!rt || wt[u] < wt[rt]) rt = u;
}

void getchildren(int u, vector<int>& children)
{
    for (int i = 0; i < g[u].size(); ++i)
    {
        int v = g[u][i];
        if (v == pa[u] || vis[v]) continue;
        children.push_back(v);
        getchildren(v, children);
    }
}

bool cmp(int x, int y)
{
    return dis[x] - k[x] > dis[y] - k[y];
}

bool check(int x, int y, int z)
{
    return 1.0 * (f[x] - f[y]) / (dis[x] - dis[y]) < 1.0 * (f[y] - f[z]) / (dis[y] - dis[z]);
}

ll calc(int i, int j)
{
    return f[j] + p[i] * (dis[i] - dis[j]) + q[i];
}

void solve(int u)
{
    vis[u] = true;
    vector<int> anc(1, u);
    while (pa[anc.back()] && !vis[pa[anc.back()]]) anc.push_back(pa[anc.back()]);
    if (anc.size() > 1)
    {
        rt = 0;
        tsiz = siz[pa[u]];
        getroot(pa[u], u);
        solve(rt);
        for (int i = 1; i < anc.size() && dis[anc[i]] >= dis[u] - k[u]; ++i) f[u] = min(f[u], calc(u, anc[i]));
    }
    vector<int> children;
    getchildren(u, children);
    sort(children.begin(), children.end(), cmp);
    vector<int> convex;
    int t = 0;
    for (int i = 0; i < children.size(); ++i)
    {
        int v = children[i];
        while (t < anc.size() && dis[anc[t]] >= dis[v] - k[v])
        {
            while (convex.size() >= 2 && check(convex[convex.size() - 2], convex.back(), anc[t])) convex.pop_back();
            convex.push_back(anc[t++]);
        }
        if (convex.empty()) continue;
        int l = 0, r = convex.size() - 1;
        while (l < r)
        {
            int mid = (l + r + 1) >> 1;
            if (mid && calc(v, convex[mid - 1]) < calc(v, convex[mid])) r = mid - 1;
            else l = mid;
        }
        f[v] = min(f[v], calc(v, convex[l]));
    }
    for (int i = 0; i < children.size(); ++i)
    {
        int v = children[i];
        if (v == pa[u] || vis[v]) continue;
        rt = 0;
        tsiz = siz[v];
        getroot(v, u);
        solve(v);
    }
}

int main()
{
    scanf("%d%d", &n, &type);

    vis.resize(n + 1, false);
    f.resize(n + 1, INF);
    dis.resize(n + 1, 0);
    dep.resize(n + 1, 0);
    siz.resize(n + 1, 0);
    wt.resize(n + 1, 0);
    pa.resize(n + 1, 0);
    g.resize(n + 1);
    p.resize(n + 1);
    q.resize(n + 1);
    k.resize(n + 1);

    for (int i = 2; i <= n; ++i)
    {
        scanf("%d%lld%lld%lld%lld", &pa[i], &dis[i], &p[i], &q[i], &k[i]);
        g[pa[i]].push_back(i);
        g[i].push_back(pa[i]);
        dis[i] += dis[pa[i]];
    }

    f[1] = 0;
    tsiz = n;
    getroot(1, 0);
    solve(rt);

    for (int i = 2; i <= n; ++i) printf("%lld\n", f[i]);

    return 0;
}