Really Fast Primality Testing

Using HPC resources to trivialize a homework problem

Background

The final course I took for my master's degree was MATH 673: Mathematics of Cryptography, which as the name suggests, discusses the mathematical underpinnings of modern cryptography. For the uninitiated, cryptography deals with the problem of maintaining private communication over insecure channels. Modern techniques draw heavily from mathematics, particularly number theory. One of the core building blocks for many cryptosystems are prime numbers, mainly due to the nice algebraic properties granted by doing arithmetic modulo $p$.

To make brute-force attacks infeasible, primes chosen for cryptographic purposes must be quite large. This is a somewhat nontrivial task, since primes have a natural density of $\frac{1}{\log n}$ (i.e. primes get "rarer" as $n$ increases). As such, I was unsurprised to see the following homework question:

"Write an algorithm to determine if $n$ is prime, with a false positive rate of no more than $2^{-40}$. Verify that 256866751887531116521374772376435790631 is prime."

The above number is quite small by cryptographic standards, fitting in a measly 128 bits. My initial strategy was to write a quick Python script, which should take no longer than a couple milliseconds for verification. I kept reading and found an additional footnote:

"Report the wall clock time, as well as your host machine's architecture and OS"

This statement gave me pause. I've been asked to benchmark my code before, but that was for my computer architecture and parallel computing courses. To receive such a request for a mathematics course was completely alien to me. Most math students treat programming as a necessary evil and are thoroughly uninterested in writing optimized software.

The most likely explanation for this addendum is to emphasize the role of hardware in security. A system that may take years to break on a personal laptop can be cracked in mere seconds on a supercomputer. With the current proliferation of hardware accelerators like GPUs and TPUs, this lesson is becoming more relevant by the day. However, I believed the message contained a second, hidden meaning. Buried within the discussion of the effects of compute scaling is an unspoken challenge- who can write the most performant algorithm?

Very well then. I accept this challenge.

The Naive Approach

Programming for mathematics courses tends to follow the path of least resistance. Performance optimizations and good software engineering practices tend to take a backseat to ease of development. This naturally leads to students to gravitate towards Python, so that's where I will begin. Below is a naive Python implementation of the Miller-Rabin primality test:


            def power_of_two_and_odd(n):
                """ Factor n = 2^s * d, where d is odd """
                s, d = 0, n
                while d % 2 == 0:
                    d //= 2
                    s += 1
                return s, d
            
            def naive_miller_rabin(n, p):
                """ Determine if n is probably prime, with false positive rate < 2^-p """
                trials = p // 2
                s, d = power_of_two_and_odd(n-1)

                for _ in range(trials):
                    a = randint(2, n-2)
                    # Strong probable prime base a, go to next witness
                    if pow(a, d, n) == 1:
                        continue
                    
                    curr_pow = d # 2^r * d for 0 <= r < s
                    is_witness = True
                    for r in range(s):
                        # Strong probable prime base a, go to next witness
                        if pow(a, curr_pow, n) == n-1:
                            is_witness = False
                            break
                        curr_pow *= 2
                    
                    if is_witness:
                        return False # definitely composite
                
                return True # probably prime
            

This is effectively a one-to-one translation of the pseudocode given by Wikipedia. On a 2019 MacBook Pro with an i9 and 16 GB of RAM, I was able to verify the primality of 256866751887531116521374772376435790631 in just 1.37 ms.

With already a relatively short execution time, it appears that we need to up the ante to make performance optimizations worthwhile. As such, we turn our attention towards the NIST standards for digital signatures. Contained within the appendices is the following table:

Parameters M-R Only M-R Only
$p$ and $q$: 1024 bits Error probability = $2^{-100}$ Error probability = $2^{-112}$
$p$ and $q$: 2048 bits Error probability = $2^{-100}$ Error probability = $2^{-144}$

These are the requirements for RSA, where $p$ and $q$ refer to the secret primes that are multiplied together as part of the public key. Since these standards are designed to withstand attacks from real adversaries as opposed to educating budding mathematicians, they are significantly more stringent than our homework problem. Indeed, even the weakest requirement, verifying 1024 bit primes with an error tolerance of $2^{-100}$, takes 409 ms using my naive script. That's nearly a 300x slowdown!

Obvious Optimizations

There are two immediate optimizations that can be made to the naive Python routine. The first is to apply strength reduction to the power_of_two_and_odd function. Division and modulo are relatively expensive operations for the CPU to conduct and, in certain cases, we can replace them with cheaper alternatives 1. Lucky for us, we can make the following substitutions:

to rely on bitwise operations. These instructions take only a single clock cycle to execute on most modern processors. This should cut down execution time, but since this function is only called once, such gains are probably slim.

The second, and more notable improvement, can be made inside the inner for loop. For each potential witness $a$, we test: $$ a^{2^r \cdot d} \equiv -1 \bmod n \quad \forall r \in \{0, 1, \dots s-1\} $$ which involves at most $s$ comparisons if none of the values are congruent to $-1 \mod n$. At present, we keep track of $2^r \cdot d$ in the variable curr_pow and compute pow(a, curr_pow, n) each iteration. Each call to pow takes $\mathcal{O}(\log \text{curr\_pow})$ multiplications, giving the inner loop a runtime of $\mathcal{O}(s \cdot \log n)$. We can eliminate the expensive call to pow by noting the following: $$ a_k = a_{k-1} \cdot a_{k-1} \quad a_0 = a^d $$ which computes the same values as before, but only needs one multiplication per loop iteration (This is effectively the fast squaring algorithm , done modulo $n$). Our new algorithm looks like:


                def fast_power_of_two_and_odd(n):
                    """ Factor n = 2^s * d, where d is odd """
                    s, d = 0, n
                    while d % 2 == 0:
                        d >>= 1
                        s += 1
                    return s, d
                
                def fast_miller_rabin(n, p):
                    """ Determine if n is probably prime, with false positive rate < 2^-p """
                    trials = p // 2
                    s, d = fast_power_of_two_and_odd(n-1)

                    for _ in range(trials):
                        a = randint(2, n-2)
                        curr = pow(a, d, n)
                        # Strong probable prime base a, find next witness
                        if curr == 1:
                            continue
                        
                        is_witness = True
                        for r in range(s):
                            # Strong probable prime base a, find next witness
                            if curr == n-1:
                                is_witness = False
                                break
                            curr = (curr * curr) % n
                        
                        if is_witness:
                            return False
                    
                    return True
            
            

Running this on the same prime as before, I get a time of 182 ms, which is over a 2x speedup over the naive approach. This is a great improvement for such low-hanging fruit, but we can do much better.

The Obligatory C++ Rewrite

We could explore other potential optimizations like unrolling the outer loop or generating all of our witnesses at once to avoid repeated calls to randint. However, the path forward if we're truly dedicated for performance is obvious- ditch Python. There's a couple of reasons for why this move is necessary:

I'll be rewriting the above kernel in C++, since it's a language I'm familiar with. Incidentally, C++, along with its predecessor C, is the de facto choice for performance critical code. All roads lead to Rome, I suppose.

There's really only one sticking point for the migration: integer representation. In Python, integers can have arbitrary width, so we don't need to worry about overflow. However, C++ only has fixed-width integers (unless you're building with the LLVM toolchain, which has llvm::APInt), so I'll need to find a third-party library to avoid rolling my own BigInt class.

I settled on NTL, which provides arbitrary precision integers and a wide array of number-theoretic functions. It also happens to be thread-safe and has a built-in thread pool for automatically parallelizing tasks (more on this later). Here's what the C++ equivalent looks like:


                    bool serial_miller_rabin(const ZZ& n, size_t p) {
                        // Set the modulus for Z/pZ to n
                        ZZ_p::init(n);
                        ZZ d = n - 1;
                        size_t s = 0;
                        while (d % 2 == 0) {
                            d >>= 1;
                            s++;
                        }
                        ZZ_p a;
                        ZZ_p curr;
                        const size_t trials = p / 2;
                        for (size_t i = 0; i < trials; i++) {
                            random(a);
                            // a = 0 will be mistakenly considered a witness, so skip it
                            // a = +/- 1 will never witness a composite, so skip it
                            if (a == 0 || a == 1 || a + 1 == 0) continue;
                            power(curr, a, d); // curr = a^d mod n
                            if (curr == 1) continue; // pseudoprime, go to next witness
                            
                            bool is_witness = true;
                            for (size_t j = 0; j < s; j++) {
                                if (curr + 1 == 0) {
                                    is_witness = false; // pseudoprime, go to next witness
                                    break;
                                }
                                curr *= curr;
                            }

                            if (is_witness) return false; // definitely composite
                        }

                        return true; // probably prime
                    }
                

ZZ and ZZ_p are NTL's arbitrary-precision integer classes, where the latter does arithmetic in $(\mathbb{Z}/p\mathbb{Z})^*$. This is especially convenient, since we don't have to keep track of the modulus ourselves. Other than that, this is very similar to our Python setup. How does it perform?

cendres@MacBookPro miller-rabin % cat prime1024.txt | ./miller-rabin
wall time: 15.87ms
Quite well, by the look of it. This is a tenfold improvement over the Python version, making our port well worth the effort.

Progress So Far

For the sake of tracking our progress, let's do a side-by-side comparison of the three algorithms we've written so far. We'll test the four parameter sets ($n$ = number of bits, $p$ = maximum fail tolerance), as outlined by NIST. Each algorithm will be run 100 times, as to spread out the effect of context switches 2, and we'll report a 95% confidence interval for the true mean running time. Below are the results of the experiment:

Algorithm $n=1024$, $p=2^{-100}$ $n=1024$, $p=2^{-112}$ $n=2048$, $p=2^{-100}$ $n=2048$, $p=2^{-144}$
Naive Python $409.2 \pm 4.4$ ms $449.1 \pm 4.6$ ms $2520 \pm 27$ ms $3694 \pm 32$ ms
Optimized Python $182.6 \pm 1.5$ ms $203.6 \pm 4.6$ ms $1119 \pm 9$ ms $1586 \pm 3$ ms
Serial C++ w/ NTL $16.0 \pm 0.1$ ms $17.9 \pm 0.2$ ms $115.4 \pm 0.3$ ms $166.8 \pm 0.7$ ms

Our C++ application, unsurprisingly, sweeps the competition. There is still one avenue we have yet to pursue: multithreading.

Adding Threads

The difficulty of parallelizing a single-threaded program can range from trivial to incredibly complicated. Thankfully, our primality test is embarassingly parallel, since no communication between threads is necessary. We can partition the witness candidates amongst all of our threads and have each thread modify a global variable indicating if a witness has been found. Even better, this partitioning is automatically handled by the NTL_EXEC_RANGE macro, which acts similar to OpenMP's #pragma omp parallel for. All we have to do is move the initialization of our integers inside the critical section to make them thread local. Here's our fully parallelized code:


                    bool parallel_miller_rabin(const ZZ& n, size_t p) {
                        ZZ d = n - 1;
                        const size_t trials = p / 2;
                        size_t s = 0;
                        while (d % 2 == 0) {
                            d >>= 1;
                            s++;
                        }
                        
                        std::atomic_bool is_prob_prime = true;

                        // Start critical section
                        NTL_EXEC_RANGE(trials, first, last)
                        ZZ_p a;           // thread local
                        ZZ_p curr;        // thread local
                        ZZ_p::init(n); // reinit, since only main thread initialized modulus
                        for (size_t i = first; i < last; i++) {
                            random(a);
                            // a = 0 will be mistakenly considered a witness, so skip it
                            // a = +/- 1 will never witness a composite, so skip it
                            if (a == 0 || a == 1 || a + 1 == 0) continue;
                            power(curr, a, d); // curr = a^d mod n
                            if (curr == 1) continue; // pseudoprime, go to next witness
                            
                            bool is_witness = true;
                            for (size_t j = 0; j < s; j++) {
                                if (curr + 1 == 0) {
                                    is_witness = false; // pseudoprime, go to next witness
                                    break;
                                }
                                curr *= curr;
                            }

                            if (is_witness) {
                                // Sequential consistency is unnecessary
                                // i.e. agreeing on a write order is irrelevant
                                is_probably_prime.store(false, std::memory_order_relaxed);
                                break;
                            }
                        }

                        // End critical section
                        NTL_EXEC_RANGE_END

                        return is_probably_prime.load(std::memory_order_relaxed);
                    }
                

The only synchronization primitive we need is a single std::atomic_bool, which a thread will atomically set to false should it find a witness. Doing this non-atomically is fine on paper, as any races to store false are benign. However, any data race involving at least one write is UB in C++, so we make this atomic to avoid the dreaded nasal demons.

Let's see how we did with a simple strong scaling experiment. Using the same parameters as last time, we successively double the number of threads and obtain:

# of Threads $n=1024$, $p=2^{-100}$ $n=1024$, $p=2^{-112}$ $n=2048$, $p=2^{-100}$ $n=2048$, $p=2^{-144}$
$t=2$ $8.43 \pm 0.13$ ms $9.46 \pm 0.11$ ms $59.44 \pm 0.27$ ms $87.32 \pm 0.81$ ms
$t=4$ $4.65 \pm 0.08$ ms $5.06 \pm 0.10$ ms $32.99 \pm 0.39$ ms $45.68 \pm 0.22$ ms
$t=8$ $2.91 \pm 0.06$ ms $3.07 \pm 0.08$ ms $21.33 \pm 0.37$ ms $27.56 \pm 0.39$ ms
$t=16$ $3.07 \pm 0.05$ ms $3.40 \pm 0.06$ ms $21.02 \pm 0.34$ ms $26.66 \pm 0.24$ ms

This is great- we have perfect strong scaling for up to eight threads! Further scaling is likely ineffective because there's not enough work to go around. For a failure threshold of $2^{-p}$, we only need to check $\frac{p}{2}$ witnesses 3. Even at the maximum tolerance of $2^{-144}$, each of the 16 threads only has around five candidates to check.

Counterarguments

There is an alternative explanation for the poor performance at $n=16$ threads. My MacBook Pro is equipped with an Intel i9-9880H CPU, which has eight physical cores. Additionally, Intel's hyperthreading allows for two threads to run on a core at the same time, bumping the thread count to 16. Since these threads are truly parallel and not merely concurrent, we would expect to get performance gains from utilizing all 16 hyperthreads. This isn't the case in practice and resource contention is likely to blame.

Two hyperthreads scheduled on the same core share a fair bit of infrastructure- L1 cache, bus interface, and functional units to name a few. The most important resource for our application is the integer ALU, which handles all of the modular arithmetic. Take a look at the following hotspot analysis, generated by Intel's vtune:

Function Module CPU Time % of CPU Time
__gmpz_powm libgmp.so.10 22.880s 76.8%
__gmpn_addmul_2_coreisbr libgmp.so.10 4.040s 13.6%
__gmpn_sqr_basecase_coreisbr libgmp.so.10 1.720s 5.8%
__gmpn_mul_basecase_coreisbr libgmp.so.10 0.351s 1.2%
__gmpn_redc_2_fat libgmp.so.10 0.330s 1.1%

CPU time is overwhelmingly dominated by __gmpz_powm, an optimized assembly routine in the GMP library responsible for computing $a^b \bmod m$. This naturally involves the integer ALU, which will become a bottleneck if two hyperthreads want to perform this calculation at the same time.

The HPC Sledgehammer

One way of circumventing this is to use a processor with more physical cores. Given enough physical cores, the scheduler could place each thread on a separate core, thereby eliminating resource contention from competing hyperthreads. This is all well and good, but where am I to find such a system?

As it so happens, I have access to Grace, a world-class supercomputer under the purview of the TAMU HPRC. I was granted a modest amount of service units on the cluster for a class in a previous semester. A fair amount of credits still remain on my account, and they're due to expire at the end of the fiscal year. It would be a shame to let them go to waste, so let's take one final crack at our primality test using upgraded hardware.

Grace is equipped with Intel's Xeon Gold processors. Each socket features 20 physical cores and the same hyperthreading setup. This gives us 40 physical threads, and more importantly, enough cores to run 16 threads without having to share. Running the same scaling experiment as before, we get the following times:

# of Threads $n=1024$, $p=2^{-100}$ $n=1024$, $p=2^{-112}$ $n=2048$, $p=2^{-100}$ $n=2048$, $p=2^{-144}$
$t=2$ $10.39 \pm 0.04$ ms $11.30 \pm 0.07$ ms $71.42 \pm 0.14$ ms $102.40 \pm 0.13$ ms
$t=4$ $5.48 \pm 0.01$ ms $5.91 \pm 0.03$ ms $37.19 \pm 0.07$ ms $51.23 \pm 0.14$ ms
$t=8$ $3.00 \pm 0.01$ ms $3.00 \pm 0.06$ ms $21.15 \pm 0.56$ ms $27.01 \pm 0.79$ ms
$t=16$ $1.84 \pm 0.02$ ms $2.06 \pm 0.06$ ms $12.17 \pm 0.05$ ms $15.57 \pm 0.33$ ms
$t=32$ $1.89 \pm 0.01$ ms $1.89 \pm 0.02$ ms $10.75 \pm 0.29$ ms $15.73 \pm 0.20$ ms

Our hypothesis appears to be correct- as scaling now persists to 16 threads! Moreover, the same stall we saw earlier now occurs at 32 threads, the first time the threads outnumber the physical cores.

Closing Thoughts

With the optimization efforts a resounding success, I submit the following results:

Algorithm $n=1024$, $p=2^{-100}$ $n=1024$, $p=2^{-112}$ $n=2048$, $p=2^{-100}$ $n=2048$, $p=2^{-144}$
Naive Python $409.2 \pm 4.4$ ms $449.1 \pm 4.6$ ms $2520 \pm 27$ ms $3694 \pm 32$ ms
NTL with $t=16$ (Grace) $1.84 \pm 0.02$ ms $2.06 \pm 0.06$ ms $12.17 \pm 0.05$ ms $15.57 \pm 0.33$ ms
Speedup Factor $222\times$ faster $218\times$ faster $207\times$ faster $237\times$ faster

There is, undoubtedly, more I could do to keep pushing down the execution time. Maybe using separate processes instead of threads would be better 4. Perhaps I could rewrite the entire codebase in Rust, as that seems to be in vogue. However, I am content with my work and am uninterested with pursuing diminishing returns.

Most importantly, I have carried on a time-honored tradition for developers: spending hours improving a task just to save a few hundred milliseconds. Time well spent.