Generating Primes

2019-12-10 • edited 2020-03-21

This post is motivated by a Computerphile video by Graham Hutton

His solution is very elegant.

Haskell Solution from the video

primes = sieve [2..]

sieve (p:ps) = p : sieve [x | x <- ps, mod x p /= 0]

twin (x, y) = x + 2 == y

twins = filter twin (zip primes (tail primes))

factors n = [ x | x <- [1..n], mod n x == 0]

is_prime n = factors n == [1, n]

naive_primes n = filter is_prime [1..n]

You can simply write it down on a piece of a napkin.

We can very easily translate it into Python thanks for list comprehensions.

Python translation

We set up natural numbers by simply using the range function.

def tail(l):
    return l[1:]

def natural_numbers(upto, start=0):
    assert start >= 0, 'natural numbers start at 0'
    return range(start, upto+1)

def take(n):
    """Take n natural numbers starting from 2"""
    return natural_numbers(upto=n, start=2)

def factors(n): 
    return [ x for x in natural_numbers(n, start=1) if n % x == 0 ]

def is_prime(n): 
    return factors(n) == [1, n]

The easiest way is to brute-force the whole list and check if each one is prime or not.

def naive_list_primes(n): 
    return [ _ for _ in take(n) if is_prime(_) ]

As you may expect, it is really slow.

It took 3 seconds to iterate though 10000 numbers.

%timeit naive_list_primes(10000)
3.22 s ± 195 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

When we added inputs to 100,000, it took 5 minutes.

%timeit naive_list_primes(100000)
5min 42s ± 7.46 s per loop (mean ± std. dev. of 7 runs, 1 loop each)

Eratosthenes to the rescue

The idea of Eratosthenes's sieve is instead of going through the whole list, we can delete composites associated to the number each time when we move to the next number.

def sieve(xs):
    xs = list(xs)
    if not xs: return []
    head, tail = xs[0], xs[1:]
    return [head] + sieve([ x for x in tail if x % head != 0 ])

It's faster than the naive method above, but will exit the interpreter when the inputs are larger.

%timeit sieve(take(10000))
66.6 ms ± 4.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

We can write it in tail-recursive form.

def sieve_tail(xs, acc=None):
    xs = list(xs)
    if acc is None:
        acc = []
    if not xs: return acc
    head, tail = xs[0], xs[1:]
    return sieve_tail([ x for x in tail if x % head != 0 ], acc + [head])

As you might expect, it still hungs up very quciky when inputs are larger, we can instead trampolinize(more on this at [1]) the computation

def sieve_tail_trampolinized(xs, acc=None):
    xs = list(xs)
    if acc is None:
        acc = []
    if not xs: return result(acc)
    head, tail = xs[0], xs[1:]
    return call(sieve_tail)([ x for x in tail if x % head != 0 ], acc + [head])

sieve_tail_trampolinized = with_trampoline(sieve_tail_trampolinized)

This time it won't just hung up, and here are the stats

%timeit sieve_tail_trampolinized(take(100000))
4.24 s ± 255 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This is far from ideal, since CPython does not support automatic tail-recursion optimization, we better find an iterative version.

Instead of keep lists in the process, we can just mark which numbers are prime/nonprime.

def sieve_iter(n):

    is_prime = [True] * n
    answers = [2]

    #: we skip all the 2s

    for i in range(3, n, 2):
        if is_prime[i]:
            answers.append(i)
            for j in range(2 * i, n, i):
                is_prime[j] = False

    return answers
        

 %timeit sieve_iter(10000)
1.59 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit sieve_iter(100000)
18.8 ms ± 882 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

As you see, it is a lot faster than the naive approach.

Make it faster

We can borrow power from numba

from numba import jit

@jit(nopython=True)
def sieve_iter(n):

    is_prime = [True] * n
    answers = [2]

    #: we skip all the 2s

    for i in range(3, n, 2):
        if is_prime[i]:
            answers.append(i)
            for j in range(2 * i, n, i):
                is_prime[j] = False

    return answers


The results are very impressive

%timeit sieve_iter(10000)
109 µs ± 5.15 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit sieve_iter(100000)
1.09 ms ± 53.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In Practice, you probably would not want to sieve primes by hand instead of doing that by reading pre-computed results or asking from a highly-optimized library. Curiosity is here to stay, thank you for reading.

references and further readings

  1. http://blog.moertel.com/posts/2013-06-12-recursion-to-iteration-4-trampolines.html
  2. https://www.kylem.net/programming/tailcall.html
  3. https://wiki.haskell.org/Prime_numbers#Sieve_of_Eratosthenes
  4. https://www.cs.hmc.edu/~oneill/papers/Sieve-JFP.pdf
#math#Python

Implementing Red black Trees with property-based testing