2553 words
13 minutes
Solving 64 bits ECDLP with GPU Acceleration

最初有这个想法是因为 seccon 2024 quals 里面的 xiyi 有一个 unintended solution,分析到最后发现需要 5 分钟跑 32 个 66 bits 的 DLP,然后就拼尽全力无法战胜了。加之之前每次做 ECDLP 都是用 sage 的或者现场写一个 bsgs,但这样到 48 bits 就开始慢得不行了,还容易炸内存。于是乎想着能不能用 GPU 来加速一下。

在折磨了一阵子之后,成功搞出来了一个能在半分钟内解决 64 bits ECDLP 的 GPU 程序。这里面也涉及到了很多 CTF 比赛里不会出现但是现实里很有用的密码学知识和 HPC 技术,因此写本文记录一下。1

Multi-Precision Integer on GPU#

首先我们需要一个 GPU 上的高精度整数库。在 CUDA 上跑这个的需求并不多(甚至在 AI 时代整数单元都被砍废了),所以可供我们选择的库非常有限。这里我选择了 CGBN。这是 NVIDIA 官方的实现,做的优化也比较好,而且比较新有30系的支持。缺点是,这个库只支持全局固定的 bit size,而且一个整数至少要在 4 个线程上跑,在 256-bit 的时候 sync 开销有点大。不过对于 ECDLP 来说已经够用了。

因为 CUDA 的整数运算单元都是 32 位的,所以在 GPU 上高精度整数是切分成 32 位的 limb 来存储和运算的。CGBN 的一个整数会被分成连续的 TPI 段,同时由 TPI 个线程存储和运算。大部分的高精度整数运算都已经被 CGBN 实现了,这里简单介绍提一下里面有意思的细节。

Addition and Multiplication#

对密码学量级的整数,朴素的加法和乘法是最快的。

单线程加法只需要从最低的一个 limb 开始相加,并把进位传递到更高的 limb 即可。CUDA 已经提供了 add with carry 的原语,可以实现高效的多线程加法。但 CGBN 是多线程的,如果要按照单线程的思路来实现加法,每个线程都需要等待前一个线程的进位传递过来,才能继续计算,这样就被串行化了。因此,我们需要利用加法是 AC0\text{AC}^0 的性质,通过 CUDA 的 warp sync 实现一次性的同步,从而避免串行化。

乘法就只是一个朴素的 O(n2)O(n^2) 乘法实现,多线程实现基本无法避免大量的同步开销。只能从最低位开始每次 sync 一个 limb 算乘法和进位最后累加。

Montgomery Multiplication#

对于椭圆曲线计算来说,真正的瓶颈在于模运算。除法和模运算涉及到大量的分支和循环,对 GPU 非常不友好。如果每次乘法都要做模运算,速度会非常慢。2 幸运的是,我们可以采用 Montgomery 乘法,它能在无需昂贵除法的情况下高效完成模乘,非常适合 GPU 无分支实现。用 Montgomery 乘法计算 ab(modN)a \cdot b\pmod{N} 可以分成三个步骤:

  1. 任意选一个 RR,先计算 (aRmodN)(aR\mod{N})(bRmodN)(bR\mod{N})
  2. 计算 (aRmodN)(bRmodN)(aR\mod{N})(bR\mod{N}),然后乘上 R1(modN)R^{-1}\pmod{N},得到 abRmodNabR\mod{N}
  3. 将结果乘上 R1(modN)R^{-1}\pmod{N},将结果还原成 abmodNab\mod{N}

这个过程并没有真正地解决取模的问题,第二步和第三步乘 R1(modN)R^{-1}\pmod{N} 之后仍然需要取模把结果限制到 NN 的范围内。Montgomery 乘法的关键在于提出了 REDC 算法,在 T<N2T<N^2 的情况下,可以高效地计算出 TR1(modN)TR^{-1}\pmod{N}

REDC 的具体算法如下:

  1. 选择一个 R>NR>N,且 RRNN 互质。预计算 N=N1(modR)N'= -N^{-1}\pmod{R}
  2. 计算 m=(TmodR)NmodRm = (T \bmod R) \cdot N' \bmod R
  3. 计算 t=(T+mN)/Rt = (T + mN) / R
  4. 如果 tNt\geq N,则返回 tNt - N,否则返回 tt

这个算法做了两件事:先加上 NN 的倍数使其能被 RR 整除,然后除以 RR 把值缩小到接近 NN 的范围内。我们可以证明 t=T+mNRN2+RNR<2Nt=\frac{T + mN}{R}\leq \frac{N^2 + RN}{R}< 2N,所以最多减一次 NN 就可以把结果限制到 NN 内。这样模 NN 的计算就被转化成了 RR 的取模和除法。这样看起来好像并没有什么意义,但当我们选择 RR22 的幂,就会发现 RR 的取模和除法可以通过位运算实现。同样,第一步把 aa 转换成 aRmodNaR\mod{N} 也可以乘上 R2modNR^2\mod{N} 后 REDC 实现。

因此,我们之后的计算都会以 aRmodNaR\mod{N} 的形式存储。3

Faster Modular Inverse#

如果我们想解决的是 DLP,那只用 Montgomery Multiplication 就已经足够了。而对 ECDLP 来说,一个重要的问题是点的加法运算中需要计算模逆。

对于 CUDA,每 32 个线程被称为一个 warp,这 32 个线程会同时执行同一条指令。如果在执行过程中有某些线程需要执行不同的指令(比如分支),就会导致 warp 内部的线程被串行化,极大地降低性能。像我们常用的 xgcd 算法,每次迭代都需要比较两个数的大小来决定下一步的操作,这就会引入大量的分支,而且每次比较大小和加减法都需要同步所有线程,性能就会非常糟糕。

这个问题对测信道敏感的场景也很麻烦,比如比特币的椭圆曲线运算。这里他们使用了一个非常魔法的算法,在不比较大小的情况下,只通过最低位的奇偶性就可以决定下一步的操作。具体如下:

def gcd(f, g):
"""Compute the GCD of an odd integer f and another integer g."""
assert f & 1 # require f to be odd
delta = 1 # additional state variable
while g != 0:
assert f & 1 # f will be odd in every iteration
if delta > 0 and g & 1:
delta, f, g = 1 - delta, g, (g - f) // 2
elif g & 1:
delta, f, g = 1 + delta, f, (g + f) // 2
else:
delta, f, g = 1 + delta, f, (g ) // 2
return abs(f)

大概思路是每次迭代就算 g,fg, f 符号不对,除2也可以确保两个数的绝对值不会增加。然后在这个 delta 的神秘操作下,可以保证我们有较高的概率能让 g,fg,f 符号恰好能让值变小 2 倍,于是就能在 O(logn)O(\log{n}) 的迭代内变成 1。具体证明可以看看原论文和这个 repo,大概就是依靠凸包控制范围,总之很神秘。而且因为算法只关注最低位的奇偶性,所以甚至可以很方便地预测每次迭代的操作,从而用乘法实现批量迭代。

最终在一堆魔法的加持下,这个算法在实现简单的同时性能也非常好,具体可以看 libsecp256k1 里面那些人的讨论。不过也有一个代价是这里的两个数是有符号的,而 CGBN 不支持有符号整数的运算(加减法和无符号整数等价,不过乘法就不行了),所以我们还是得稍微修改一下库。最后也是比 CGBN 给的 sample 快了 3 倍。

Pollard’s Kangaroo Algorithm#

下一个问题是传统的 BSGS 算法在 64 bits 往上的时候内存消耗就非常大了,所以这里需要一个更节省内存的算法。

Pollard’s Kangaroo Algorithm 的大概思路是让若干袋鼠在曲线上进行确定性的随机游走,比如有一个 6 bits 的 jump table,通过当前点的最后 6 bits 确定下一步加上 jump table 里对应的点。当 jump table 里面的 offset~=O(sqrt(n)) 的时候,袋鼠碰撞的概率会变成大约 1n\frac{1}{\sqrt{n}}。因为碰撞后两只袋鼠将会走同样的路径,所以我们不需要储存所有的点,只需要储存最后 k bit 都为 0 的点就可以拦截这两只袋鼠。关于 Pollard’s Kangaroo Algorithm 的介绍和实现可以参考 Kangaroo-256-bit-pythonKangaroo。不过他们是给 secp256k1 设计的,所以做了很多特殊优化,没法直接拿来用。

Batched Inverse#

对 ECDLP,就算我们用了我们的快速求逆,一次求逆还是比乘法慢了 60 倍左右。4 于是一个神奇的办法是把很多个数的求逆合并成一个批量求逆。也就是

M=i=1naimodp,ai1=M1jiajmodpM = \prod_{i=1}^n a_i\bmod{p}, \qquad a_i^{-1} = M^{-1}\prod_{j\neq i} a_j\bmod{p}

具体实现的时候可以先存储所有的累乘,然后求出 M1M^{-1},最后再通过乘法得到每个 ai1a_i^{-1}。这样一共需要一次求逆和 3n3n 次乘法。不过由于 GPU 的缓存和袋鼠个数的限制,批量求逆的数量不能太大。但不论如何 ECDLP 可以轻松用上 100k+ 袋鼠,可以说非常享受了。最终实现在这里 CGBN,不过没怎么优化过缓存,特别是 ECDLP 的超大 batch inverse。

Final#

在绕了这么大一个圈子后,我们可以狠狠地拿 GPU 来非预期炸掉 seccon 2024 quals 的 xiyi 了。这道题的问题是 g12(modp)g\equiv-\frac{1}{2}\pmod{p},所以我们可以寻找一个 pp 使得 gg 的阶非常小。这里唯一的办法是找 2k+12^k+1 的素因子中恰为 518 bits 的素数。经过一番 factordb 搜索,只能找到 21308+12^{1308}+122490+12^{2490}+1,在一堆精细优化后,最后需要解一个 61 bits 和 66 bits 的 DLP。然后就是 GPU 暴力了。

最终在 3070 Ti 上做到了大约 4 分钟解出 32 个 66 bits 的 DLP,成功拿到 flag。

solve.py
from subprocess import check_output
from sage.all import *
from sage.groups.generic import bsgs
from tqdm import tqdm
from pwn import remote, process
def run_kangaroo(target, g, order, bound=None, num_retry=3):
# solve target = g**k mod p
if bound is None:
bound = order
assert order % 2 == 1
p = g.parent().characteristic()
inp = f"{p} {g} {order} {target} {bound}\n"
print(f"Running kangaroo with {bound.bit_length()} bits...")
print(inp)
ret = None
for _ in range(num_retry):
try:
ret = check_output("./kangaroo_gpu", input=inp.encode(), shell=True).decode()
assert "Solution:" in ret
break
except Exception as e:
print(e)
continue
assert ret is not None, "Kangaroo failed"
print(ret)
k = ret.split("Solution:")[1]
k = k.split("\n")[0]
return ZZ(k)
def dlp_generic(target, g, order, max_bound, num_retry=3):
if max_bound.bit_length() > 32:
return run_kangaroo(target, g, order, max_bound, num_retry)
else:
return bsgs(g, target, bounds=(ZZ(0), max_bound))
def dlp_prime(target, g, order, num_retry=3):
if order.bit_length() > 34:
return run_kangaroo(target, g, order, num_retry=num_retry)
else:
os = matrix(ZZ, [[order, 1]])
return pari(target).znlog(pari(g), pari([order, os])).sage()
def dlp_primepower(target, g, prime, n, num_retry=3):
order = prime**n
k = 0
g0 = g ** (order // prime)
for i in range(n):
t0 = target * (g**(-k))
t0 = t0 ** (prime**(n-1-i))
assert t0 ** prime == 1
k += dlp_prime(t0, g0, prime, num_retry) * prime**i
return k
def dlp_main(target, g, order, max_bound, factors, num_retry=3):
# solve target = g**k mod p
cur_bound = max_bound
mods = []
for prime, n in factors:
if cur_bound <= 10:
break
if prime > cur_bound: # too large, fall back to generic
break
mod = prime**n
ord_ = order // mod
g0 = g ** ord_
t0 = target ** ord_
cur_bound = cur_bound // mod
k_ = dlp_primepower(t0, g0, prime, n, num_retry)
mods.append((k_, mod))
rs, ms = zip(*mods)
k0 = crt(list(rs), list(ms))
MOD = prod(ms) # k = k0 + ? * MOD
g1 = g ** MOD
t1 = target * (g ** (-k0))
k1 = dlp_generic(t1, g1, order // MOD, cur_bound+1, num_retry)
return k0 + k1 * MOD
def dlp(target, g, order=None, bounds=None, known_factors=None, num_retry=3):
"""
Solve DLP for target = g**k mod p
- **order**: order of g. A multiple of the order (e.g. p-1) is also accepted.
- **bounds**: The bound for k. If None, it will be set to `order-1`.
If integer, search range is [0, bound].
If list/tuple, represents an interval [min, max].
- **known_factors**: List of known factors of order. The rest part of the order will be handled by generic method.
Either (prime, exponent) pairs or prime lists are accepted.
- **num_retry**: Number of retries when kangaroo fails.
"""
assert g.parent().is_prime_field()
p = g.parent().characteristic()
if order is None:
order = p-1
else:
order = ZZ(order)
if bounds is None:
bounds = order
if known_factors is None:
known_factors = list(factor(order, algorithm="ecm"))
if isinstance(known_factors[0], (list, tuple)):
known_primes = [p for p, _ in known_factors]
else:
known_primes = known_factors
if 2 not in known_primes:
known_primes.append(2)
assert g ** order == 1
known_primes.sort()
for p in known_primes:
while order % p == 0 and (g ** (order // p)) == 1:
order = order // p
assert g ** order == 1
factors = []
ord_ = order
for p in known_primes:
if ord_ % p != 0:
continue
v = ord_.valuation(p)
factors.append((p, v))
ord_ = ord_ // (p**v)
if ord_ > 1:
factors.append((ord_, 1))
factors.sort()
print(factors)
if bounds is None:
bounds = (0, order-1)
elif isinstance(bounds, (list, tuple)):
bounds = tuple(bounds)
else:
bounds = (0, int(bounds))
min_bound, max_bound = bounds
t2 = target * (g ** (-min_bound))
range = max_bound - min_bound
return min_bound + dlp_main(t2, g, order, range, factors, num_retry), factors
import json
from secrets import randbelow
from lib2 import Cryptosystem, Privkey, Pubkey, Ct # same as lib.py
from params import L, M, N
p = 470772889347891753894848105243789217710301326704247389714521154467591283889037042443973128895041119713324021962506767036488831747747963841111160849037254561 # (2^1308+1)
q = 731433453306625773851831070026391128383727479375152752977676902094814479855682993608187229204658713734108641142148010876470120999557683935590321341245075861 # (2^2490+1)
p1_factors = [(2, 5), (3, 3), (5, 1), (109, 1), (457456471313, 1), (53400322888084232503, 1), (4952131943166985961003171, 1), (12483277974290346429122493663434783301283, 1), (662043754266558743926932151479961636112025140409089081, 1)]
q1_factors = [(2, 2), (3, 1), (5, 1), (61, 1), (83, 1), (2099, 1), (14071, 1), (361728968932127, 1), (225369400669558848086892569194138533812747585329213772218859604310602106229150875171288346031013036088614600109065387761801873239, 1)]
n_factors = list(sorted(set([pf[0] for pf in p1_factors] + [qf[0] for qf in q1_factors])))+[p]
p_used = [2, 3, 457456471313, 53400322888084232503]
q_used = [61, 2099, 14071, 361728968932127]
p_unused = prod([4952131943166985961003171, 12483277974290346429122493663434783301283, 662043754266558743926932151479961636112025140409089081])
q_unused = 225369400669558848086892569194138533812747585329213772218859604310602106229150875171288346031013036088614600109065387761801873239
n = p**2 * q
order = p*lcm(p-1,q-1)
for x in range(2, 10000):
x = mod(x, n)
assert x ** order == 1
good = True
for p0 in n_factors:
if (x ** (order // p0)) == 1:
good = False
break
if good:
gx = int(x)
break
bad_g = mod(n // 2, p*q)
bad_g_order = 1085640
assert bad_g ** bad_g_order == 1
priv_key = Privkey(p, q, Pubkey(n))
io = process(['sage', 'server.py'])
# initialize
C = Cryptosystem.from_privkey(priv_key)
assert C.privkey is not None
enc_xs = [Ct(pow(gx, 281**i, n)) for i in range(L)]
# 1: (client) --- n, enc_xs ---> (server)
io.sendlineafter(b"> ", json.dumps({"n": n, "enc_xs": enc_xs}).encode())
def decrypt_correct(c: int) -> int:
a = C.L(pow(c, p - 1, p**2))
b = C.L(pow(gx, p - 1, p**2))
return a * pow(b, -1, p) % p
# 3: (server) --- enc_alphas, beta_sum_mod_n ---> (client)
params = json.loads(io.recvline().strip().decode())
enc_alphas, beta_sum_mod_n = params["enc_alphas"], params["beta_sum_mod_n"]
beta_sum_mod_n = decrypt_correct(n//2) * beta_sum_mod_n % p
alphas = [decrypt_correct(enc_alpha) for enc_alpha in enc_alphas]
alpha_sum = sum(alphas) % p
inner_product = (alpha_sum + beta_sum_mod_n) % p
print(f"{inner_product = }")
ys = []
for i in tqdm(range(L)):
target = pow(enc_alphas[i], bad_g_order, p*q)
base = pow(gx, 281**i, n)
base = pow(base, bad_g_order, p*q)
# base ** y_i == target
target2 = pow(target, p_unused * q_unused, p*q)
base2 = pow(base, p_unused * q_unused, p*q)
p_i, f_p = dlp(mod(target2, p), mod(base2, p), (p-1)//p_unused, known_factors=p1_factors)
q_i, f_q = dlp(mod(target2, q), mod(base2, q), (q-1)//q_unused, known_factors=q1_factors)
# print(p_i, f_p)
# print(q_i, f_q)
mod1 = prod([pf[0]**pf[1] for pf in f_p])
mod2 = prod([qf[0]**qf[1] for qf in f_q])
ci = (inner_product - sum([ys[j] * 281**j for j in range(i)]))
assert ci % (281**i) == 0
inner_product_i = ci // (281**i)
y_i_0 = crt([p_i, q_i, inner_product_i%281], [mod1, mod2, 281])
MOD = mod1 * mod2 * 281
base3 = mod(base, p) ** ((p-1)//p_unused)
target3 = mod(target, p) ** ((p-1)//p_unused)
target3 = target3 / (base3 ** y_i_0)
base3 = base3 ** MOD
y_i_1 = dlp_generic(target3, base3, p_unused, M//MOD)
y_i = y_i_0 + y_i_1 * MOD
print(f"{i = }, {y_i = }")
ys.append(int(y_i))
# If, by any chance, you can guess ys, send it for the flag!
io.sendlineafter(b"> ", json.dumps({"ys": ys, "p": C.privkey.p, "q": C.privkey.q}).encode())
print(io.recvline().strip().decode()) # Congratz! or Wrong...
print(io.recvline().strip().decode()) # flag or ys

Footnotes#

  1. 但是写博客咕了半年(x

  2. 这对那些 secp256k1 的 GPU 库来说就不是问题,因为 p=2256232977p=2^{256}-2^{32}-977 的性质很好,可以把高位的值快速降到低位。

  3. Montgomery 算法在硬件设计和高精度整数库中也被广泛使用,对现代密码学的实际应用有着非常重要的意义。 顺带一提,其发明者 Peter L. Montgomery 还有两个非常有名的算法是 Montgomery Ladder 和 Block Lanczos。

  4. 在我的 3070 Ti 上测得,乘法 + Montgomery reduction 大概是 6B op/s,而求逆大概是 100M op/s。