Incremental Prime Sieve in Python

Finding primes is one of the classic problems in computer science. In fact, the difficulty of factoring large numbers is what is behind some cryptographic protocols. We're going to look at some ways of generating prime numbers.

This article uses python3.8, and assumes familiarity with the following built in functions.

from heapq import heappush, heappop
from itertools import accumulate, count, cycle, islice, takewhile, tee
import math

# utility
def heapmin(heap):
    return heap[0]

See itertools and heapq documentation for more details.

The easy, slow way - testing for factors

We can determine if a number is prime using this algorithm

for i in 2 to sqrt(n)
    if i divides evently into n
        n is prime
n is not prime

In python

def is_prime(n):
    """Check if number is prime by testing all numbers to sqrt(n)
    """
    assert n > 0, 'only integers > 0'
    if n == 1:
        return False
    for i in range(2, math.floor(math.sqrt(n)) + 1):
        if n % i == 0:
            return False
    return True

We can use this to generate all primes less than 100

def primes_brute1(n):
    """Generator which yields primes up to n
    """
    return (i for i in range(2, n) if is_prime(i))

list(primes_brute1(100))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]

This uses fixed about of memory, since no arrays are allocated. Because it uses generators, we can pull as many items from it as we want, and not worry about running out of memory.

In fact, we can even modify it so that it doesn't have any upper bound, using itertools.count.

def primes_brute2():
    """Generator which yields all primes
    """
    return (i for i in count(1) if is_prime(i))
list(islice(primes_brute(), 10))
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]

Here we've taken the first 10 primes using itertools.islice.

The problem is that it's slow - it has a time complexity of $O\left(N^\frac{3}{2}\right)$*

Let's see how long it takes to find the largest prime smaller than 10 million

def last(s):
    """Get last element of sequence
    """
    for last in s:
        pass
    return last

%%time
last(takewhile(lambda x: x<int(1e7), primes_brute2()))
CPU times: user 1min 50s, sys: 41.4 ms, total: 1min 50s
Wall time: 1min 50s
9999991

2 minutes - let's try to beat this

The fast way - prime sieves

The sieve of Eratosthenes generates prime numbers with a time complexity of $O(n \log{}\log{}n)$

primes = Array of True from 1 to n

for i in 2 to n
    if array[i] is True:
        i is prime
        for j in i*i to n, increasing by i
            array[i] = False

Index of elements of array which are still True are the primes

We can get an idea of why this method is faster by viewing this gif

https://en.wikipedia.org/wiki/Sieve_of_Eratosthenes

We can see that the amount of work we have to do for each iteration is greatly reduced.

  • For composite numbers we don't have to do anything at all. This is great, because there are far fewer primes than composites,
  • For primes, we cross off multiples. As the primes become larger and larger we do less work because the gaps are larger, and because the inner loop starts at values further away. In this example, the square of 11 is 121. The size of the array is 120, and so we don't have to cross anything off at all.
def primes_sieve1(n):
    """Primes using sieve of erasthenes
    Generate all primes less than n
    """
    n = int(n)
    prime_sieve = [True] * n
    for i in range(2, math.ceil(math.sqrt(n))):
        for j in range(i*i, n, i):
            prime_sieve[j] = False
    return [
        idx for idx, is_prime 
        in enumerate(prime_sieve) 
        if is_prime and idx>1
    ]
primes_sieve1(10)
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97]
%%time
last(prime_sieve1(1e7))
CPU times: user 5.63 s, sys: 88 ms, total: 5.72 s
Wall time: 5.73 s
999983

Much faster - we've gone from 2 minutes to 5 seconds.

A further optimization we could do is to speed this up by using a native python array instead of a list, or even a numpy array.

However, there are 2 trade offs compared to the brute force algorithm

  • We are now using O(N) memory, where N is the largest prime we can find. In this example we had to allocate an array for 10 million elements
  • primes_brute2 was a generator which would continue forever. Here we are limited to a predetermined size.

Incremental Sieve

Our goal now is to get the best of both worlds - as fast as a sieve, but without allocating an array for every number, and with no maximum size. Is this possible?

The key insight is that we don't have to store all numbers the numbers up to n. Instead we only keep track of pairs of primes and composites. E.g. [(6, 2), (9, 3), (25, 5)]. As we will frequently remove the smallest composite, we will store this in a priority queue / heap.
Now the algorithm looks like this.

for i in 2 to infinity
    while i > smallest composite:
        np, p = remove_min(heap)
        insert (np+p, p) into heap
   if i != smallest composite
       i is prime
       insert (i*i, i) into heap

Instead of crossing out all multiples of a prime once we find it, we cross them out as we "overtake" the composites. For example, say that we i=26, and the heap contains (25, 5). We see that 26>25, so we remove (25, 5) from the queue and replace it with (30, 5).

Instead of storing every number, we are now storing every prime number. Because there are many less primes than compisites we are less concerned about running out of space.

Note: in python, heaps can be lists of tuples. The tuples are ordered by the first item - this is why we put the composite first.
https://docs.python.org/3/library/heapq.html#basic-examples

def primes_sieve2():
    """Incremental Sieve of Erasthenes
    """
    yield 2
    pqueue = [(4, 2)]
    for i in count(3):
        while i > heapmin(pqueue)[0]:
            np, p = heappop(pqueue)
            heappush(pqueue, (np + p, p))
        if i != heapmin(pqueue)[0]:
            yield i
            heappush(pqueue, (i*i, i))
%%time
last(takewhile(lambda x: x<int(1e7), primes_sieve2()))
CPU times: user 34.6 s, sys: 68 ms, total: 34.6 s
Wall time: 34.6 s
999983

This is still faster than the brute force method, but slower than the array based version of the sieve. This is because inserting into a heap is slower than accessing elements of an array.

We can see this algorithm in action  in the gif below.

Incremental Sieve

Speeding it up - wheels

Here's a trick to speed it up - you don't need to test every number, only odd ones. You can generate odd numbers just as quickly as integers - just start with 1, and keep adding 2. That means the incremental sieve algorithm will be twice as fast.

How about using the same trick for multiples of two and three? We can illustrate this by putting the first 120 natural numbers in a table with 6 columns, and highlight the composite numbers in red

1 2 3 4 5 6
7 8 9 10 11 12
13 14 15 16 17 18
19 20 21 22 23 24
25 26 27 28 29 30
31 32 33 34 35 36
37 38 39 40 41 42
43 44 45 46 47 48
49 50 51 52 53 54
55 56 57 58 59 60
61 62 63 64 65 66
67 68 69 70 71 72
73 74 75 76 77 78
79 80 81 82 83 84
85 86 87 88 89 90
91 92 93 94 95 96
97 98 99 100 101 102
103 104 105 106 107 108
109 110 111 112 113 114
115 116 117 118 119 120

We can see that there's no point in checking anything in the second, third, fourth or sixth column. 1 and 5 are "relatively prime" to both 2 and 3, so we only need to check numbers of the form $1+6n$ and $5+6n$

Let's list these "coprimes", and the get the difference between elements in the sequence

sequence: 1, 5, 7, 11, 13, 17, 19, 23, 25, 29...
diffs:    -, 4, 2,  4,  2,  4,  2,  4,  2,  4...

If we start with 1, we can keep adding 4 and 2 alternatively and generate this sequence.

def wheel23():
    """Wheel generates numbers coprime to 2 and 3
    """
    yield from accumulate(cycle([2, 4]), initial=1)

list(islice(wheel23(), 10))
[1, 3, 7, 9, 13, 15, 19, 21, 25, 27]

Now we only have to check one third of the amount of numbers we did previously

Larger Wheels

How about a wheel for 2, 3, 5 or 2, 3, 5, 7? We can find larger wheels by using this algorithm

Find the first n primes (e.g 2, 3, 5)
The size of the wheel is their product (e.g. 2*3*5=30)
For 1...size, find elements which are coprime to all of the first n primes. 
Get the difference between their indexes and repeatedly sum them

# https://docs.python.org/3/library/itertools.html
def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

def coprime(a, b):
    """Check if a and be have any factors in common other than 1
    """
    return math.gcd(a, b) == 1

def wheel(size):
    """Generate wheel coprime to first n primes
    Returns tuple of primes, then iterator of wheel
    """
    assert size > 1
    initial_primes = list(islice(primes_brute2(), size))
    n = math.prod(initial_primes)
    coprimes = [i for i in range(1, n+1) if coprime(i, n)]
    diffs = [
        (b - a)
        for a, b
        in pairwise(coprimes)
    ]
    diffs.append((coprimes[0] - coprimes[-1]) % n)
    coprimes = accumulate(cycle(diffs), initial=1)
    next(coprimes) # skip 1
    return initial_primes, coprimes
first_primes, wheel23 = wheel(2)
print(first_primes)
print(list(islice(wheel23, 10)))
[2, 3]
[5, 7, 11, 13, 17, 19, 23, 25, 29, 31]
first_primes, wheel235 = wheel(3)
print(first_primes)
print(list(islice(wheel23, 10)))
[2, 3, 5]
[7, 11, 13, 17, 19, 23, 29, 31, 37, 41]

Unfortunately, we cannot continue to add more and more primes to the wheel indefinitely. As we increase the number of primes, the size of the wheel increases exponentially, and the marginal increases in speed become smaller.

primes speed up array size
1 2.00x 2
2 3.00x 6
3 3.75x 30
4 4.38x 210
5 4.81x 2,310
6 5.21x 30,030
7 5.54x 510,510
8 5.85x 9,699,690
9 6.11x 223,092,870

Prime Sieve with Wheel

def primes_sieve3(wheel_size):
    """Incremental Sieve of Erasthenes, using wheel
    """
    first_primes, wheel_ = wheel(wheel_size)
    yield from first_primes
    
    p = next(wheel_)
    yield p
    pqueue = [(p*p, p)]

    for i in wheel_:
        while i > heapmin(pqueue)[0]:
            np, p = heappop(pqueue)
            heappush(pqueue, (np + p, p))
        if i != heapmin(pqueue)[0]:
            yield i
            heappush(pqueue, (i*i, i))
%%time
last(takewhile(lambda x: x<int(1e7), primes_sieve3(5)))
CPU times: user 11.2 s, sys: 40 ms, total: 11.3 s
Wall time: 11.3 s
9999991

Incremental Prime Sieve using Wheel with 3 primes

Using the prime sieve with three primes, it takes 11 seconds to find the first 10 million primes.  This is an improvement on the version without the wheel, but slower than using an array based sieve.

Benchmark - primesieve

Lets compare this to one of the best prime finding algorithms - bucketsieve.
https://github.com/kimwalisch/primesieve

from primesieve import primes

%%time
primes(int(1e7))
CPU times: user 5.35 ms, sys: 234 µs, total: 5.59 ms
Wall time: 5.5 ms

Our best result is about 1000 times slower than this. This is due to the implementation of prime sieve being written in C, and using an algorithm that improves cache efficiency.

Notes

$$ O\left( \sum_{i=1}^{N}\sqrt{i} \right) < O\left( \sum_{i=1}^{N}\sqrt{N}\right) = O \left( N^{\frac{3}{2}} \right) $$

It's possible to get a lower bound, but this is good enough for our purposes.

See this notebook for the code

Code

All code used in this document can be found in this notebook

References

https://www.cs.hmc.edu/~oneill/papers/Sieve-JFP.pdf

Show Comments