Skip to content

AVL trees

Question

Define AVL trees and provide an implementation in Python.

Definition

  • AVL tree is a BST where for every node:

    • abs(height(left) - height(right)) <= 1
  • Rebalance with rotations after insert/delete.

Example

Insert 30, 20, 10

Before rebalance:
    30
   /
  20
 /
10

After right rotation:
   20
  /  \
10   30

Key trick

Store height and rebalance using local rotations.

Trap

  • Forgetting to update heights before checking balance.
  • Mixing up LL, RR, LR, RL cases.

Use

When you want guaranteed \(O(\log n)\) search/insert/delete.

Do not use

If implementation simplicity matters more than strict balance.

Idiomatic Python implementation

from dataclasses import dataclass

@dataclass
class AVLNode:
    key: int
    left: "AVLNode | None" = None
    right: "AVLNode | None" = None
    height: int = 1

def h(node):
    return node.height if node else 0

def update(node):
    node.height = 1 + max(h(node.left), h(node.right))

def balance_factor(node):
    return h(node.left) - h(node.right)

def rotate_right(y):
    x = y.left
    t2 = x.right

    x.right = y
    y.left = t2

    update(y)
    update(x)
    return x

def rotate_left(x):
    y = x.right
    t2 = y.left

    y.left = x
    x.right = t2

    update(x)
    update(y)
    return y

def rebalance(node):
    update(node)
    bf = balance_factor(node)

    if bf > 1:
        if balance_factor(node.left) < 0:
            node.left = rotate_left(node.left)
        return rotate_right(node)

    if bf < -1:
        if balance_factor(node.right) > 0:
            node.right = rotate_right(node.right)
        return rotate_left(node)

    return node

def insert(node, key):
    if not node:
        return AVLNode(key)

    if key < node.key:
        node.left = insert(node.left, key)
    elif key > node.key:
        node.right = insert(node.right, key)
    else:
        return node

    return rebalance(node)

def min_node(node):
    while node.left:
        node = node.left
    return node

def delete(node, key):
    if not node:
        return None

    if key < node.key:
        node.left = delete(node.left, key)
    elif key > node.key:
        node.right = delete(node.right, key)
    else:
        if not node.left:
            return node.right
        if not node.right:
            return node.left

        succ = min_node(node.right)
        node.key = succ.key
        node.right = delete(node.right, succ.key)

    return rebalance(node) if node else None

def search(node, key):
    while node:
        if key == node.key:
            return True
        node = node.left if key < node.key else node.right
    return False

def inorder_keys(node):
    if not node:
        return []
    return inorder_keys(node.left) + [node.key] + inorder_keys(node.right)

Pytest

import pytest

@pytest.mark.parametrize(
    "values, deletes, search_keys, expected_inorder, expected_search",
    [
        ([30, 20, 10], [], [20, 30, 5], [10, 20, 30], [True, True, False]),
        ([10, 20, 30, 40, 50, 25], [40], [25, 40], [10, 20, 25, 30, 50], [True, False]),
        ([5, 3, 7, 2, 4, 6, 8], [5], [5, 6], [2, 3, 4, 6, 7, 8], [False, True]),
    ],
)
def test_avl(values, deletes, search_keys, expected_inorder, expected_search):
    root = None
    for v in values:
        root = insert(root, v)
    for d in deletes:
        root = delete(root, d)

    assert inorder_keys(root) == expected_inorder
    assert [search(root, k) for k in search_keys] == expected_search