204. Count Primes
On LeetCode ->Reformulated question¶
Count how many prime numbers are < n.
- Example:
n = 10 -> 4because primes< 10are[2, 3, 5, 7]
Key trick¶
Use the Sieve of Eratosthenes.
- Instead of testing each number by dividing by previous primes, mark multiples of each prime as non-prime.
- Start marking at
p * p, because smaller multiples were already handled.
Trap¶
- Counting primes
<= ninstead of< n. - Forgetting edge cases like
n <= 2. - Starting crossing out at
2 * p, which is correct but slower. - Using repeated primality checks, which is too slow for
nup to5 * 10^6.
Why this question is interesting¶
It tests whether you recognize when a brute-force "check each number" approach should be replaced by preprocessing.
- It is a classic jump from local checking to global marking.
- It also checks time complexity awareness: \(O(n \log \log n)\) vs much slower trial division.
Solve the problem with idiomatic Python¶
class Solution:
def countPrimes(self, n: int) -> int:
# No primes strictly below 2.
if n <= 2:
return 0
# Assume all numbers are prime first.
is_prime = [True] * n
is_prime[0] = is_prime[1] = False
p = 2
while p * p < n:
if is_prime[p]:
# Mark multiples of p starting from p*p.
for multiple in range(p * p, n, p):
is_prime[multiple] = False
p += 1
return sum(is_prime)
This implementation of Sieve of Erathostenes breaks for high numbers
n. Here's some test I made:
# (Tony Aldon) My tests of the AI generated implementation
# It passes tests on leetcode but errors or crashes the python process
# for n bigger than 2^31 - 1.
# This is due to the array allocation of length n which exhausts
# my RAM (only 16G) and this allocation needs more than 10G
# Solution().countPrimes(2**32 - 1)
Solution().countPrimes(2**28 - 1) # 14630843
Solution().countPrimes(2**29 - 1) # 28192750
Solution().countPrimes(2**30 - 1) # 54400028
Solution().countPrimes(2**31 - 1) # Python crashes
Solution().countPrimes(2**32 - 1)
# Traceback (most recent call last):
# File "<string>", line 19, in __PYTHON_EL_eval
# File "/home/tony/work/tech-interviews/questions/leetcode_top_easy_08_math_02_count_primes.py", line 99, in <module>
# Solution().countPrimes(2**32 - 1)
# ~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^
# File "/home/tony/work/tech-interviews/questions/leetcode_top_easy_08_math_02_count_primes.py", line 55, in countPrimes
# is_prime = [True] * n
# ~~~~~~~^~~
# MemoryError
is_prime = [True] * (2**31 - 1) # Python crashes
is_prime = [True] * (2**32 - 1)
# Traceback (most recent call last):
# File "<string>", line 17, in __PYTHON_EL_eval
# File "/home/tony/work/tech-interviews/questions/leetcode_top_easy_08_math_02_count_primes.py", line 116, in <module>
# is_prime = [True] * (2**32 - 1)
# ~~~~~~~^~~~~~~~~~~~~
# MemoryError
import psutil
import os
import gc
p = psutil.Process(os.getpid())
p.memory_info().rss // 10**6 # 21M
is_prime = [True] * (2**30 - 1)
p.memory_info().rss // 10**6 # 8611M
del is_prime
gc.collect()
p.memory_info().rss // 10**6 # 21M
is_prime = bytearray(b"\x01") * 3
is_prime # bytearray(b'\x01\x01\x01')
is_prime[0] # 1
type(is_prime[0]) # <class 'int'>
[k for k in is_prime] # [1, 1, 1]
is_prime = is_prime + b"\x00"
[k for k in is_prime] # [1, 1, 1, 0]
Pytest test¶
import pytest
@pytest.mark.parametrize(
("n", "expected"),
[
(0, 0),
(1, 0),
(2, 0),
(3, 1),
(4, 2),
(5, 2),
(10, 4),
(20, 8),
(100, 25),
],
)
def test_count_primes(n, expected):
assert Solution().countPrimes(n) == expected
Comment my solution¶
Your solution is correct but too slow.
- For each
k, you test divisibility by many previous primes. - You also keep checking primes larger than \(\sqrt{k}\), which is unnecessary.
- Even with that optimization, trial division is still slower than the sieve for this problem.
A small improvement to your idea would be:
- Stop checking when
p * p > k.
But the expected interview answer here is still the sieve.
import pytest
# Maybe ok for small value of `n`
# Error on leetcode: Time Limited Exceeded
class Solution:
def countPrimes(self, n: int) -> int:
if n in {0,1,2}:
return 0
primes = [2]
for k in range(3, n):
k_is_prime = True
for p in primes:
if k % p == 0:
k_is_prime = False
break
if k_is_prime:
primes.append(k)
return len(primes)
@pytest.mark.parametrize(
("n", "count"),
[(10, 4), (2, 0), (3, 1)]
)
def test_countPrimes(n, count):
assert Solution().countPrimes(n) == count
Code¶
import pytest
class Solution:
def countPrimes(self, n: int) -> int:
# No primes strictly below 2.
if n <= 2:
return 0
# Sieve of Eratosthenes:
# is_prime[x] tells whether x is still considered prime.
is_prime = [True] * n
is_prime[0] = is_prime[1] = False
p = 2
while p * p < n:
if is_prime[p]:
# All smaller multiples of p were already handled
# by smaller prime factors, so start at p*p.
for multiple in range(p * p, n, p):
is_prime[multiple] = False
p += 1
return sum(is_prime)
@pytest.mark.parametrize(
("n", "expected"),
[
(0, 0),
(1, 0),
(2, 0),
(3, 1),
(4, 2),
(5, 2),
(10, 4),
(20, 8),
(100, 25),
],
)
def test_count_primes(n, expected):
assert Solution().countPrimes(n) == expected
Solution().countPrimes(10) # 4
Solution().countPrimes(0) # 0
Solution().countPrimes(1) # 0
Solution().countPrimes(100) # 25