Really Fast Primality Testing
Using HPC resources to trivialize a homework problem
Background
The final course I took for my master's degree was MATH 673: Mathematics of Cryptography, which as the name suggests, discusses the mathematical underpinnings of modern cryptography. For the uninitiated, cryptography deals with the problem of maintaining private communication over insecure channels. Modern techniques draw heavily from mathematics, particularly number theory. One of the core building blocks for many cryptosystems are prime numbers, mainly due to the nice algebraic properties granted by doing arithmetic modulo $p$.
To make brute-force attacks infeasible, primes chosen for cryptographic purposes must be quite large. This is a somewhat nontrivial task, since primes have a natural density of $\frac{1}{\log n}$ (i.e. primes get "rarer" as $n$ increases). As such, I was unsurprised to see the following homework question:
"Write an algorithm to determine if $n$ is prime, with a false positive rate of no more than $2^{-40}$. Verify that 256866751887531116521374772376435790631 is prime."
The above number is quite small by cryptographic standards, fitting in a measly 128 bits. My initial strategy was to write a quick Python script, which should take no longer than a couple milliseconds for verification. I kept reading and found an additional footnote:
"Report the wall clock time, as well as your host machine's architecture and OS"
This statement gave me pause. I've been asked to benchmark my code before, but that was for my computer architecture and parallel computing courses. To receive such a request for a mathematics course was completely alien to me. Most math students treat programming as a necessary evil and are thoroughly uninterested in writing optimized software.
The most likely explanation for this addendum is to emphasize the role of hardware in security. A system that may take years to break on a personal laptop can be cracked in mere seconds on a supercomputer. With the current proliferation of hardware accelerators like GPUs and TPUs, this lesson is becoming more relevant by the day. However, I believed the message contained a second, hidden meaning. Buried within the discussion of the effects of compute scaling is an unspoken challenge- who can write the most performant algorithm?
Very well then. I accept this challenge.
The Naive Approach
Programming for mathematics courses tends to follow the path of least resistance. Performance optimizations and good software engineering practices tend to take a backseat to ease of development. This naturally leads to students to gravitate towards Python, so that's where I will begin. Below is a naive Python implementation of the Miller-Rabin primality test:
def power_of_two_and_odd(n):
""" Factor n = 2^s * d, where d is odd """
s, d = 0, n
while d % 2 == 0:
d //= 2
s += 1
return s, d
def naive_miller_rabin(n, p):
""" Determine if n is probably prime, with false positive rate < 2^-p """
trials = p // 2
s, d = power_of_two_and_odd(n-1)
for _ in range(trials):
a = randint(2, n-2)
# Strong probable prime base a, go to next witness
if pow(a, d, n) == 1:
continue
curr_pow = d # 2^r * d for 0 <= r < s
is_witness = True
for r in range(s):
# Strong probable prime base a, go to next witness
if pow(a, curr_pow, n) == n-1:
is_witness = False
break
curr_pow *= 2
if is_witness:
return False # definitely composite
return True # probably prime
This is effectively a one-to-one translation of the pseudocode given by Wikipedia. On a 2019 MacBook Pro with an i9 and 16 GB of RAM, I was able to verify the primality of 256866751887531116521374772376435790631 in just 1.37 ms.
With already a relatively short execution time, it appears that we need to up the ante to make performance optimizations worthwhile. As such, we turn our attention towards the NIST standards for digital signatures. Contained within the appendices is the following table:
| Parameters | M-R Only | M-R Only |
|---|---|---|
| $p$ and $q$: 1024 bits | Error probability = $2^{-100}$ | Error probability = $2^{-112}$ |
| $p$ and $q$: 2048 bits | Error probability = $2^{-100}$ | Error probability = $2^{-144}$ |
These are the requirements for RSA, where $p$ and $q$ refer to the secret primes that are multiplied together as part of the public key. Since these standards are designed to withstand attacks from real adversaries as opposed to educating budding mathematicians, they are significantly more stringent than our homework problem. Indeed, even the weakest requirement, verifying 1024 bit primes with an error tolerance of $2^{-100}$, takes 409 ms using my naive script. That's nearly a 300x slowdown!
Obvious Optimizations
There are two immediate optimizations that can be made to the naive Python routine. The first is to
apply strength
reduction to the power_of_two_and_odd function. Division and
modulo are
relatively expensive operations for the CPU to conduct and, in certain cases, we can replace them with
cheaper
alternatives 1. Lucky for us, we can make the following substitutions:
-
while(n % 2 == 0):$\implies$while(~n & 1): -
n /= 2$\implies$n >>= 1
The second, and more notable improvement, can be made inside the inner for loop. For each potential
witness $a$, we
test:
$$
a^{2^r \cdot d} \equiv -1 \bmod n \quad \forall r \in \{0, 1, \dots s-1\}
$$
which involves at most $s$ comparisons if none of the values are congruent to $-1 \mod n$. At present,
we keep track
of $2^r \cdot d$ in the variable curr_pow and compute pow(a, curr_pow, n)
each iteration.
Each call to pow takes $\mathcal{O}(\log \text{curr\_pow})$ multiplications, giving the
inner loop a runtime
of $\mathcal{O}(s \cdot \log n)$. We can eliminate the expensive call to pow by noting the
following:
$$
a_k = a_{k-1} \cdot a_{k-1} \quad a_0 = a^d
$$
which computes the same values as before, but only needs one multiplication per loop iteration (This is
effectively the
fast squaring
algorithm , done modulo $n$). Our new algorithm looks like:
def fast_power_of_two_and_odd(n):
""" Factor n = 2^s * d, where d is odd """
s, d = 0, n
while d % 2 == 0:
d >>= 1
s += 1
return s, d
def fast_miller_rabin(n, p):
""" Determine if n is probably prime, with false positive rate < 2^-p """
trials = p // 2
s, d = fast_power_of_two_and_odd(n-1)
for _ in range(trials):
a = randint(2, n-2)
curr = pow(a, d, n)
# Strong probable prime base a, find next witness
if curr == 1:
continue
is_witness = True
for r in range(s):
# Strong probable prime base a, find next witness
if curr == n-1:
is_witness = False
break
curr = (curr * curr) % n
if is_witness:
return False
return True
Running this on the same prime as before, I get a time of 182 ms, which is over a 2x speedup over the naive approach. This is a great improvement for such low-hanging fruit, but we can do much better.
The Obligatory C++ Rewrite
We could explore other potential optimizations like unrolling the outer loop or generating all of our
witnesses at once to avoid repeated calls to randint. However, the path forward if we're
truly dedicated for performance is obvious- ditch Python. There's a couple of reasons for why this move
is necessary:
- Python, by dint of being an interpreted language, has a signficantly higher runtime overhead compared to a compiled language. Bytecode translation, dynamic type checking, and garbage collection all must be done at runtime by the interpreter.
- Modern compilers offer a slew of optimization passes. The programmer can rely on the wisdom of decades of compiler engineers (or Claude, I guess) to generate optimized code for a particular architecture.
- True parallelism is impossible in Python, due to the GIL only permitting one thread to
access the Python interpreter at a time. There is concurrency support via the
multithreadinglibrary, but performance gains are minimal for CPU-bound tasks.
I'll be rewriting the above kernel in C++, since it's a language I'm familiar with. Incidentally, C++, along with its predecessor C, is the de facto choice for performance critical code. All roads lead to Rome, I suppose.
There's really only one sticking point for the migration: integer representation. In Python,
integers can have arbitrary width, so we don't need to worry about overflow. However, C++ only
has fixed-width integers (unless you're building with the LLVM toolchain, which has
llvm::APInt),
so I'll need to find a third-party library to avoid rolling my own BigInt class.
I settled on NTL, which provides arbitrary precision integers and a wide array of number-theoretic functions. It also happens to be thread-safe and has a built-in thread pool for automatically parallelizing tasks (more on this later). Here's what the C++ equivalent looks like:
bool serial_miller_rabin(const ZZ& n, size_t p) {
// Set the modulus for Z/pZ to n
ZZ_p::init(n);
ZZ d = n - 1;
size_t s = 0;
while (d % 2 == 0) {
d >>= 1;
s++;
}
ZZ_p a;
ZZ_p curr;
const size_t trials = p / 2;
for (size_t i = 0; i < trials; i++) {
random(a);
// a = 0 will be mistakenly considered a witness, so skip it
// a = +/- 1 will never witness a composite, so skip it
if (a == 0 || a == 1 || a + 1 == 0) continue;
power(curr, a, d); // curr = a^d mod n
if (curr == 1) continue; // pseudoprime, go to next witness
bool is_witness = true;
for (size_t j = 0; j < s; j++) {
if (curr + 1 == 0) {
is_witness = false; // pseudoprime, go to next witness
break;
}
curr *= curr;
}
if (is_witness) return false; // definitely composite
}
return true; // probably prime
}
ZZ and ZZ_p are NTL's arbitrary-precision integer classes, where
the latter does arithmetic in $(\mathbb{Z}/p\mathbb{Z})^*$. This is especially convenient, since
we don't have to keep track of the modulus ourselves. Other than that, this is very similar
to our Python setup. How does it perform?
cendres@MacBookPro miller-rabin % cat prime1024.txt | ./miller-rabin
wall time: 15.87ms
Quite well, by the look of it. This is a tenfold improvement over the Python version, making our
port well worth the effort.
Progress So Far
For the sake of tracking our progress, let's do a side-by-side comparison of the three algorithms we've written so far. We'll test the four parameter sets ($n$ = number of bits, $p$ = maximum fail tolerance), as outlined by NIST. Each algorithm will be run 100 times, as to spread out the effect of context switches 2, and we'll report a 95% confidence interval for the true mean running time. Below are the results of the experiment:
| Algorithm | $n=1024$, $p=2^{-100}$ | $n=1024$, $p=2^{-112}$ | $n=2048$, $p=2^{-100}$ | $n=2048$, $p=2^{-144}$ |
|---|---|---|---|---|
| Naive Python | $409.2 \pm 4.4$ ms | $449.1 \pm 4.6$ ms | $2520 \pm 27$ ms | $3694 \pm 32$ ms |
| Optimized Python | $182.6 \pm 1.5$ ms | $203.6 \pm 4.6$ ms | $1119 \pm 9$ ms | $1586 \pm 3$ ms |
| Serial C++ w/ NTL | $16.0 \pm 0.1$ ms | $17.9 \pm 0.2$ ms | $115.4 \pm 0.3$ ms | $166.8 \pm 0.7$ ms |
Our C++ application, unsurprisingly, sweeps the competition. There is still one avenue we have yet to pursue: multithreading.
Adding Threads
The difficulty of parallelizing a single-threaded program can range from trivial to incredibly
complicated.
Thankfully, our primality test is embarassingly parallel, since no
communication between threads is necessary. We can partition the witness candidates amongst all of our
threads
and have each thread modify a global variable indicating if a witness has been found. Even better,
this partitioning is automatically handled by the NTL_EXEC_RANGE macro, which acts
similar to OpenMP's #pragma omp parallel for. All we have to do is move the initialization
of our integers inside the critical section to make them thread local. Here's our fully parallelized
code:
bool parallel_miller_rabin(const ZZ& n, size_t p) {
ZZ d = n - 1;
const size_t trials = p / 2;
size_t s = 0;
while (d % 2 == 0) {
d >>= 1;
s++;
}
std::atomic_bool is_prob_prime = true;
// Start critical section
NTL_EXEC_RANGE(trials, first, last)
ZZ_p a; // thread local
ZZ_p curr; // thread local
ZZ_p::init(n); // reinit, since only main thread initialized modulus
for (size_t i = first; i < last; i++) {
random(a);
// a = 0 will be mistakenly considered a witness, so skip it
// a = +/- 1 will never witness a composite, so skip it
if (a == 0 || a == 1 || a + 1 == 0) continue;
power(curr, a, d); // curr = a^d mod n
if (curr == 1) continue; // pseudoprime, go to next witness
bool is_witness = true;
for (size_t j = 0; j < s; j++) {
if (curr + 1 == 0) {
is_witness = false; // pseudoprime, go to next witness
break;
}
curr *= curr;
}
if (is_witness) {
// Sequential consistency is unnecessary
// i.e. agreeing on a write order is irrelevant
is_probably_prime.store(false, std::memory_order_relaxed);
break;
}
}
// End critical section
NTL_EXEC_RANGE_END
return is_probably_prime.load(std::memory_order_relaxed);
}
The only synchronization primitive we need is a single std::atomic_bool,
which a thread will atomically set to false should it find a witness. Doing this
non-atomically
is fine on paper, as any races to store false are benign. However,
any
data race involving at least one write is UB in C++, so we make this atomic to avoid the dreaded nasal
demons.
Let's see how we did with a simple strong scaling experiment. Using the same parameters as last time, we successively double the number of threads and obtain:
| # of Threads | $n=1024$, $p=2^{-100}$ | $n=1024$, $p=2^{-112}$ | $n=2048$, $p=2^{-100}$ | $n=2048$, $p=2^{-144}$ |
|---|---|---|---|---|
| $t=2$ | $8.43 \pm 0.13$ ms | $9.46 \pm 0.11$ ms | $59.44 \pm 0.27$ ms | $87.32 \pm 0.81$ ms |
| $t=4$ | $4.65 \pm 0.08$ ms | $5.06 \pm 0.10$ ms | $32.99 \pm 0.39$ ms | $45.68 \pm 0.22$ ms |
| $t=8$ | $2.91 \pm 0.06$ ms | $3.07 \pm 0.08$ ms | $21.33 \pm 0.37$ ms | $27.56 \pm 0.39$ ms |
| $t=16$ | $3.07 \pm 0.05$ ms | $3.40 \pm 0.06$ ms | $21.02 \pm 0.34$ ms | $26.66 \pm 0.24$ ms |
This is great- we have perfect strong scaling for up to eight threads! Further scaling is likely ineffective because there's not enough work to go around. For a failure threshold of $2^{-p}$, we only need to check $\frac{p}{2}$ witnesses 3. Even at the maximum tolerance of $2^{-144}$, each of the 16 threads only has around five candidates to check.
Counterarguments
There is an alternative explanation for the poor performance at $n=16$ threads. My MacBook Pro is equipped with an Intel i9-9880H CPU, which has eight physical cores. Additionally, Intel's hyperthreading allows for two threads to run on a core at the same time, bumping the thread count to 16. Since these threads are truly parallel and not merely concurrent, we would expect to get performance gains from utilizing all 16 hyperthreads. This isn't the case in practice and resource contention is likely to blame.
Two hyperthreads scheduled on the same core share a fair bit of infrastructure- L1 cache, bus interface, and functional
units to name a few. The most important resource for our application is the integer ALU, which handles all of the modular
arithmetic. Take a look at the following hotspot analysis, generated by Intel's vtune:
| Function | Module | CPU Time | % of CPU Time |
|---|---|---|---|
| __gmpz_powm | libgmp.so.10 | 22.880s | 76.8% |
| __gmpn_addmul_2_coreisbr | libgmp.so.10 | 4.040s | 13.6% |
| __gmpn_sqr_basecase_coreisbr | libgmp.so.10 | 1.720s | 5.8% |
| __gmpn_mul_basecase_coreisbr | libgmp.so.10 | 0.351s | 1.2% |
| __gmpn_redc_2_fat | libgmp.so.10 | 0.330s | 1.1% |
CPU time is overwhelmingly dominated by __gmpz_powm, an optimized assembly routine in the GMP library responsible
for computing $a^b \bmod m$. This naturally involves the integer ALU, which will become a bottleneck if two hyperthreads
want to perform this calculation at the same time.
The HPC Sledgehammer
One way of circumventing this is to use a processor with more physical cores. Given enough physical cores, the scheduler could place each thread on a separate core, thereby eliminating resource contention from competing hyperthreads. This is all well and good, but where am I to find such a system?
As it so happens, I have access to Grace, a world-class supercomputer under the purview of the TAMU HPRC. I was granted a modest amount of service units on the cluster for a class in a previous semester. A fair amount of credits still remain on my account, and they're due to expire at the end of the fiscal year. It would be a shame to let them go to waste, so let's take one final crack at our primality test using upgraded hardware.
Grace is equipped with Intel's Xeon Gold processors. Each socket features 20 physical cores and the same hyperthreading setup. This gives us 40 physical threads, and more importantly, enough cores to run 16 threads without having to share. Running the same scaling experiment as before, we get the following times:
| # of Threads | $n=1024$, $p=2^{-100}$ | $n=1024$, $p=2^{-112}$ | $n=2048$, $p=2^{-100}$ | $n=2048$, $p=2^{-144}$ |
|---|---|---|---|---|
| $t=2$ | $10.39 \pm 0.04$ ms | $11.30 \pm 0.07$ ms | $71.42 \pm 0.14$ ms | $102.40 \pm 0.13$ ms |
| $t=4$ | $5.48 \pm 0.01$ ms | $5.91 \pm 0.03$ ms | $37.19 \pm 0.07$ ms | $51.23 \pm 0.14$ ms |
| $t=8$ | $3.00 \pm 0.01$ ms | $3.00 \pm 0.06$ ms | $21.15 \pm 0.56$ ms | $27.01 \pm 0.79$ ms |
| $t=16$ | $1.84 \pm 0.02$ ms | $2.06 \pm 0.06$ ms | $12.17 \pm 0.05$ ms | $15.57 \pm 0.33$ ms |
| $t=32$ | $1.89 \pm 0.01$ ms | $1.89 \pm 0.02$ ms | $10.75 \pm 0.29$ ms | $15.73 \pm 0.20$ ms |
Our hypothesis appears to be correct- as scaling now persists to 16 threads! Moreover, the same stall we saw earlier now occurs at 32 threads, the first time the threads outnumber the physical cores.
Closing Thoughts
With the optimization efforts a resounding success, I submit the following results:
| Algorithm | $n=1024$, $p=2^{-100}$ | $n=1024$, $p=2^{-112}$ | $n=2048$, $p=2^{-100}$ | $n=2048$, $p=2^{-144}$ |
|---|---|---|---|---|
| Naive Python | $409.2 \pm 4.4$ ms | $449.1 \pm 4.6$ ms | $2520 \pm 27$ ms | $3694 \pm 32$ ms |
| NTL with $t=16$ (Grace) | $1.84 \pm 0.02$ ms | $2.06 \pm 0.06$ ms | $12.17 \pm 0.05$ ms | $15.57 \pm 0.33$ ms |
| Speedup Factor | $222\times$ faster | $218\times$ faster | $207\times$ faster | $237\times$ faster |
There is, undoubtedly, more I could do to keep pushing down the execution time. Maybe using separate processes instead of threads would be better 4. Perhaps I could rewrite the entire codebase in Rust, as that seems to be in vogue. However, I am content with my work and am uninterested with pursuing diminishing returns.
Most importantly, I have carried on a time-honored tradition for developers: spending hours improving a task just to save a few hundred milliseconds. Time well spent.
□