跳过正文
  1. Posts/

树上启发式合并(dsu on tree)

·1122 字·3 分钟
📖 阅读 --
DogDu
作者
DogDu
工作结束或者累了, 要不休息一会, 看会动漫吧 ~

参考视频:【AgOHの算法胡扯】树上启发式合并(dsu on tree)

dsu on tree 的核心思想其实很朴素:始终保留重儿子的统计结果,把轻儿子的贡献往它上面合并。这样每个节点被“暴力加入统计结构”的次数是受控的,整体复杂度可以压到 O(n log n)

题目:P9233 [蓝桥杯 2023 省 A] 颜色平衡树 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)

理解模板之后,这道题真正麻烦的地方只剩一个:如何 O(1) 判断一棵子树是不是“颜色平衡树”。

我们设计三个数组: color[i],cnt[i],cntcnt[i],color[i]表示节点i的颜色,cnt[i]表示颜色i出现的次数,cntcnt[i]表示 颜色出现次数 i 的出现的次数。

对于以x为根的子树,如果 设 **c=cnt[color[x]] , f=cntcnt[c] 如果 cfsize[x] 那么他就是颜色平衡树。**原因:假设存在一个颜色出现的次数不等于c,那么必然 cf < size[x],这与c*fsize[x]矛盾,所以假设不成立。

当然,我们看到颜色个数这种问题,自然而然会想起来莫队的模板题(P1903 [国家集训队] 数颜色 / 维护队列 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn),这题也可以用莫队解决,只是需要先dfs一遍进行线性化而已,还是上面一样的判断方法,判断的时间复杂度为O(1),整体时间复杂度为O(n*根号n),可以通过本题。唯一需要注意一点的是,为了操作方便可能还要多存一个dfx表示时间戳x的原来的节点编号。

法一:启发式合并AC代码:

CPP 代码 · 共 132 行
#include <iostream>
#include <cmath>
#include <ctime>
#include <string>
#include <cstring>
#include <vector>
#include <algorithm>

#define ll long long

using namespace std;

const int maxn=200000+10;
vector<int> graph[maxn];

int n,root=1;
int color[maxn],son[maxn],t[2][maxn],siz[maxn];
//t[0][x]表示颜色x出现了几次,t[1][x]表示次数x出现的次数。

int ans=0,flag;//MINN,MAXN维护的是颜色出现的次数

void add(int x)
{
    t[1][t[0][x]]--;
    t[0][x]++;
    t[1][t[0][x]]++;
}

void del(int x)
{
    t[1][t[0][x]]--;
    t[0][x]--;
    t[1][t[0][x]]++;
}

void add_k(int now)
{
    add(color[now]);

    for(const auto&it : graph[now])
        if(it!=flag)add_k(it);
}


void del_k(int now)
{
    del(color[now]);

    for(const auto&it : graph[now])
        if(it!=flag)del_k(it);
}

void dfs(int now)
{
    int maxsize=-1;
    siz[now]=1;

    for(const auto&it : graph[now])
    {
        dfs(it);
        siz[now]+=siz[it];
        if(siz[it]>maxsize)
        {
            maxsize=siz[it];
            son[now]=it;
        }
    }
}


void dfs(int now,bool keep)
{
    for(const auto& it : graph[now])
        if(it != son[now])
            dfs(it,false);
    
    if(son[now]!=0)
    {
        dfs(son[now],true);
        flag=son[now];
    }

    add_k(now);
    flag=0;
    
    //这句话的含义是:color[now]出现的频次*这个频次的频次 ==  siz[now],
    //则说明颜色出现最多的次数和颜色出现最小的次数相等。
    //原因是:如果存在一个颜色,其出现的频次不等于color[now]出现的频次,
    //那么,这个算式是小于size[now]的。
    if(t[0][color[now]]*t[1][t[0][color[now]]]==siz[now])
        ++ans;

    if(keep==false)
    {
        del_k(now);
    }
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);

#ifdef LOCAL
    clock_t c1 = clock();
    freopen("in.in","r",stdin);
    freopen("out.out","w",stdout);
#endif

//-------------------------------------------------

    cin>>n;

    for(int i=1,f;i<=n;++i)
    {
        cin>>color[i]>>f;
        graph[f].push_back(i);
    }

    dfs(root);
    dfs(root,true);

    cout << ans << endl;
//-------------------------------------------------

#ifdef LOCAL
    cout << "Time Used:\n" << clock() - c1 << " ms" << endl;
#endif

    //system("pause");
    return 0;
}

法二:莫队AC代码:

CPP 代码 · 共 103 行
#include <iostream>
#include <cmath>
#include <ctime>
#include <string>
#include <cstring>
#include <vector>
#include <algorithm>

#define ll long long

using namespace std;
 
const int maxn=200000+10;
vector<int> graph[maxn];

int n,root=1,tim=0,block;
int color[maxn],siz[maxn],dfn[maxn],dfx[maxn];
int cnt[maxn],cntcnt[maxn];

struct QUERY{
    int l,r;
    bool operator<(const QUERY& other)const{
        if(l/block==other.l/block)return r<other.r;
        return l<other.l;
    }
}query[maxn];

void dfs(int now)
{
    siz[now]=1;
    dfn[now]=++tim;
    dfx[tim]=now;

    for(const auto&it : graph[now])
    {
        dfs(it);
        siz[now]+=siz[it];
    }
}

void add(int x,int k)
{
    cntcnt[cnt[x]]--;
    cnt[x]+=k;
    cntcnt[cnt[x]]++;
}

int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);

#ifdef LOCAL
    clock_t c1 = clock();
    freopen("in.in","r",stdin);
    freopen("out.out","w",stdout);
#endif

//-------------------------------------------------

    cin>>n;

    block=sqrt(n);

    for(int i=1,y;i<=n;++i)
    {
        cin>>color[i]>>y;
        graph[y].push_back(i);
    }

    dfs(root);

    for(int i=1;i<=n;++i)
        query[i]={dfn[i],dfn[i]+siz[i]-1};

    sort(query+1,query+1+n);

    int ans=0,l=1,r=1;
    add(color[dfn[1]],1);

    for(int i=1,c,f;i<=n;++i)
    {
        while(l<query[i].l)add(color[dfx[l++]],-1);
        while(l>query[i].l)add(color[dfx[--l]],1);
        while(r<query[i].r)add(color[dfx[++r]],1);
        while(r>query[i].r)add(color[dfx[r--]],-1);

        c=cnt[color[dfx[query[i].l]]],f=cntcnt[c];

        if(c*f==query[i].r-query[i].l+1)
            ++ans;
    }

    cout<<ans;
//-------------------------------------------------

#ifdef LOCAL
    cout << "\nTime Used: \n" << clock() - c1 << " ms" << endl;
#endif

    //system("pause");
    return 0;
}