AVL trees
Question¶
Define AVL trees and provide an implementation in Python.
Definition¶
-
AVL tree is a BST where for every node:
abs(height(left) - height(right)) <= 1
-
Rebalance with rotations after insert/delete.
Example¶
Key trick¶
Store height and rebalance using local rotations.
Trap¶
- Forgetting to update heights before checking balance.
- Mixing up LL, RR, LR, RL cases.
Use¶
When you want guaranteed \(O(\log n)\) search/insert/delete.
Do not use¶
If implementation simplicity matters more than strict balance.
Idiomatic Python implementation¶
from dataclasses import dataclass
@dataclass
class AVLNode:
key: int
left: "AVLNode | None" = None
right: "AVLNode | None" = None
height: int = 1
def h(node):
return node.height if node else 0
def update(node):
node.height = 1 + max(h(node.left), h(node.right))
def balance_factor(node):
return h(node.left) - h(node.right)
def rotate_right(y):
x = y.left
t2 = x.right
x.right = y
y.left = t2
update(y)
update(x)
return x
def rotate_left(x):
y = x.right
t2 = y.left
y.left = x
x.right = t2
update(x)
update(y)
return y
def rebalance(node):
update(node)
bf = balance_factor(node)
if bf > 1:
if balance_factor(node.left) < 0:
node.left = rotate_left(node.left)
return rotate_right(node)
if bf < -1:
if balance_factor(node.right) > 0:
node.right = rotate_right(node.right)
return rotate_left(node)
return node
def insert(node, key):
if not node:
return AVLNode(key)
if key < node.key:
node.left = insert(node.left, key)
elif key > node.key:
node.right = insert(node.right, key)
else:
return node
return rebalance(node)
def min_node(node):
while node.left:
node = node.left
return node
def delete(node, key):
if not node:
return None
if key < node.key:
node.left = delete(node.left, key)
elif key > node.key:
node.right = delete(node.right, key)
else:
if not node.left:
return node.right
if not node.right:
return node.left
succ = min_node(node.right)
node.key = succ.key
node.right = delete(node.right, succ.key)
return rebalance(node) if node else None
def search(node, key):
while node:
if key == node.key:
return True
node = node.left if key < node.key else node.right
return False
def inorder_keys(node):
if not node:
return []
return inorder_keys(node.left) + [node.key] + inorder_keys(node.right)
Pytest¶
import pytest
@pytest.mark.parametrize(
"values, deletes, search_keys, expected_inorder, expected_search",
[
([30, 20, 10], [], [20, 30, 5], [10, 20, 30], [True, True, False]),
([10, 20, 30, 40, 50, 25], [40], [25, 40], [10, 20, 25, 30, 50], [True, False]),
([5, 3, 7, 2, 4, 6, 8], [5], [5, 6], [2, 3, 4, 6, 7, 8], [False, True]),
],
)
def test_avl(values, deletes, search_keys, expected_inorder, expected_search):
root = None
for v in values:
root = insert(root, v)
for d in deletes:
root = delete(root, d)
assert inorder_keys(root) == expected_inorder
assert [search(root, k) for k in search_keys] == expected_search