最初有这个想法是因为 seccon 2024 quals 里面的 xiyi 有一个 unintended solution,分析到最后发现需要 5 分钟跑 32 个 66 bits 的 DLP,然后就拼尽全力无法战胜了。加之之前每次做 ECDLP 都是用 sage 的或者现场写一个 bsgs,但这样到 48 bits 就开始慢得不行了,还容易炸内存。于是乎想着能不能用 GPU 来加速一下。
在折磨了一阵子之后,成功搞出来了一个能在半分钟内解决 64 bits ECDLP 的 GPU 程序。这里面也涉及到了很多 CTF 比赛里不会出现但是现实里很有用的密码学知识和 HPC 技术,因此写本文记录一下。1
Multi-Precision Integer on GPU
首先我们需要一个 GPU 上的高精度整数库。在 CUDA 上跑这个的需求并不多(甚至在 AI 时代整数单元都被砍废了),所以可供我们选择的库非常有限。这里我选择了 CGBN。这是 NVIDIA 官方的实现,做的优化也比较好,而且比较新有30系的支持。缺点是,这个库只支持全局固定的 bit size,而且一个整数至少要在 4 个线程上跑,在 256-bit 的时候 sync 开销有点大。不过对于 ECDLP 来说已经够用了。
因为 CUDA 的整数运算单元都是 32 位的,所以在 GPU 上高精度整数是切分成 32 位的 limb 来存储和运算的。CGBN 的一个整数会被分成连续的 TPI 段,同时由 TPI 个线程存储和运算。大部分的高精度整数运算都已经被 CGBN 实现了,这里简单介绍提一下里面有意思的细节。
Addition and Multiplication
对密码学量级的整数,朴素的加法和乘法是最快的。
单线程加法只需要从最低的一个 limb 开始相加,并把进位传递到更高的 limb 即可。CUDA 已经提供了 add with carry 的原语,可以实现高效的多线程加法。但 CGBN 是多线程的,如果要按照单线程的思路来实现加法,每个线程都需要等待前一个线程的进位传递过来,才能继续计算,这样就被串行化了。因此,我们需要利用加法是 的性质,通过 CUDA 的 warp sync 实现一次性的同步,从而避免串行化。
乘法就只是一个朴素的 乘法实现,多线程实现基本无法避免大量的同步开销。只能从最低位开始每次 sync 一个 limb 算乘法和进位最后累加。
Montgomery Multiplication
对于椭圆曲线计算来说,真正的瓶颈在于模运算。除法和模运算涉及到大量的分支和循环,对 GPU 非常不友好。如果每次乘法都要做模运算,速度会非常慢。2 幸运的是,我们可以采用 Montgomery 乘法,它能在无需昂贵除法的情况下高效完成模乘,非常适合 GPU 无分支实现。用 Montgomery 乘法计算 可以分成三个步骤:
- 任意选一个 ,先计算 和 。
- 计算 ,然后乘上 ,得到 。
- 将结果乘上 ,将结果还原成 。
这个过程并没有真正地解决取模的问题,第二步和第三步乘 之后仍然需要取模把结果限制到 的范围内。Montgomery 乘法的关键在于提出了 REDC 算法,在 的情况下,可以高效地计算出 。
REDC 的具体算法如下:
- 选择一个 ,且 与 互质。预计算 。
- 计算 。
- 计算 。
- 如果 ,则返回 ,否则返回 。
这个算法做了两件事:先加上 的倍数使其能被 整除,然后除以 把值缩小到接近 的范围内。我们可以证明 ,所以最多减一次 就可以把结果限制到 内。这样模 的计算就被转化成了 的取模和除法。这样看起来好像并没有什么意义,但当我们选择 为 的幂,就会发现 的取模和除法可以通过位运算实现。同样,第一步把 转换成 也可以乘上 后 REDC 实现。
因此,我们之后的计算都会以 的形式存储。3
Faster Modular Inverse
如果我们想解决的是 DLP,那只用 Montgomery Multiplication 就已经足够了。而对 ECDLP 来说,一个重要的问题是点的加法运算中需要计算模逆。
对于 CUDA,每 32 个线程被称为一个 warp,这 32 个线程会同时执行同一条指令。如果在执行过程中有某些线程需要执行不同的指令(比如分支),就会导致 warp 内部的线程被串行化,极大地降低性能。像我们常用的 xgcd 算法,每次迭代都需要比较两个数的大小来决定下一步的操作,这就会引入大量的分支,而且每次比较大小和加减法都需要同步所有线程,性能就会非常糟糕。
这个问题对测信道敏感的场景也很麻烦,比如比特币的椭圆曲线运算。这里他们使用了一个非常魔法的算法,在不比较大小的情况下,只通过最低位的奇偶性就可以决定下一步的操作。具体如下:
def gcd(f, g): """Compute the GCD of an odd integer f and another integer g.""" assert f & 1 # require f to be odd delta = 1 # additional state variable while g != 0: assert f & 1 # f will be odd in every iteration if delta > 0 and g & 1: delta, f, g = 1 - delta, g, (g - f) // 2 elif g & 1: delta, f, g = 1 + delta, f, (g + f) // 2 else: delta, f, g = 1 + delta, f, (g ) // 2 return abs(f)大概思路是每次迭代就算 符号不对,除2也可以确保两个数的绝对值不会增加。然后在这个 delta 的神秘操作下,可以保证我们有较高的概率能让 符号恰好能让值变小 2 倍,于是就能在 的迭代内变成 1。具体证明可以看看原论文和这个 repo,大概就是依靠凸包控制范围,总之很神秘。而且因为算法只关注最低位的奇偶性,所以甚至可以很方便地预测每次迭代的操作,从而用乘法实现批量迭代。
最终在一堆魔法的加持下,这个算法在实现简单的同时性能也非常好,具体可以看 libsecp256k1 里面那些人的讨论。不过也有一个代价是这里的两个数是有符号的,而 CGBN 不支持有符号整数的运算(加减法和无符号整数等价,不过乘法就不行了),所以我们还是得稍微修改一下库。最后也是比 CGBN 给的 sample 快了 3 倍。
Pollard’s Kangaroo Algorithm
下一个问题是传统的 BSGS 算法在 64 bits 往上的时候内存消耗就非常大了,所以这里需要一个更节省内存的算法。
Pollard’s Kangaroo Algorithm 的大概思路是让若干袋鼠在曲线上进行确定性的随机游走,比如有一个 6 bits 的 jump table,通过当前点的最后 6 bits 确定下一步加上 jump table 里对应的点。当 jump table 里面的 offset~=O(sqrt(n)) 的时候,袋鼠碰撞的概率会变成大约 。因为碰撞后两只袋鼠将会走同样的路径,所以我们不需要储存所有的点,只需要储存最后 k bit 都为 0 的点就可以拦截这两只袋鼠。关于 Pollard’s Kangaroo Algorithm 的介绍和实现可以参考 Kangaroo-256-bit-python 和 Kangaroo。不过他们是给 secp256k1 设计的,所以做了很多特殊优化,没法直接拿来用。
Batched Inverse
对 ECDLP,就算我们用了我们的快速求逆,一次求逆还是比乘法慢了 60 倍左右。4 于是一个神奇的办法是把很多个数的求逆合并成一个批量求逆。也就是
具体实现的时候可以先存储所有的累乘,然后求出 ,最后再通过乘法得到每个 。这样一共需要一次求逆和 次乘法。不过由于 GPU 的缓存和袋鼠个数的限制,批量求逆的数量不能太大。但不论如何 ECDLP 可以轻松用上 100k+ 袋鼠,可以说非常享受了。最终实现在这里 CGBN,不过没怎么优化过缓存,特别是 ECDLP 的超大 batch inverse。
Final
在绕了这么大一个圈子后,我们可以狠狠地拿 GPU 来非预期炸掉 seccon 2024 quals 的 xiyi 了。这道题的问题是 ,所以我们可以寻找一个 使得 的阶非常小。这里唯一的办法是找 的素因子中恰为 518 bits 的素数。经过一番 factordb 搜索,只能找到 和 ,在一堆精细优化后,最后需要解一个 61 bits 和 66 bits 的 DLP。然后就是 GPU 暴力了。
最终在 3070 Ti 上做到了大约 4 分钟解出 32 个 66 bits 的 DLP,成功拿到 flag。
solve.py
from subprocess import check_outputfrom sage.all import *from sage.groups.generic import bsgsfrom tqdm import tqdmfrom pwn import remote, process
def run_kangaroo(target, g, order, bound=None, num_retry=3): # solve target = g**k mod p if bound is None: bound = order assert order % 2 == 1 p = g.parent().characteristic() inp = f"{p} {g} {order} {target} {bound}\n" print(f"Running kangaroo with {bound.bit_length()} bits...") print(inp) ret = None for _ in range(num_retry): try: ret = check_output("./kangaroo_gpu", input=inp.encode(), shell=True).decode() assert "Solution:" in ret break except Exception as e: print(e) continue assert ret is not None, "Kangaroo failed" print(ret) k = ret.split("Solution:")[1] k = k.split("\n")[0] return ZZ(k)
def dlp_generic(target, g, order, max_bound, num_retry=3): if max_bound.bit_length() > 32: return run_kangaroo(target, g, order, max_bound, num_retry) else: return bsgs(g, target, bounds=(ZZ(0), max_bound))
def dlp_prime(target, g, order, num_retry=3): if order.bit_length() > 34: return run_kangaroo(target, g, order, num_retry=num_retry) else: os = matrix(ZZ, [[order, 1]]) return pari(target).znlog(pari(g), pari([order, os])).sage()
def dlp_primepower(target, g, prime, n, num_retry=3): order = prime**n k = 0 g0 = g ** (order // prime) for i in range(n): t0 = target * (g**(-k)) t0 = t0 ** (prime**(n-1-i)) assert t0 ** prime == 1 k += dlp_prime(t0, g0, prime, num_retry) * prime**i return k
def dlp_main(target, g, order, max_bound, factors, num_retry=3): # solve target = g**k mod p cur_bound = max_bound mods = [] for prime, n in factors: if cur_bound <= 10: break if prime > cur_bound: # too large, fall back to generic break mod = prime**n ord_ = order // mod g0 = g ** ord_ t0 = target ** ord_ cur_bound = cur_bound // mod k_ = dlp_primepower(t0, g0, prime, n, num_retry) mods.append((k_, mod)) rs, ms = zip(*mods) k0 = crt(list(rs), list(ms)) MOD = prod(ms) # k = k0 + ? * MOD g1 = g ** MOD t1 = target * (g ** (-k0)) k1 = dlp_generic(t1, g1, order // MOD, cur_bound+1, num_retry) return k0 + k1 * MOD
def dlp(target, g, order=None, bounds=None, known_factors=None, num_retry=3): """ Solve DLP for target = g**k mod p
- **order**: order of g. A multiple of the order (e.g. p-1) is also accepted. - **bounds**: The bound for k. If None, it will be set to `order-1`. If integer, search range is [0, bound]. If list/tuple, represents an interval [min, max]. - **known_factors**: List of known factors of order. The rest part of the order will be handled by generic method. Either (prime, exponent) pairs or prime lists are accepted. - **num_retry**: Number of retries when kangaroo fails. """ assert g.parent().is_prime_field() p = g.parent().characteristic() if order is None: order = p-1 else: order = ZZ(order) if bounds is None: bounds = order if known_factors is None: known_factors = list(factor(order, algorithm="ecm")) if isinstance(known_factors[0], (list, tuple)): known_primes = [p for p, _ in known_factors] else: known_primes = known_factors if 2 not in known_primes: known_primes.append(2) assert g ** order == 1 known_primes.sort() for p in known_primes: while order % p == 0 and (g ** (order // p)) == 1: order = order // p assert g ** order == 1 factors = [] ord_ = order for p in known_primes: if ord_ % p != 0: continue v = ord_.valuation(p) factors.append((p, v)) ord_ = ord_ // (p**v) if ord_ > 1: factors.append((ord_, 1)) factors.sort() print(factors) if bounds is None: bounds = (0, order-1) elif isinstance(bounds, (list, tuple)): bounds = tuple(bounds) else: bounds = (0, int(bounds)) min_bound, max_bound = bounds t2 = target * (g ** (-min_bound)) range = max_bound - min_bound return min_bound + dlp_main(t2, g, order, range, factors, num_retry), factors
import jsonfrom secrets import randbelow
from lib2 import Cryptosystem, Privkey, Pubkey, Ct # same as lib.pyfrom params import L, M, N
p = 470772889347891753894848105243789217710301326704247389714521154467591283889037042443973128895041119713324021962506767036488831747747963841111160849037254561 # (2^1308+1)q = 731433453306625773851831070026391128383727479375152752977676902094814479855682993608187229204658713734108641142148010876470120999557683935590321341245075861 # (2^2490+1)
p1_factors = [(2, 5), (3, 3), (5, 1), (109, 1), (457456471313, 1), (53400322888084232503, 1), (4952131943166985961003171, 1), (12483277974290346429122493663434783301283, 1), (662043754266558743926932151479961636112025140409089081, 1)]q1_factors = [(2, 2), (3, 1), (5, 1), (61, 1), (83, 1), (2099, 1), (14071, 1), (361728968932127, 1), (225369400669558848086892569194138533812747585329213772218859604310602106229150875171288346031013036088614600109065387761801873239, 1)]n_factors = list(sorted(set([pf[0] for pf in p1_factors] + [qf[0] for qf in q1_factors])))+[p]p_used = [2, 3, 457456471313, 53400322888084232503]q_used = [61, 2099, 14071, 361728968932127]p_unused = prod([4952131943166985961003171, 12483277974290346429122493663434783301283, 662043754266558743926932151479961636112025140409089081])q_unused = 225369400669558848086892569194138533812747585329213772218859604310602106229150875171288346031013036088614600109065387761801873239
n = p**2 * qorder = p*lcm(p-1,q-1)
for x in range(2, 10000): x = mod(x, n) assert x ** order == 1 good = True for p0 in n_factors: if (x ** (order // p0)) == 1: good = False break if good: gx = int(x) break
bad_g = mod(n // 2, p*q)bad_g_order = 1085640assert bad_g ** bad_g_order == 1
priv_key = Privkey(p, q, Pubkey(n))
io = process(['sage', 'server.py'])
# initializeC = Cryptosystem.from_privkey(priv_key)assert C.privkey is not Noneenc_xs = [Ct(pow(gx, 281**i, n)) for i in range(L)]
# 1: (client) --- n, enc_xs ---> (server)io.sendlineafter(b"> ", json.dumps({"n": n, "enc_xs": enc_xs}).encode())
def decrypt_correct(c: int) -> int: a = C.L(pow(c, p - 1, p**2)) b = C.L(pow(gx, p - 1, p**2)) return a * pow(b, -1, p) % p
# 3: (server) --- enc_alphas, beta_sum_mod_n ---> (client)params = json.loads(io.recvline().strip().decode())enc_alphas, beta_sum_mod_n = params["enc_alphas"], params["beta_sum_mod_n"]beta_sum_mod_n = decrypt_correct(n//2) * beta_sum_mod_n % palphas = [decrypt_correct(enc_alpha) for enc_alpha in enc_alphas]alpha_sum = sum(alphas) % pinner_product = (alpha_sum + beta_sum_mod_n) % pprint(f"{inner_product = }")
ys = []for i in tqdm(range(L)): target = pow(enc_alphas[i], bad_g_order, p*q) base = pow(gx, 281**i, n) base = pow(base, bad_g_order, p*q) # base ** y_i == target target2 = pow(target, p_unused * q_unused, p*q) base2 = pow(base, p_unused * q_unused, p*q)
p_i, f_p = dlp(mod(target2, p), mod(base2, p), (p-1)//p_unused, known_factors=p1_factors) q_i, f_q = dlp(mod(target2, q), mod(base2, q), (q-1)//q_unused, known_factors=q1_factors) # print(p_i, f_p) # print(q_i, f_q) mod1 = prod([pf[0]**pf[1] for pf in f_p]) mod2 = prod([qf[0]**qf[1] for qf in f_q]) ci = (inner_product - sum([ys[j] * 281**j for j in range(i)])) assert ci % (281**i) == 0 inner_product_i = ci // (281**i) y_i_0 = crt([p_i, q_i, inner_product_i%281], [mod1, mod2, 281]) MOD = mod1 * mod2 * 281 base3 = mod(base, p) ** ((p-1)//p_unused) target3 = mod(target, p) ** ((p-1)//p_unused) target3 = target3 / (base3 ** y_i_0) base3 = base3 ** MOD y_i_1 = dlp_generic(target3, base3, p_unused, M//MOD) y_i = y_i_0 + y_i_1 * MOD print(f"{i = }, {y_i = }") ys.append(int(y_i))
# If, by any chance, you can guess ys, send it for the flag!io.sendlineafter(b"> ", json.dumps({"ys": ys, "p": C.privkey.p, "q": C.privkey.q}).encode())print(io.recvline().strip().decode()) # Congratz! or Wrong...print(io.recvline().strip().decode()) # flag or ys
