Skip to content

104. Maximum Depth of Binary Tree

On LeetCode ->

Reformulated question

Given the root of a binary tree, return the number of nodes on its longest root-to-leaf path.

Example:

[3, 9, 20, None, None, 15, 7] -> 3
# longest path: 3 -> 20 -> 15

Key trick

Use the recursive definition of depth:

  • empty tree => 0
  • non-empty tree => 1 + max(depth(left), depth(right))

Trap

  • Confusing depth in nodes with depth in edges.
  • Returning 1 for an empty tree instead of 0.
  • In helpers, treating 0 as missing instead of only None as missing.

Why is this question interesting?

It tests whether you can see a tree problem as a direct recurrence, and it is one of the cleanest examples of DFS recursion.

Solve the problem with idiomatic python

Solution 1: Recursive (expected solution)

from typing import Optional
from utils import TreeNode

class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        # Empty subtree has depth 0.
        if root is None:
            return 0

        # Depth is this node plus the deeper child subtree.
        return 1 + max(self.maxDepth(root.left), self.maxDepth(root.right))

Iterative DFS also works, but recursion is the most direct and idiomatic here.

Solution 2: DFS with an explicit stack

  • Recursion can be converted into iteration by storing pending work in a stack.
  • This is useful if you want more control or want to avoid recursive style.
from typing import Optional
from utils import TreeNode

def max_depth_iterative_dfs(root: Optional[TreeNode]) -> int:
    if root is None:
        return 0

    stack = [(root, 1)]
    best = 0

    while stack:
        node, depth = stack.pop()
        best = max(best, depth)

        if node.right is not None:
            stack.append((node.right, depth + 1))

        if node.left is not None:
            stack.append((node.left, depth + 1))

    return best

Solution 3: BFS level traversal

  • Breadth-first traversal naturally groups nodes by depth.
  • Counting levels directly gives the maximum depth.
from typing import Optional
from utils import TreeNode
from collections import deque

def max_depth_bfs(root: Optional[TreeNode]) -> int:
    if root is None:
        return 0

    queue = deque([root])
    depth = 0

    while queue:
        depth += 1
        for _ in range(len(queue)):
            node = queue.popleft()

            if node.left is not None:
                queue.append(node.left)

            if node.right is not None:
                queue.append(node.right)

    return depth

Pytest test

from utils import TreeNode, tree_build
import pytest

@pytest.mark.parametrize(
    ("values", "expected"),
    [
        ([], 0),
        ([1], 1),
        ([1, None, 2], 2),
        ([3, 9, 20, None, None, 15, 7], 3),
        ([1, 2, None, 3, None, 4, None], 3),
        ([1, None, 2, None, 3, None, 4], 3),
        ([0], 1),
    ],
)
def test_max_depth(values, expected):
    root = tree_build(values)
    assert Solution().maxDepth(root) == expected

Comment my solution

  • Your recursive solution is correct and is the best interview answer here.
  • Your iterative DFS solution is also correct, but it is longer and less clear than needed.
  • build_tree is fragile:
    • it says Iterable[int] but uses None
    • it treats 0 as missing because of if (val := ...)
    • it can index past the list on incomplete last levels
    • LeetCode level-order trees should be built with a queue, not with full-level index math
## Solution

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Optional


@dataclass
class TreeNode:
    val: int
    left: Optional[TreeNode] = None
    right: Optional[TreeNode] = None


def build_tree(values: Iterable[int]) -> Optional[TreeNode]:
    values = list(values)
    if not values:
        return None

    root = TreeNode(values[0])
    prev_level_nodes = [root]
    i = 1  # level in the tree
    k = 2**i - 1  # index in values of left most node of level i

    while k < len(values):
        j = 0
        new_nodes = []
        for parent_node in prev_level_nodes:
            left = TreeNode(val=val) if (val := values[k + j]) else None
            right = TreeNode(val=val) if (val := values[k + j + 1]) else None
            new_nodes.append(left)
            new_nodes.append(right)
            if parent_node:
                parent_node.left = left
                parent_node.right = right
            j += 2
        prev_level_nodes = new_nodes
        i += 1
        k = 2**i - 1
    return root


def tree_to_list(root: Optional[TreeNode]):
    if root is None:
        return []

    values = []
    cur_level_nodes = [root]
    while any(node is not None for node in cur_level_nodes):
        new_nodes = []
        for node in cur_level_nodes:
            if node is None:
                val = left = right = None
            else:
                val = node.val
                left = node.left
                right = node.right
            values.append(val)
            new_nodes.append(left)
            new_nodes.append(right)
        cur_level_nodes = new_nodes
    return values


# values = [3, 9, 20, None, None, 15, 7]
# root = build_tree(values)
# root # TreeNode(val=3, left=TreeNode(val=9, left=None, right=None), right=TreeNode(val=20, left=TreeNode(val=15, left=None, right=None), right=TreeNode(val=7, left=None, right=None)))
# tree_to_list(root) # [3, 9, 20, None, None, 15, 7]


class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        if root is None:
            return 0

        return 1 + max(self.maxDepth(root.left), self.maxDepth(root.right))


from collections import namedtuple


class Solution:
    def maxDepth(self, root: Optional[TreeNode]) -> int:
        Node = namedtuple("Node", ["node", "depth"])

        if root is None:
            return 0

        branches = []
        max_depth = 1
        cur = Node(root, 1)

        while (cur.node is not None) or (len(branches) > 0):
            if cur.node is None:
                cur = branches.pop()
                continue
            if cur.node.left is not None:
                left = cur.node.left
                right = cur.node.right
                d = cur.depth + 1
                cur = Node(left, d)
                if right is not None:
                    branches.append(Node(right, d))
                continue
            if cur.node.right is not None:
                right = cur.node.right
                d = cur.depth + 1
                cur = Node(right, d)
                continue
            # cur.node is a leaf
            max_depth = max(max_depth, cur.depth)
            cur = Node(None, 0)

        return max_depth

import pytest

@pytest.mark.parametrize(
    "values,max_depth", [([3, 9, 20, None, None, 15, 7], 3), ([1, None, 2], 2)]
)
def test_maximum_depth_of_binary_tree(values, max_depth):
    root = build_tree(values)
    solver = Solution()
    assert solver.maxDepth(root) == max_depth