Skip to content

104. Maximum Depth of Binary Tree

On LeetCode ->

Reformulated question

Given the root of a binary tree, return the number of nodes on its longest root-to-leaf path.

Example:

[3, 9, 20, None, None, 15, 7] -> 3
# longest path: 3 -> 20 -> 15

Key trick

Use the recursive definition of depth:

  • empty tree => 0
  • non-empty tree => 1 + max(depth(left), depth(right))

Trap

  • Confusing depth in nodes with depth in edges.
  • Returning 1 for an empty tree instead of 0.
  • In helpers, treating 0 as missing instead of only None as missing.

Why is this question interesting?

It tests whether you can see a tree problem as a direct recurrence, and it is one of the cleanest examples of DFS recursion.

Solve the problem with idiomatic python

Solution 1: Recursive (expected solution)

class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        # Empty subtree has depth 0.
        if root is None:
            return 0

        # Depth is this node plus the deeper child subtree.
        return 1 + max(self.maxDepth(root.left), self.maxDepth(root.right))

Iterative DFS also works, but recursion is the most direct and idiomatic here.

Solution 2: DFS with an explicit stack

  • Recursion can be converted into iteration by storing pending work in a stack.
  • This is useful if you want more control or want to avoid recursive style.
def max_depth_iterative_dfs(root: Optional[TreeNode]) -> int:
    if root is None:
        return 0

    stack = [(root, 1)]
    best = 0

    while stack:
        node, depth = stack.pop()
        best = max(best, depth)

        if node.right is not None:
            stack.append((node.right, depth + 1))

        if node.left is not None:
            stack.append((node.left, depth + 1))

    return best

Solution 3: BFS level traversal

  • Breadth-first traversal naturally groups nodes by depth.
  • Counting levels directly gives the maximum depth.
def max_depth_bfs(root: Optional[TreeNode]) -> int:
    if root is None:
        return 0

    queue = deque([root])
    depth = 0

    while queue:
        depth += 1
        for _ in range(len(queue)):
            node = queue.popleft()

            if node.left is not None:
                queue.append(node.left)

            if node.right is not None:
                queue.append(node.right)

    return depth

Pytest test

@pytest.mark.parametrize(
    ("values", "expected"),
    [
        ([], 0),
        ([1], 1),
        ([1, None, 2], 2),
        ([3, 9, 20, None, None, 15, 7], 3),
        ([1, 2, None, 3, None, 4, None], 4),
        ([1, None, 2, None, 3, None, 4], 4),
        ([0], 1),
    ],
)
def test_max_depth(values, expected):
    root = build_tree(values)
    assert Solution().maxDepth(root) == expected

Comment my solution

  • Your recursive solution is correct and is the best interview answer here.
  • Your iterative DFS solution is also correct, but it is longer and less clear than needed.
  • build_tree is fragile:
  • it says Iterable[int] but uses None
  • it treats 0 as missing because of if (val := ...)
  • it can index past the list on incomplete last levels
  • LeetCode level-order trees should be built with a queue, not with full-level index math
## Solution

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Optional


@dataclass
class TreeNode:
    val: int
    left: Optional[TreeNode] = None
    right: Optional[TreeNode] = None


def build_tree(values: Iterable[int]) -> Optional[TreeNode]:
    values = list(values)
    if not values:
        return None

    root = TreeNode(values[0])
    prev_level_nodes = [root]
    i = 1  # level in the tree
    k = 2**i - 1  # index in values of left most node of level i

    while k < len(values):
        j = 0
        new_nodes = []
        for parent_node in prev_level_nodes:
            left = TreeNode(val=val) if (val := values[k + j]) else None
            right = TreeNode(val=val) if (val := values[k + j + 1]) else None
            new_nodes.append(left)
            new_nodes.append(right)
            if parent_node:
                parent_node.left = left
                parent_node.right = right
            j += 2
        prev_level_nodes = new_nodes
        i += 1
        k = 2**i - 1
    return root


def tree_to_list(root: Optional[TreeNode]):
    if root is None:
        return []

    values = []
    cur_level_nodes = [root]
    while any(node is not None for node in cur_level_nodes):
        new_nodes = []
        for node in cur_level_nodes:
            if node is None:
                val = left = right = None
            else:
                val = node.val
                left = node.left
                right = node.right
            values.append(val)
            new_nodes.append(left)
            new_nodes.append(right)
        cur_level_nodes = new_nodes
    return values


# values = [3, 9, 20, None, None, 15, 7]
# root = build_tree(values)
# root # TreeNode(val=3, left=TreeNode(val=9, left=None, right=None), right=TreeNode(val=20, left=TreeNode(val=15, left=None, right=None), right=TreeNode(val=7, left=None, right=None)))
# tree_to_list(root) # [3, 9, 20, None, None, 15, 7]


class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        if root is None:
            return 0

        return 1 + max(self.maxDepth(root.left), self.maxDepth(root.right))


from collections import namedtuple


class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        Node = namedtuple("Node", ["node", "depth"])

        if root is None:
            return 0

        branches = []
        max_depth = 1
        cur = Node(root, 1)

        while (cur.node is not None) or (len(branches) > 0):
            if cur.node is None:
                cur = branches.pop()
                continue
            if cur.node.left is not None:
                left = cur.node.left
                right = cur.node.right
                d = cur.depth + 1
                cur = Node(left, d)
                if right is not None:
                    branches.append(Node(right, d))
                continue
            if cur.node.right is not None:
                right = cur.node.right
                d = cur.depth + 1
                cur = Node(right, d)
                continue
            # cur.node is a leaf
            max_depth = max(max_depth, cur.depth)
            cur = Node(None, 0)

        return max_depth

import pytest

@pytest.mark.parametrize(
    "values,max_depth", [([3, 9, 20, None, None, 15, 7], 3), ([1, None, 2], 2)]
)
def test_maximum_depth_of_binary_tree(values, max_depth):
    root = build_tree(values)
    solver = Solution()
    assert solver.maxDepth(root) == max_depth

Code

from __future__ import annotations

from collections import deque
from dataclasses import dataclass
from typing import Optional

import pytest


@dataclass
class TreeNode:
    val: int
    left: Optional["TreeNode"] = None
    right: Optional["TreeNode"] = None


def build_tree(values: list[Optional[int]]) -> Optional[TreeNode]:
    # [5, 4, 9, None, None, 8, 11]
    #
    #     5
    #    / \
    #   4   9
    #      / \
    #     8   11
    if not values or values[0] is None:
        return None

    nodes = [None if v is None else TreeNode(v) for v in values]

    for i, node in enumerate(nodes):
        if node is None:
            continue

        left_i = 2 * i + 1
        right_i = 2 * i + 2

        if left_i < len(nodes):
            node.left = nodes[left_i]

        if right_i < len(nodes):
            node.right = nodes[right_i]

    return nodes[0]


class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        # Empty subtree has depth 0.
        if root is None:
            return 0

        # Depth is this node plus the deeper child subtree.
        return 1 + max(self.maxDepth(root.left), self.maxDepth(root.right))


@pytest.mark.parametrize(
    ("values", "expected"),
    [
        ([], 0),
        ([1], 1),
        ([1, None, 2], 2),
        ([3, 9, 20, None, None, 15, 7], 3),
        ([1, 2, None, 3, None, 4, None], 4),
        ([1, None, 2, None, 3, None, 4], 4),
        ([0], 1),
    ],
)
def test_max_depth(values, expected):
    root = build_tree(values)
    assert Solution().maxDepth(root) == expected


build_tree([3, 9, 20, None, None, 15, 7])
Solution().maxDepth(build_tree([3, 9, 20, None, None, 15, 7]))
Solution().maxDepth(build_tree([1, None, 2]))
Solution().maxDepth(build_tree([]))