BZOJ3572 [HNOI2014]世界树

好久没有详细地写一写单题的题解了。

其实当时学虚树的时候就想写了,但是介于这题实在太dark… 我今天写了一个晚上都才写出来,主要麻烦的还是Debug.

分析

数据范围已经很明显地暗示了这题是用虚树来解的。我们不妨先建出一棵虚树——就以我们以前的单调栈的方式。然后,我们考虑虚树上的每一条边来统计答案。具体说明之前我们不妨给虚树中议事处的节点涂上黑色,其他节点为白色。那么边分为如下几类:

  1. 两端都是黑点
  2. 一端黑一端白
  3. 两端都是白点

类1,我们知道虚树上的边长度可能不止1,这条边上面还可能连出许多分支,不同的分支选择的黑点是不同的。但是我们可以二分(换到树上可以直接倍增)出这个分界点。(我写代码出锅也主要是这里出锅…)

对于类2和类3,我们找到离白点最近的黑点,这两类就变为同一类了。

注意对于虚树上的一点,可能还有原树中一些到其路径上不经过虚树边的点,这些点我们可以分摊到虚树上每个点来计算,算出虚树上的分支的大小,再用子树大小减去即可。

那么整理一下,我们的程序需要这些步骤:

  1. dfs预处理
  2. 建虚树
  3. 预处理虚树上每个点离它最近的黑点(主要需要编号)
  4. dfs计算答案

Code

感觉自己写得好简洁啊,还以为写错了呢…

#define forto(_) for (int e = last[_], v = E[e].to; e; v = E[e = E[e].next].to)
typedef int IntAr[MAXN];
struct Edge {
    int to, cost, next;
} E[MAXN << 1];

int N, Q, CS, tote, dfs_clock, F[MAXN][Lg_N + 2];
IntAr last, dep, C, D, S, H, H0, ans, tag, sz, dfn;

void dfs(int u);
void dfs_up(int u);
void dfs_dn(int u);
void dfs_vt(int u);
int lca(int u, int v);
inline bool cmp(int u, int v) {
    return dfn[u] < dfn[v];
}
inline int dist(int u, int v) {
    return dep[u] + dep[v] - dep[lca(u, v)] * 2;
}
inline void add_edge(int u, int v, int c) {
    E[++tote] = (Edge){v, c, last[u]}, last[u] = tote;
}
inline void add_vte(int u, int v) {
//  这个是给虚树加边专用的
    add_edge(u, v, dist(u, v));
}
inline void init_G() {
    tote = 0;
    memset(last, 0, sizeof last);
}

int main() {
    scanf("%d", &N);
    for (int i = 1, u, v; i < N; i++) {
        scanf("%d%d", &u, &v);  
        add_edge(u, v, 1);
        add_edge(v, u, 1);
    }

    dep[1] = 1;
    dfs(1);

    scanf("%d", &Q);
    int M;
    for (CS = 1; CS <= Q; CS++) {
        scanf("%d", &M);
        for (int i = 0; i < M; i++) {
            scanf("%d", H + i);
            tag[H[i]] = CS; // tag用于标记当前点的颜色
            ans[H[i]] = 0;
        }
        memcpy(H0, H, sizeof(int) * (M));

    //  build virtual tree
        std::sort(H, H + M, cmp);
        init_G();
        int top = 0;
        S[top++] = 1;
        for (int i = 0; i < M; i++) {
            int l = lca(S[top - 1], H[i]);
            while (dfn[l] < dfn[S[top - 1]]) {
                if (dfn[l] >= dfn[S[top - 2]]) break;
                add_vte(S[top - 2], S[top - 1]);
                l = lca(S[--top - 1], H[i]);
            }
            if (dfn[l] < dfn[S[top - 1]]) {
                int lst = S[--top];
                if (l != S[top - 1]) S[top++] = l;
                add_vte(S[top - 1], lst);
            }
            if (S[top - 1] != H[i]) S[top++] = H[i];
        }
        for (; top > 1; --top) add_vte(S[top - 2], S[top - 1]);

    //  find the closest taged vertex
        dfs_up(1);
        dfs_dn(1);
        dfs_vt(1);

        for (int i = 0; i < M; i++) printf("%d ", ans[H0[i]]);
        putchar('\n');
    }

    return 0;
}

void upd(int u, int v, int dis) {
    if (D[v] + dis < D[u] || D[v] + dis == D[u] && C[v] < C[u]) {
        D[u] = D[v] + dis;
        C[u] = C[v];
    }
}
// dfs on virtual tree
// update up
void dfs_up(int u) {
    if (tag[u] == CS) C[u] = u, D[u] = 0;
    else C[u] = -1, D[u] = INF;
    forto(u) {
        dfs_up(v);  
        upd(u, v, E[e].cost);
    }
}
// update down
void dfs_dn(int u) {
    forto(u) {
        upd(v, u, E[e].cost);
        dfs_dn(v);
    }
}

int find(int up, int low) {
    if (C[up] == C[low]) return low;
    int v = low, d = -dist(up, C[up]) + dist(low, C[low]) + dep[low] + dep[up];
    for (int i = Lg_N; i >= 0; --i) {
        if ((1 << i) > dep[low] - dep[up]) continue;
        if ((dep[F[v][i]] << 1) > d) v = F[v][i];
    //  对于相等的情况可以先退让一步,跳到下面再处理
    }
    if ((dep[F[v][0]] << 1) == d && C[up] > C[low]) v = F[v][0];
//  这些dep[]里面很容易吧F[v][]漏下直接写v,我就是这里写错拍了好久
    return v;
}

int son(int u, int from) {
    for (int del = dep[from] - dep[u] - 1, i = 0; i <= Lg_N && (1 << i) <= del; ++i)
        if (del >> i & 1) from = F[from][i];
    return from;
}
// get ans
void dfs_vt(int u) {
    int rem = sz[u];
    forto(u) {
        int t = find(u, v), sn;
        if (dep[t] <= dep[u]) t = son(u, v);
        ans[C[v]] += sz[t] - sz[v];
        ans[C[u]] += sz[sn = son(u, t)] - sz[t];
        rem -= sz[sn];
        dfs_vt(v);
    }
    ans[C[u]] += rem;
}

void dfs(int u) {
    dfn[u] = ++dfs_clock, sz[u] = 1;
    for (int i = 1; i <= Lg_N && F[u][i - 1]; i++) F[u][i] = F[F[u][i - 1]][i - 1];
    forto(u) {
        if (v == F[u][0]) continue;
        F[v][0] = u, dep[v] = dep[u] + E[e].cost;
        dfs(v);
        sz[u] += sz[v];
    }
}

总结

这个题,在边上计算比较难想,主要还是因为感觉这样太麻烦了… 其实,还行吧。写代码的时候还是要仔细一点,要不然又要被调试续好久了。不过考场上这种题还是一定要写暴力对拍,我好像对拍用的还比较少。

About The Author

发表评论

电子邮件地址不会被公开。 必填项已用*标注