Skip to content

669. Trim a Binary Search Tree

On LeetCode ->

Reformulated question

Given a BST root and bounds low, high, remove every node with value outside [low, high] and return the new root.

Keep the relative parent/descendant order of all remaining nodes unchanged.

  • Example:
    • Input: root=[3,0,4,null,2,null,null,1], low=1, high=3
    • Output: root=[3,2,null,1]

Key trick

Use the BST property to skip whole subtrees.

  • If root.val < low, the whole left subtree is too small, so answer is trimming root.right.
  • If root.val > high, the whole right subtree is too large, so answer is trimming root.left.
  • Otherwise keep root, and recursively trim both children.

Trap

Common mistakes:

  • Forgetting that when a node is out of range, you do not delete both subtrees blindly.
    • One side may still contain valid nodes and must be returned.
  • Not using the BST property.
    • A generic tree solution works but misses the main idea.
  • Dereferencing node.left or node.right without checking for None.
  • Mutating pointers incorrectly and breaking structure.

Why is this question interesting?

It is a clean BST pruning problem.

  • It tests whether you can turn ordering properties into simpler recursion.
  • The final code is short, but only if you see the subtree-skipping idea.

Solve the problem with idiomatic python

class Solution:
    def trimBST(self, root: Optional[TreeNode], low: int, high: int) -> Optional[TreeNode]:
        # Empty subtree stays empty.
        if root is None:
            return None

        # Current node and its left subtree are too small.
        if root.val < low:
            return self.trimBST(root.right, low, high)

        # Current node and its right subtree are too large.
        if root.val > high:
            return self.trimBST(root.left, low, high)

        # Current node is valid, so trim both children and keep it.
        root.left = self.trimBST(root.left, low, high)
        root.right = self.trimBST(root.right, low, high)
        return root
  • Time: \(O(n)\)
  • Space: \(O(h)\) recursion stack, where \(h\) is tree height

Pytest test

[NOT PROVIDED]

Comment my solution

Your solution should be removed and restarted.

Main issues:

  • It accesses l.val and r.val without checking None.
  • The dummy-node approach is unnecessary here.
  • There are pointer bugs:
    • q.append(l) is used when processing the right child.
    • node.right = l.left should use r, not l.
  • The queue logic does not match BST trimming rules.
  • deque and Optional are used but not imported in that snippet.
  • This problem is much simpler recursively.

The interview answer is the recursive BST-pruning solution above.

class Solution:
    def trimBST(self, root: Optional[TreeNode], low: int, high: int) -> Optional[TreeNode]:
        dummy = TreeNode(-1)
        dummy.left = TreeNode(-2)
        dummy.right = root
        q=deque([dummy]) # only node that we kept (i.e. low <= node <= high)
        while q:
            node = q.popleft()
            r = node.right
            l = node.left

            # left of node
            if l.val >= low:
                q.append(l)
            else:
                # we remove l from the tree
                node.left = l.right
                if l.right is not None:
                    q.append(node) # we do it again with node

            # righ of node
            if r.val <= high:
                q.append(l)
            else:
                # we remove r from the tree
                node.right = l.left
                if r.right is not None:
                    q.append(node) # we do it again with node

        return dummy.right

# solution after reading AI comments
class Solution:
    def trimBST(self, root: Optional[TreeNode], low: int, high: int) -> Optional[TreeNode]:
        # if root.val < low   -> trim root.right
        # if root.val > high  -> trim root.left
        # keep root and recurse on root.left and root.right
        if root is None:
            return None
        if root.val < low:
            return self.trimBST(root.right,low,high)
        if root.val > high:
            return self.trimBST(root.left,low, high)
        root.left = self.trimBST(root.left,low,high)
        root.right = self.trimBST(root.right,low,high)
        return root