跳过正文
  1. Posts/

AVL树全代码实现与验证

·3616 字·8 分钟· loading
DogDu
作者
DogDu
相信代码的力量,用实现驱动学习
目录

建议先了解 AVL 树的定义、平衡因子与四种基本旋转,再来看这篇实现记录。

这篇文章的重点不是重新讲一遍教科书定义,而是把“手写 AVL”这件事真正落到代码上:如何设计节点信息、如何维护高度、插入和删除时在什么位置检查失衡,以及如何把旋转逻辑写得稳定可复用。

我最初尝试过沿用教材里的非递归写法,在节点中维护 bf 和 parent,但实现过程非常拧巴:删除后的回溯修正尤其容易写乱。后来改成递归版本,并统一用 height 推导平衡状态,代码结构才变得清晰很多。

先把一个核心结论说清楚:AVL 树的失衡修正只有四种基本类型——LL、LR、RR、RL。只要数据结构和回溯规则设计正确,插入和删除都不会凭空冒出“第五种旋转”。后文的代码就是围绕这个原则展开的。

定义
#

CPP 代码 · 共 5 行
struct AVLNode{
    AVLNode*left,*right;//没有parent,所以需要递归,因为递归,所以简单
    int key,value;
    int height;//因为bf可以由height算出,所以存height
};

这里我用AVL树做桶排,用洛谷的P1177排序做简单的验证,所以有一个value表示key的个数。而且不要把bf直接存进AVL树节点中,不然会变的不幸!!!(也可以自己尝尝试试)

CPP 代码 · 共 12 行
class AVLTree{
private:
    AVLNode* tree;//所有空指针全部指向tree,方便更新高度和计算平衡因子
    int Size;

public:
    AVLTree(){
        tree=new AVLNode;
        tree->height=0;//设置成0
        tree->left=tree->right=tree;//全部都指向头结点
        Size=0;
    }

在类的定义中,为了方便后面的代码,我设置了一个头结点,把高度设置成0,把全部的空节点都指向头结点,不会触摸到空指针,也方便处理。

给出几个辅助函数:

CPP 代码 · 共 21 行
    void newnode(AVLNode*&p,int& val)
    {
        p=new AVLNode;
        p->key=val;
        p->value=1;
        p->left=p->right=tree;//指向tree
        p->height=1;
        ++Size;
    }

    void update(AVLNode*&p)//更新高度
    {
        if(p==tree)return;//如果是tree,不更新高度,一直为0
        p->height=max(p->left->height,p->right->height)+1;
    }

    int factor(AVLNode*&p)//获得bf
    {
        if(p==tree)return 0;
        return p->left->height-p->right->height;//返回因子
    }

注意factor和update一定要特判是否是头结点,因为实际上在删除函数的递归平衡中会出现头结点(我已经试过删之后会出现什么了。)

旋转
#

看左右旋转:

CPP 代码 · 共 19 行
    void lrotate(AVLNode*&T)//左旋
    {
        AVLNode*r=T->right;
        T->right=r->left;
        r->left=T;
        T=r;

        update(T->left),update(T);
    }

    void rrotate(AVLNode*&T)//右旋
    {
        AVLNode*l=T->left;
        T->left=l->right;
        l->right=T;
        T=l;

        update(T->right),update(T);
    }

一定要注意把T当做引用型参数。因为在平衡之后下一步需要进一步循环的时候就改变对象了。

非常重要的check函数帮忙平衡:

CPP 代码 · 共 20 行
    void check(AVLNode*&T)//check
    {
        int bf=factor(T);//获得因子

        if(bf>1)
        {
            int lf=factor(T->left);

            if(lf>0)rrotate(T);
            else lrotate(T->left),rrotate(T);
        }
        else if(bf<-1)
        {
            int rf=factor(T->right);

            if(rf<0)lrotate(T);
            else rrotate(T->right),lrotate(T);
        }
        else if(T!=tree)update(T);//如果平衡,并且非空,更新高度
    }

注意如果bf>1或者bf<-1时,需要平衡,之后 if(lf>0)显然右旋即可,但是为什么要写else呢?明明根据 LRLR , LLR 来讲,只剩下 lf==-1 的情况了啊,lf在bf>1的时候不可能为0啊,按照旋转的四个分类的话。为什么呢?

实际上,lf==0的情况是会发生的,确实当bf==2的时候,lf只会==1或者-1,但这仅限于插入的时候,如果带上删除的话,比如:当前左边高度为2且左节点bf为0,右边高度为1,删除右节点,这个时候就出现了bf==2而lf==0的情况。那这怎么办呢?不是说只有四种情况吗?

我帮你们试过了,这种情况下,LLR旋转或者LRLR旋转都是可以完成平衡的,达到的是相同的效果,可以自己画图试一下。既然如此,为什么不选择单旋呢?额,我试了一下这两个,发现OJ不知道为什么是LRLR的双旋反而快2ms,诡异。

其实上面的就是所有的难点了。

功能函数
#

来看insert函数:

CPP 代码 · 共 48 行
    AVLnode*search(int x)
    {
        AVLnode*ans;
        _search(tree->left,x,ans);
        return ans;
    }
    void _search(AVLnode*&now,int &key,AVLnode*&ans)
    {
        if(now==tree)
        {
            ans=nullptr;
            return;
        }
        else if(key>now->key)_search(now->right,key,ans);
        else if(key<now->key)_search(now->left,key,ans);
        else
        {
            ans=now;
            return;
        }

    }

     bool insert(int x)
    {
        bool flag;
        ins(tree->left,x,flag);
        return flag;
    }
    void ins(AVLnode*&now,int& key,bool &flag)
    {
        if(now==tree)
        {
            newnode(now,key);
            flag=1;
        }
        else if(key>now->key)ins(now->right,key,flag);
        else if(key<now->key)ins(now->left,key,flag);
        else
        {
            flag=0;
            ++now->value;
            return;
        }

        if(flag)
            check(now);
    }

search函数没啥好说的,很简单。

insert函数对public,ins函数是private,注意指针传递一定是引用型的,通过设置了一个引用型的flag判断是否进行了insert,如果没有insert,就直接不再平衡,常数的优化罢了。最后函数进行检查是否平衡,不平衡就通过check旋转。

CPP 代码 · 共 58 行
    bool erase(int x)
    {
        bool flag;
        del(tree->left,x,flag);
        return flag;
    }
        void del(AVLnode*&now,int&key,bool&flag)
    {
        if(now==tree)
        {
            flag=0;
            return;
        }
        else if(key>now->key)del(now->right,key,flag);
        else if(key<now->key)del(now->left,key,flag);
        else
        {
            AVLnode*temp=now,*l=now->left,*r=now->right;

            if(l==tree)now=r;
            else if(r==tree)now=l;
            else
            {
                now=find(l,l);//把这个节点剖下来

                if(now!=l)//如果不是直接为左边的
                    now->left=l;

                now->right=r;
            }

            --Size;
            flag=1;
            delete temp;
        }

        if(flag)
            check(now);
        return;
    }

    AVLnode* find(AVLnode*&now,AVLnode*fa)
    {
        AVLnode*res;

        if(now->right==tree)
        {
            res=now;
            fa->right=now->left;
        }
        else
        {
            res=find(now->right,now);
            check(now);
        }

        return res;
    }

删除策略:找直接前驱(数据结构-严蔚敏),然后指针替代,这个节点更方便于删除。最后delete被查找节点。find函数做的就是得到直接前驱的节点,最后还是递归检查就好了。

完整代码和AC记录
#

最后看完整代码(对于P1177,因为没有写迭代器,只能以这种方式验证了….有时间一定写一个迭代器,前一段时间刚学二叉排序树迭代器怎么写,一定要找时间搓一个。):

CPP 代码 · 共 225 行
#include <iostream>
#include <cmath>
#include <string>
#include <cstring>
#include <vector>
#include <algorithm>

#define ll long long

using namespace std;

struct AVLnode{
    AVLnode*left,*right;
    int height;
    int key,value;
};

class AVLTree{
private:
    int Size=0;
    AVLnode*root;
public:
    AVLTree(){
        root=new AVLnode;
        Size=0;
        root->left=nullptr;
        root->right=nullptr;
        root->height=0;
    }
    bool insert(int x)
    {
        bool flag;
        ins(root->left,x,flag);
        return flag;
    }
    bool erase(int x)
    {
        bool flag;
        del(root->left,x,flag);
        return flag;
    }
    AVLnode*search(int x)
    {
        AVLnode*ans;
        _search(root->left,x,ans);
        return ans;
    }
private:
    void del(AVLnode*&now,int&key,bool&flag)
    {
        if(now==root||now==nullptr)
        {
            flag=0;
            return;
        }
        else if(key>now->key)del(now->right,key,flag);
        else if(key<now->key)del(now->left,key,flag);
        else
        {
            AVLnode*temp=now,*l=now->left,*r=now->right;

            if(l==root)now=r;
            else if(r==root)now=l;
            else
            {
                now=find(l,l);//把这个节点剖下来

                if(now!=l)//如果不是直接为左边的
                    now->left=l;

                now->right=r;
            }

            --Size;
            flag=1;
            delete temp;
        }

        if(flag)
            check(now);
        return;
    }

    AVLnode* find(AVLnode*&now,AVLnode*fa)
    {
        AVLnode*res;

        if(now->right==root)
        {
            res=now;
            fa->right=now->left;
        }
        else
        {
            res=find(now->right,now);
            check(now);
        }

        return res;
    }

    void _search(AVLnode*&now,int &key,AVLnode*&ans)
    {
        if(now==root||now==nullptr)
        {
            ans=nullptr;
            return;
        }
        else if(key>now->key)_search(now->right,key,ans);
        else if(key<now->key)_search(now->left,key,ans);
        else
        {
            ans=now;
            return;
        }

    }

    void ins(AVLnode*&now,int& key,bool &flag)
    {
        if(now==root||now==nullptr)
        {
            newnode(now,key);
            flag=1;
        }
        else if(key>now->key)ins(now->right,key,flag);
        else if(key<now->key)ins(now->left,key,flag);
        else
        {
            flag=0;
            ++now->value;
            return;
        }

        if(flag)
            check(now);
    }

    void check(AVLnode*&T)
    {
        int bf=factor(T);

        if(bf>1)
        {
            int lf=factor(T->left);
            if(lf>0)Rrotate(T);
            else Lrotate(T->left),Rrotate(T);

        }
        else if(bf<-1)
        {
            int rf=factor(T->right);
            if(rf<0)Lrotate(T);
            else Rrotate(T->right),Lrotate(T);
        }
        else if(T!=root)update(T);
    }

    int factor(AVLnode*&T)
    {
        if(T==root)return 0;
        return T->left->height-T->right->height;
    }
    
    void Rrotate(AVLnode*&T)
    {
        AVLnode*L=T->left;
        T->left=L->right;
        L->right=T;
        T=L;
        update(T->right),update(T);
    }
    void Lrotate(AVLnode*&T)
    {
        AVLnode*R=T->right;
        T->right=R->left;
        R->left=T;
        T=R;
        update(T->left),update(T);
    }
    void update(AVLnode*&T)
    {
        if(T==root)return;
        T->height=max(T->left->height,T->right->height)+1;
    }
    void newnode(AVLnode*&T,int &key)
    {
        T=new AVLnode;
        T->left=T->right=root;
        T->height=1;
        T->key=key;
        T->value=1;
        ++Size;
    }
};
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);

    int n;
    cin>>n;
    int a[n+1];
    AVLTree it;

    for(int i=1;i<=n;++i)
    {
        cin>>a[i];
        it.insert(a[i]);
    }

    sort(a+1,a+1+n);
    int tot=unique(a+1,a+1+n)-a-1;

    for(int i=1;i<=tot;++i)
    {
        int t=it.search(a[i])->value;
        while(t--)cout<<a[i]<<' ';
        it.erase(a[i]);
    }

    cout<<endl;
    system("pause");
    return 0;
}

![] (/img/CSDN/algo-avl-tree-full-impl/86809b085806610a958bece8466c6b83.webp)

CPP 代码 · 共 301 行
#include <iostream>

using namespace std;

//删去了value,允许重复插入,方便处理get_val函数

struct AVLnode{
    AVLnode*left,*right;
    int height,Size;
    int key;
};

class AVLTree{
private:
    int Size=0;
    AVLnode*root;
public:
    AVLTree(){
        root=new AVLnode;
        Size=0;
        root->left=nullptr;
        root->right=nullptr;
        root->height=0;
        root->Size=0;
    }

    bool insert(int x)
    {
        bool flag;
        ins(root->left,x,flag);
        return flag;
    }
    bool erase(int x)
    {
        bool flag;
        del(root->left,x,flag);
        return flag;
    }
    AVLnode*search(int x)
    {
        AVLnode*ans;
        _search(root->left,x,ans);
        return ans;
    }

    int get_rank(int x)
    {
        AVLnode*now=root->left;
        int rank=1;

        while(now!=root&&now!=nullptr)
        {
            if(x<=now->key)
                now=now->left;
            else
            {
                rank+=now->left->Size+1;
                now=now->right;
            }
        }

        return rank;
    }

    int get_val(int rank)
    {
        AVLnode*now=root->left;

        while(now!=root&&now!=nullptr)
        {
            if(now->left->Size+1==rank)
                break;
            else if(now->left->Size>=rank)
                now=now->left;
            else
            {
                rank-=now->left->Size+1;
                now=now->right;
            }
        }

        return now->key;
    }

    int get_pre(int x)
    {
        AVLnode*p=root->left;
        int pre;

        while(p!=root&&p!=nullptr)
        {
            if(p->key<x)pre=p->key,p=p->right;
            else p=p->left;
        }

        return pre;
    }

    int get_next(int x)
    {
        AVLnode*p=root->left;
        int next;

        while(p!=root&&p!=nullptr)
        {
            if(p->key>x)next=p->key,p=p->left;
            else p=p->right;
        }

        return next;
    }

private:

    void del(AVLnode*&now,int&key,bool&flag)
    {
        if(now==root||now==nullptr)
        {
            flag=0;
            return;
        }
        else if(key>now->key)del(now->right,key,flag);
        else if(key<now->key)del(now->left,key,flag);
        else
        {
            AVLnode*temp=now,*l=now->left,*r=now->right;

            if(l==root)now=r;
            else if(r==root)now=l;
            else
            {
                now=find(l,l);//把这个节点剖下来

                if(now!=l)//如果不是直接为左边的
                    now->left=l;

                now->right=r;
            }

            --Size;
            flag=1;
            delete temp;
        }

        if(flag)
            check(now);
        return;
    }

    AVLnode* find(AVLnode*&now,AVLnode*fa)
    {
        AVLnode*res;

        if(now->right==root)
        {
            res=now;
            fa->right=now->left;
        }
        else
        {
            res=find(now->right,now);
            check(now);
        }

        return res;
    }

    void _search(AVLnode*&now,int &key,AVLnode*&ans)
    {
        if(now==root||now==nullptr)
        {
            ans=nullptr;
            return;
        }
        else if(key>now->key)_search(now->right,key,ans);
        else if(key<now->key)_search(now->left,key,ans);
        else
        {
            ans=now;
            return;
        }

    }

    void ins(AVLnode*&now,int& key,bool &flag)
    {
        if(now==root||now==nullptr)
        {
            newnode(now,key);
            flag=1;
        }
        else if(key>now->key)ins(now->right,key,flag);
        else ins(now->left,key,flag);

        if(flag)
            check(now);
    }

    void check(AVLnode*&T)
    {
        int bf=factor(T);

        if(bf>1)
        {
            int lf=factor(T->left);
            if(lf>0)Rrotate(T);
            else Lrotate(T->left),Rrotate(T);

        }
        else if(bf<-1)
        {
            int rf=factor(T->right);
            if(rf<0)Lrotate(T);
            else Rrotate(T->right),Lrotate(T);
        }
        else if(T!=root)update(T);
    }

    int factor(AVLnode*&T)
    {
        if(T==root)return 0;
        return T->left->height-T->right->height;
    }
    
    void Rrotate(AVLnode*&T)
    {
        AVLnode*L=T->left;
        T->left=L->right;
        L->right=T;
        T=L;
        update(T->right),update(T);
    }
    void Lrotate(AVLnode*&T)
    {
        AVLnode*R=T->right;
        T->right=R->left;
        R->left=T;
        T=R;
        update(T->left),update(T);
    }
    void update(AVLnode*&T)
    {
        if(T==root)return;
        T->height=max(T->left->height,T->right->height)+1;
        T->Size=T->left->Size+T->right->Size+1;
    }
    void newnode(AVLnode*&T,int &key)
    {
        T=new AVLnode;
        T->left=T->right=root;
        T->height=1;
        T->key=key;
        ++Size;
    }
};
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0),cout.tie(0);

    int n;
    cin>>n;
    AVLTree it;

    while(n--)
    {
        int op,x;
        cin>>op>>x;

        if(op==1)
        {
            it.insert(x);
        }
        else if(op==2)
        {
            it.erase(x);
        }
        else if(op==3)
        {
            cout<<it.get_rank(x)<<'\n';
        }
        else if(op==4)
        {
            cout<<it.get_val(x)<<'\n';
        }
        else if(op==5)
        {
            int t=it.get_rank(x)-1;
            cout<<it.get_val(t)<<'\n';
        }
        else if(op==6)
        {
            int t=it.get_rank(x+1);
            cout<<it.get_val(t)<<'\n';
        }
    }

    cout<<endl;
    system("pause");
    return 0;
}

随便给大伙看一下我之前痛苦的代码,连一个insert也没有实现的版本(还有一个版本被我弄丢了。)

CPP 代码 · 共 374 行
#include <iostream>
#include <cmath>
#include <string>
#include <cstring>
#include <vector>
#include <algorithm>

#define ll long long

using namespace std;

typedef struct AVLNode
{
    AVLNode *l = nullptr, *r = nullptr, *p = nullptr;
    int data, bf = 0,value=0;
} *AVLTree;

class myAVLTree
{
private:
    AVLTree tree;
    int Size;

    void Lrotate(AVLTree &T)
    {
        AVLTree R = T->r;

        T->r = R->l;
        R->l = T;

        if (T->p->l == T)
            T->p->l = R;
        else
            T->p->r = R;

        R->p = T->p;
        T->p = R;

        if (T->r != nullptr)
            T->r->p = T;

        T->bf = R->bf = 0;
    }

    void Rrotate(AVLTree &T)
    {
        AVLTree L = T->l;

        T->l = L->r;
        L->r = T;

        if (T->p->l == T)
            T->p->l = L;
        else
            T->p->r = L;

        L->p = T->p;
        T->p = L;

        if (T->l != nullptr)
            T->l->p = T;

        T->bf = L->bf = 0;
    }

    void LLR(AVLTree &T)
    {
        Rrotate(T);
    }

    void LRLR(AVLTree &T)
    {
        AVLTree L = T->l;
        int bf = L->r->bf;

        Lrotate(T->l);
        Rrotate(T);

        if (bf == 1)
            T->bf = -1;
        else if (bf == -1)
            L->bf = 1;
    }

    void RRL(AVLTree &T)
    {
        Lrotate(T);
    }

    void RLRL(AVLTree &T)
    {
        AVLTree R = T->r;
        int bf = R->l->bf;

        Rrotate(T->r);
        Lrotate(T);

        if (bf == 1)
            R->bf = -1;
        else if (bf == -1)
            T->bf = 1;
    }

    void rebalance(AVLTree p, int k)
    {
        AVLTree f = p->p;

        while (f != tree)
        {
            if (f->l == p)
                f->bf += k;
            else
                f->bf -= k;

            if (f->bf == 0)
                break;
            else if (f->bf == 1 || f->bf == -1)
                p = f, f = f->p;
            else
            {
                if (f->bf == 2)
                {
                    if (p->bf == 1)
                        LLR(f);
                    else
                        LRLR(f);
                }
                else
                {
                    if (p->bf == -1)
                        RRL(f);
                    else
                        RLRL(f);
                }

                break;
            }
        }
    }

    int GetDepth(AVLTree root)
    {
        if (root == nullptr)
            return 0;
        return max(GetDepth(root->l), GetDepth(root->r)) + 1;
    }

    bool check(AVLTree root)
    {
        if (root == nullptr)
            return true;
        ;
        int l = GetDepth(root->l);
        int r = GetDepth(root->r);

        if (abs(l - r) > 1 || l - r != root->bf)
            return false;

        return check(root->l) && check(root->r);
    }

public:
    myAVLTree()
    {
        tree = new AVLNode;
        Size = 0;
    }

    bool insert(int x)
    {
        if (tree->l == nullptr)
        {
            tree->l = new AVLNode;
            tree->l->p = tree;
            tree->l->data = x;
            tree->l->value=1;
            ++Size;
            return true;
        }

        AVLNode *p = tree->l, *f = tree;

        while (p != nullptr)
        {
            if (p->data == x)
            {
                p->value++;
                return false;
            }

            f = p;

            if (x > p->data)
                p = p->r;
            else
                p = p->l;
        }

        p=new AVLNode;
        p->p=f;
        p->data=x;
        p->value=1;
        
        if(x>f->data)f->r=p;
        else f->l=p;

        while(p!=tree)
        {
            if(f->l==p)++f->bf;
            else --f->bf;

            if(f->bf==0)break;
            else if(f->bf==-1||f->bf==1)
                p=f,f=f->p;
            else
            {    
                if(f->bf==2)
                {
                    if(p->bf==1)LLR(f);
                    else LRLR(f);
                }
                else
                {
                    if(p->bf==-1)RRL(f);
                    else RLRL(f);
                }

                break;
            }
        }

        return true;
    }

    AVLTree search(int x)
    {
        AVLTree p = tree->l;

        while (p != nullptr)
        {
            if (p->data == x)
                return p;

            if (x > p->data)
                p = p->r;
            else
                p = p->l;
        }

        return nullptr;
    }

    bool erase(int x)
    {
        AVLTree p = tree->l;

        while (p != nullptr)
        {
            if (p->data == x)
                break;

            if (x > p->data)
                p = p->r;
            else
                p = p->l;
        }

        if (p == nullptr)
            return false;

        if (p->l == nullptr)
        {
            AVLTree f = p->p;

            if (f->l == p)
            {
                --f->bf;
                f->l = p->r;
            }
            else
            {
                ++f->bf;
                f->r = p->r;
            }

            if (p->r != nullptr)
                p->r->p = f;

            delete p;

            rebalance(f, -1);
        }
        else
        {
            AVLTree l = p->l;

            while (l->r != nullptr)
                l = l->r;

            p->data = move(l->data);

            if (l->p == p)
            {
                --p->bf;
                p->l = nullptr;
                free(l);
                rebalance(p, -1);
            }
            else
            {
                ++l->p->bf;
                l->p->r = nullptr;
                free(l);
                rebalance(p, -1);
            }
        }

        return true;
    }

    int getdepth()
    {
        return GetDepth(tree->l);
    }

    bool isAVLTree()
    {
        return check(tree->l);
    }
};
int main()
{
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    myAVLTree it;

    int n;

    cin >> n;

    while (n--)
    {
        int op, x;
        cin >> op;

        if (op == 1)
        {
            cin >> x;
            AVLTree p = it.search(x);

            if (p == nullptr)
                cout << "empty" << endl;
            else
                cout << p->data << endl;
        }
        else if (op == 2)
        {
            cin >> x;

            cout << it.insert(x) << endl;
        }
        else if (op == 3)
        {
            cin >> x;

            cout << it.erase(x) << endl;
        }
    }

    cout << endl;
    system("pause");
    return 0;
}

透过这几个名字或许也能感受到我当时的想法。