15. 3Sum
On LeetCode ->Reformulated question¶
Find all unique value triplets in nums whose sum is 0.
Compact example:
nums = [-1, 0, 1, 2, -1, -4]
-> [[-1, -1, 2], [-1, 0, 1]]
# unique triplets only, order does not matter
Key trick¶
Sort first, then fix one number and use two pointers on the rest.
- Sorting lets you:
- skip duplicates easily
- move pointers by sum comparison
- This gives \(O(n^2)\) instead of checking all triples or doing expensive duplicate checks.
Trap¶
Common mistakes:
- Not sorting first.
- Returning duplicate triplets.
- Forgetting to skip duplicate anchors
nums[i]. - Forgetting to skip duplicates after finding a valid triplet.
- Using
triplet not in result, which makes the solution much slower. - Building all pairs with hashes and ending up near \(O(n^3)\) in practice.
Why is this question interesting?¶
It is a classic pattern-compression problem.
- Brute force is obvious.
- The real skill is spotting how sorting turns 3-sum into repeated 2-sum with two pointers.
- It tests duplicates handling, time complexity, and clean implementation.
Solve the problem with idiomatic python¶
class Solution:
def threeSum(self, nums: list[int]) -> list[list[int]]:
nums.sort()
res: list[list[int]] = []
n = len(nums)
for i in range(n - 2):
# Same anchor value would generate duplicate triplets.
if i > 0 and nums[i] == nums[i - 1]:
continue
# Since array is sorted, if anchor is already > 0, sum cannot be 0.
if nums[i] > 0:
break
left, right = i + 1, n - 1
while left < right:
s = nums[i] + nums[left] + nums[right]
if s < 0:
left += 1
elif s > 0:
right -= 1
else:
res.append([nums[i], nums[left], nums[right]])
left += 1
right -= 1
# Skip duplicate second values.
while left < right and nums[left] == nums[left - 1]:
left += 1
# Skip duplicate third values.
while left < right and nums[right] == nums[right + 1]:
right -= 1
return res
Pytest test¶
import pytest
def normalize(triplets: list[list[int]]) -> list[tuple[int, int, int]]:
return sorted(tuple(sorted(t)) for t in triplets)
@pytest.mark.parametrize(
("nums", "expected"),
[
([-1, 0, 1, 2, -1, -4], [[-1, -1, 2], [-1, 0, 1]]),
([0, 1, 1], []),
([0, 0, 0], [[0, 0, 0]]),
([0, 0, 0, 0], [[0, 0, 0]]),
([-2, 0, 0, 2, 2], [[-2, 0, 2]]),
([-1, -1, -1, 2, 2], [[-1, -1, 2]]),
([-2, -1, 0, 1, 2, 3], [[-2, -1, 3], [-2, 0, 2], [-1, 0, 1]]),
([1, 2, -2, -1], []),
],
)
def test_three_sum(nums, expected):
assert normalize(Solution().threeSum(nums)) == normalize(expected)
Comment my solution¶
Your solution is correct in idea but too expensive.
- You scan all pairs \((i, j)\).
- For each pair, you may scan many candidate
k. - Then
triplet not in tripletsis another linear scan. - Sorting each found triplet also adds repeated extra work.
So the practical cost becomes much worse than \(O(n^2)\), which is why it times out.
Good parts:
- You correctly reduced the condition to
nums[k] = -(nums[i] + nums[j]). - You correctly enforced distinct indices.
- You correctly normalized triplets before deduplicating.
Main interview improvement:
- Replace hash-of-indices plus result-membership checks with sorted array + two pointers.
- That gives clean duplicate handling and the expected \(O(n^2)\) time.
# Works
# But refused on LeetCode: Time Limit Exceeded
class Solution:
def threeSum(self, nums: list[int]) -> list[list[int]]:
# one pass to do reverse hash
nums_to_index = {}
for i,num in enumerate(nums):
if num in nums_to_index:
nums_to_index[num].append(i)
else:
nums_to_index[num] = [i]
# nums[i] + nums[j] + nums[k] == 0
# <=>
# -(nums[i] + nums[j]) == nums[k]
# We scan each pair (i, j) with i < j
# and look for -(nums[i] + nums[j]) in nums_to_index
triplets = []
n = len(nums)
for i in range(n):
for j in range(i + 1, n):
num = -(nums[i] + nums[j])
if num in nums_to_index:
for k in nums_to_index[num]:
if k != i and k != j:
triplet = sorted([nums[i],nums[j],nums[k]])
if triplet not in triplets:
triplets.append(triplet)
return triplets
Extra¶
Complexity¶
The outer loop runs \(O(n)\) times.
For each i, the inner two-pointer scan is \(O(n)\).
So total is \(O(n) \times O(n) = O(n^2)\)
Then include sorting \(O(n \log n) + O(n^2) = O(n^2)\) because \(O(n^2)\) dominates \(O(n \log n)\).