Skip to content

15. 3Sum

On LeetCode ->

Reformulated question

Find all unique value triplets in nums whose sum is 0.

Compact example:

nums = [-1, 0, 1, 2, -1, -4]
-> [[-1, -1, 2], [-1, 0, 1]]
# unique triplets only, order does not matter

Key trick

Sort first, then fix one number and use two pointers on the rest.

  • Sorting lets you:
    • skip duplicates easily
    • move pointers by sum comparison
  • This gives \(O(n^2)\) instead of checking all triples or doing expensive duplicate checks.

Trap

Common mistakes:

  • Not sorting first.
  • Returning duplicate triplets.
  • Forgetting to skip duplicate anchors nums[i].
  • Forgetting to skip duplicates after finding a valid triplet.
  • Using triplet not in result, which makes the solution much slower.
  • Building all pairs with hashes and ending up near \(O(n^3)\) in practice.

Why is this question interesting?

It is a classic pattern-compression problem.

  • Brute force is obvious.
  • The real skill is spotting how sorting turns 3-sum into repeated 2-sum with two pointers.
  • It tests duplicates handling, time complexity, and clean implementation.

Solve the problem with idiomatic python

class Solution:
    def threeSum(self, nums: list[int]) -> list[list[int]]:
        nums.sort()
        res: list[list[int]] = []
        n = len(nums)

        for i in range(n - 2):
            # Same anchor value would generate duplicate triplets.
            if i > 0 and nums[i] == nums[i - 1]:
                continue

            # Since array is sorted, if anchor is already > 0, sum cannot be 0.
            if nums[i] > 0:
                break

            left, right = i + 1, n - 1

            while left < right:
                s = nums[i] + nums[left] + nums[right]

                if s < 0:
                    left += 1
                elif s > 0:
                    right -= 1
                else:
                    res.append([nums[i], nums[left], nums[right]])

                    left += 1
                    right -= 1

                    # Skip duplicate second values.
                    while left < right and nums[left] == nums[left - 1]:
                        left += 1

                    # Skip duplicate third values.
                    while left < right and nums[right] == nums[right + 1]:
                        right -= 1

        return res

Pytest test

import pytest


def normalize(triplets: list[list[int]]) -> list[tuple[int, int, int]]:
    return sorted(tuple(sorted(t)) for t in triplets)


@pytest.mark.parametrize(
    ("nums", "expected"),
    [
        ([-1, 0, 1, 2, -1, -4], [[-1, -1, 2], [-1, 0, 1]]),
        ([0, 1, 1], []),
        ([0, 0, 0], [[0, 0, 0]]),
        ([0, 0, 0, 0], [[0, 0, 0]]),
        ([-2, 0, 0, 2, 2], [[-2, 0, 2]]),
        ([-1, -1, -1, 2, 2], [[-1, -1, 2]]),
        ([-2, -1, 0, 1, 2, 3], [[-2, -1, 3], [-2, 0, 2], [-1, 0, 1]]),
        ([1, 2, -2, -1], []),
    ],
)
def test_three_sum(nums, expected):
    assert normalize(Solution().threeSum(nums)) == normalize(expected)

Comment my solution

Your solution is correct in idea but too expensive.

  • You scan all pairs \((i, j)\).
  • For each pair, you may scan many candidate k.
  • Then triplet not in triplets is another linear scan.
  • Sorting each found triplet also adds repeated extra work.

So the practical cost becomes much worse than \(O(n^2)\), which is why it times out.

Good parts:

  • You correctly reduced the condition to nums[k] = -(nums[i] + nums[j]).
  • You correctly enforced distinct indices.
  • You correctly normalized triplets before deduplicating.

Main interview improvement:

  • Replace hash-of-indices plus result-membership checks with sorted array + two pointers.
  • That gives clean duplicate handling and the expected \(O(n^2)\) time.
# Works
# But refused on LeetCode: Time Limit Exceeded
class Solution:
    def threeSum(self, nums: list[int]) -> list[list[int]]:
        # one pass to do reverse hash
        nums_to_index = {}
        for i,num in enumerate(nums):
            if num in nums_to_index:
                nums_to_index[num].append(i)
            else:
                nums_to_index[num] = [i]

        # nums[i] + nums[j] + nums[k] == 0
        # <=>
        # -(nums[i] + nums[j]) == nums[k]

        # We scan each pair (i, j) with i < j
        # and look for -(nums[i] + nums[j]) in nums_to_index
        triplets = []
        n = len(nums)
        for i in range(n):
            for j in range(i + 1, n):
                num = -(nums[i] + nums[j])
                if num in nums_to_index:
                    for k in nums_to_index[num]:
                        if k != i and k != j:
                            triplet = sorted([nums[i],nums[j],nums[k]])
                            if triplet not in triplets:
                                triplets.append(triplet)
        return triplets

Extra

Complexity

The outer loop runs \(O(n)\) times.

For each i, the inner two-pointer scan is \(O(n)\).

So total is \(O(n) \times O(n) = O(n^2)\)

Then include sorting \(O(n \log n) + O(n^2) = O(n^2)\) because \(O(n^2)\) dominates \(O(n \log n)\).