104. Maximum Depth of Binary Tree
On LeetCode ->Reformulated question¶
Given the root of a binary tree, return the number of nodes on its longest root-to-leaf path.
Example:
Key trick¶
Use the recursive definition of depth:
- empty tree =>
0 - non-empty tree =>
1 + max(depth(left), depth(right))
Trap¶
- Confusing depth in nodes with depth in edges.
- Returning
1for an empty tree instead of0. - In helpers, treating
0as missing instead of onlyNoneas missing.
Why is this question interesting?¶
It tests whether you can see a tree problem as a direct recurrence, and it is one of the cleanest examples of DFS recursion.
Solve the problem with idiomatic python¶
Solution 1: Recursive (expected solution)¶
class Solution:
def maxDepth(self, root: Optional[TreeNode]) -> int:
# Empty subtree has depth 0.
if root is None:
return 0
# Depth is this node plus the deeper child subtree.
return 1 + max(self.maxDepth(root.left), self.maxDepth(root.right))
Iterative DFS also works, but recursion is the most direct and idiomatic here.
Solution 2: DFS with an explicit stack¶
- Recursion can be converted into iteration by storing pending work in a stack.
- This is useful if you want more control or want to avoid recursive style.
def max_depth_iterative_dfs(root: Optional[TreeNode]) -> int:
if root is None:
return 0
stack = [(root, 1)]
best = 0
while stack:
node, depth = stack.pop()
best = max(best, depth)
if node.right is not None:
stack.append((node.right, depth + 1))
if node.left is not None:
stack.append((node.left, depth + 1))
return best
Solution 3: BFS level traversal¶
- Breadth-first traversal naturally groups nodes by depth.
- Counting levels directly gives the maximum depth.
def max_depth_bfs(root: Optional[TreeNode]) -> int:
if root is None:
return 0
queue = deque([root])
depth = 0
while queue:
depth += 1
for _ in range(len(queue)):
node = queue.popleft()
if node.left is not None:
queue.append(node.left)
if node.right is not None:
queue.append(node.right)
return depth
Pytest test¶
@pytest.mark.parametrize(
("values", "expected"),
[
([], 0),
([1], 1),
([1, None, 2], 2),
([3, 9, 20, None, None, 15, 7], 3),
([1, 2, None, 3, None, 4, None], 4),
([1, None, 2, None, 3, None, 4], 4),
([0], 1),
],
)
def test_max_depth(values, expected):
root = build_tree(values)
assert Solution().maxDepth(root) == expected
Comment my solution¶
- Your recursive solution is correct and is the best interview answer here.
- Your iterative DFS solution is also correct, but it is longer and less clear than needed.
build_treeis fragile:- it says
Iterable[int]but usesNone - it treats
0as missing because ofif (val := ...) - it can index past the list on incomplete last levels
- LeetCode level-order trees should be built with a queue, not with full-level index math
## Solution
from __future__ import annotations
from dataclasses import dataclass
from typing import Iterable, Optional
@dataclass
class TreeNode:
val: int
left: Optional[TreeNode] = None
right: Optional[TreeNode] = None
def build_tree(values: Iterable[int]) -> Optional[TreeNode]:
values = list(values)
if not values:
return None
root = TreeNode(values[0])
prev_level_nodes = [root]
i = 1 # level in the tree
k = 2**i - 1 # index in values of left most node of level i
while k < len(values):
j = 0
new_nodes = []
for parent_node in prev_level_nodes:
left = TreeNode(val=val) if (val := values[k + j]) else None
right = TreeNode(val=val) if (val := values[k + j + 1]) else None
new_nodes.append(left)
new_nodes.append(right)
if parent_node:
parent_node.left = left
parent_node.right = right
j += 2
prev_level_nodes = new_nodes
i += 1
k = 2**i - 1
return root
def tree_to_list(root: Optional[TreeNode]):
if root is None:
return []
values = []
cur_level_nodes = [root]
while any(node is not None for node in cur_level_nodes):
new_nodes = []
for node in cur_level_nodes:
if node is None:
val = left = right = None
else:
val = node.val
left = node.left
right = node.right
values.append(val)
new_nodes.append(left)
new_nodes.append(right)
cur_level_nodes = new_nodes
return values
# values = [3, 9, 20, None, None, 15, 7]
# root = build_tree(values)
# root # TreeNode(val=3, left=TreeNode(val=9, left=None, right=None), right=TreeNode(val=20, left=TreeNode(val=15, left=None, right=None), right=TreeNode(val=7, left=None, right=None)))
# tree_to_list(root) # [3, 9, 20, None, None, 15, 7]
class Solution:
def maxDepth(self, root: Optional[TreeNode]) -> int:
if root is None:
return 0
return 1 + max(self.maxDepth(root.left), self.maxDepth(root.right))
from collections import namedtuple
class Solution:
def maxDepth(self, root: Optional[TreeNode]) -> int:
Node = namedtuple("Node", ["node", "depth"])
if root is None:
return 0
branches = []
max_depth = 1
cur = Node(root, 1)
while (cur.node is not None) or (len(branches) > 0):
if cur.node is None:
cur = branches.pop()
continue
if cur.node.left is not None:
left = cur.node.left
right = cur.node.right
d = cur.depth + 1
cur = Node(left, d)
if right is not None:
branches.append(Node(right, d))
continue
if cur.node.right is not None:
right = cur.node.right
d = cur.depth + 1
cur = Node(right, d)
continue
# cur.node is a leaf
max_depth = max(max_depth, cur.depth)
cur = Node(None, 0)
return max_depth
import pytest
@pytest.mark.parametrize(
"values,max_depth", [([3, 9, 20, None, None, 15, 7], 3), ([1, None, 2], 2)]
)
def test_maximum_depth_of_binary_tree(values, max_depth):
root = build_tree(values)
solver = Solution()
assert solver.maxDepth(root) == max_depth
Code¶
from __future__ import annotations
from collections import deque
from dataclasses import dataclass
from typing import Optional
import pytest
@dataclass
class TreeNode:
val: int
left: Optional["TreeNode"] = None
right: Optional["TreeNode"] = None
def build_tree(values: list[Optional[int]]) -> Optional[TreeNode]:
# [5, 4, 9, None, None, 8, 11]
#
# 5
# / \
# 4 9
# / \
# 8 11
if not values or values[0] is None:
return None
nodes = [None if v is None else TreeNode(v) for v in values]
for i, node in enumerate(nodes):
if node is None:
continue
left_i = 2 * i + 1
right_i = 2 * i + 2
if left_i < len(nodes):
node.left = nodes[left_i]
if right_i < len(nodes):
node.right = nodes[right_i]
return nodes[0]
class Solution:
def maxDepth(self, root: Optional[TreeNode]) -> int:
# Empty subtree has depth 0.
if root is None:
return 0
# Depth is this node plus the deeper child subtree.
return 1 + max(self.maxDepth(root.left), self.maxDepth(root.right))
@pytest.mark.parametrize(
("values", "expected"),
[
([], 0),
([1], 1),
([1, None, 2], 2),
([3, 9, 20, None, None, 15, 7], 3),
([1, 2, None, 3, None, 4, None], 4),
([1, None, 2, None, 3, None, 4], 4),
([0], 1),
],
)
def test_max_depth(values, expected):
root = build_tree(values)
assert Solution().maxDepth(root) == expected
build_tree([3, 9, 20, None, None, 15, 7])
Solution().maxDepth(build_tree([3, 9, 20, None, None, 15, 7]))
Solution().maxDepth(build_tree([1, None, 2]))
Solution().maxDepth(build_tree([]))