之前参加了 imaginaryCTF 的练习赛,解了 Complex Curve Crypto
这道在复数域上的椭圆曲线题,很好玩。做这道题的时候也学习了一下复数域上的椭圆曲线,所以会顺带介绍一下。
椭圆曲线与模形式
Theorem. 对所有椭圆曲线 ,都有 ,其中 是一个复数域上的格。
这个定理证明很复杂,我自己也搞不太懂。大概是 Weierstrass 函数同时满足了 和 ,所以可以构造一个映射。
这个定理说明我们可以把椭圆曲线上的离散对数问题变成一个格上的离散对数问题。而格上的离散对数就是一个有理逼近问题,所以可以用 LLL 解决。
接下来是 -invariant 的定义。这个东西最初是用来分类格的同构类的。首先一个格在旋转下是同构的,所以我们可以把一个格简化为 。但即便是这样, 也不能唯一确定格的同构类,例如 也表示了同样的格。实际上,一个格的全部同构类为 ,其中 。
Prop. -invariant 是唯一一个满足 且 ,定义在复数域上半平面的全纯函数。
同样,这个的证明也很难,而且也牵扯到了模形式的一些东西。所以只做密码学的话可以跳过了,总之先相信(x
对于所有的格,-invariant 都唯一地确定了它的同构类。基于前面的定理,我们可以把椭圆曲线的同构类和格的同构类联系起来,这自然给出了椭圆曲线的 -invariant。1
通过各种变换,我们可以把所有 变换到 的区域上(图自己去 wikipedia 看)。
然后,我们来看一下椭圆曲线在复数域上的 isogeny。首先,我们可以仿照有限域,利用 torsion point 来构造 isogeny。而 torsion point 在格上就变成了类似 等分点的东西。因此从格的角度来看,isogeny 是两个共享很多 等分点的格之间的部分同态(或者说嵌入?)。
Prop. classical modular polynomial 的一组解是一个 isogeny 的对应格的 -invariant,即 是 classical modular polynomial 的一组解。
在椭圆曲线 isogeny 的 mitm 攻击中,我们也会用到 classical modular polynomial 来加速遍历。这就是用到了它的解集恰好是 isogeny 的两个椭圆曲线的 -invariant 这个性质。对于一个 ,它的 次 isogeny 有 个,分别是 和 ,恰好是格的所有不同构的 等分点。
所以说,isogeny 在复数域上的性质是非常好的,后面这道题就是一个例子。
Complex Curve Crypto
from sage.all import *phi = classical_modular_polynomial(l:=3)flag = ZZ.from_bytes(os.environ.get('FLAG', 'jctf{fakeflag}').encode())j = jp = ComplexField(1337)(1337)for d in flag.digits(l): roots = phi(X=j).univariate_polynomial().roots(multiplicities=False) roots.sort(key=jp.dist) del roots[0] # no turning back roots.sort(key=arg) j, jp = roots[d], jprint(j)
这道题通过 classical modular polynomial 实现了复数域上的 isogeny。我们的目标是找到通往最终 -invariant 的路径。
我们先调用 pari 的库算出起点和终点的 。但-invariant 的逆不唯一,我们实际获取的那个是 normalized 的那一个解。
根据前面的说明,我们知道起始值是 ,而每次 isogeny 都会将其变成 。2 因此在最终 会变成 。可以看到,这个过程中 的虚部会不断变小,所以最终的解一定无法轻易地 normalize 到 的区域上。
于是,问题转化为了给定 ,怎么找到 使得 。我们可以把这个写成 的变换,即 。
这个式子的实部因为 的存在,是非常难解的,但好消息是虚部是不变的,而 的值大概是 的量级,所以如果精度足够高的话,我们可以用 LLL 来找小整数解。我们先利用 化简一下虚部。设,
在一通消元后,我们只需要把 当做未知数,构造一个矩阵,然后用 LLL 求解。这个 LLL 还是有一些不稳定,所以需要用 的关系稍微调一下。
在解出 之后,我们可以直接算出每一步对应的 -invariant,然后稍微验证一下就可以知道是第几个了。
solve.py
from sage.all import *from mpmath import mp
def babai_cvp(B, t, perform_reduction=True): if perform_reduction: B = B.LLL(delta=0.75)
G = B.gram_schmidt()[0] b = t for i in reversed(range(B.nrows())): c = ((b * G[i]) / (G[i] * G[i])).round() b -= c * B[i]
return b
l = 3phi = classical_modular_polynomial(l)flag = ZZ.from_bytes(os.environ.get('FLAG', 'jctf{ghrf65fakefd2lag}').encode())# flag = ZZ.from_bytes(b'jctf{fakef212113lag}')j = jp = ComplexField(1337)(1337)print(len(flag.digits(l)), flag.digits(l))ddd = [2, 1, 2, 2, 1, 2, 2, 0, 1, 1, 1, 2, 0, 0, 0, 1, 0, 2, 0, 2, 1, 0, 1, 0, 1, 2, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 2, 1, 0, 1, 0, 0, 2, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 1, 1, 0, 1, 0, 2, 2, 1, 0, 2, 0, 0, 1, 2, 0, 2, 2, 2, 2, 0, 0, 1, 0, 1, 2, 1, 1, 2, 0, 0, 2, 2, 1, 2, 0, 1, 2, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 2, 1, 2, 2, 0, 0, 1, 0, 1, 2, 0, 0, 2, 0, 1, 1, 1, 2, 1, 2, 1, 1, 2, 0, 1, 2, 1, 2, 2, 0, 1]for d in ddd: roots = phi(X=j).univariate_polynomial().roots(multiplicities=False)
roots.sort(key=jp.dist) del roots[0] # no turning back roots.sort(key=arg) j, jp = roots[d], jprint(j)target_j = j
mp.dps = 8000pari("\p 8000")def get_tau(j_val): """ solve j(tau) = j_val """ j_val = CC(j_val) output = pari(f"ellinit([{str(j_val)}]).omega") w1, w2 = output tau = w1 / w2 return CC(tau)
CC = ComplexField(8000)# target_j = CC(target_j)target_j = CC(["-6380.29156755903034687704660000282270448709570700984709657263854520138341881550084848649839410922876936678108641559731906139407589458824525200739437427675429842545943862350406898342718600643307574207936681432539822432021043321858999100455988010088297166788530685045590301521392833041157761230553808660966639987227996254364651040683587985017755013287627806460764447918220597797227073500295101888328945378", "-40214.6672669145274363733990190234269396915010860735171342583778810557492152506583335099397606096949655915384682940536670398063862158837139959936293817852537495478495004482851179789975853315750442943613549504080182872285152393975211173312274904654221951964258004673123691758298620284682558712820558193891787303533093197420998108945366704043133890680849252774739771188873183182185315971489373696493547256"]) # i=136t = get_tau(CC(1337))t = (t - 4) / (t - 3)
R = PolynomialRing(RR, 'x')xbar = R.gen()
def try_get_transform(t1, t2): """ t2 = (a*t1 + b) / (c*t1 + d), a*d - b*c = 1 """ x0 = t1.real() x1 = t2.real() y0 = t1.imag() y1 = t2.imag() u = y0 / y1 v1 = x0 ** 2 + y0 ** 2 v2 = 2 * x0 S1 = 2**3800 S2 = 2**3000 M = matrix(ZZ, [[S1, -S2, 0, 0], [(v1*S1).round(), 0, -S2, 0], [(v2*S1).round(), 0, 0, -S2]]) target = vector(ZZ, [(u*S1).round(), 0, 0, 0]) M = M.LLL() short_vectors = [] v=M[0] v = [v[1] // S2, v[2] // S2, v[3] // S2] sol = babai_cvp(M, target, perform_reduction=False)
dd = sol[1] // S2 cc = sol[2] // S2 cd = sol[3] // S2 print(M[0][0].bit_length(), abs(sol[0]).bit_length(), sol[1].bit_length(), sol[2].bit_length(), sol[3].bit_length()) print(cc, cd, dd) print(v)
f = (dd+v[0]*xbar)*(cc+v[1]*xbar) - (cd+v[2]*xbar)**2 rrr = f.roots(multiplicities=False) find = False print("rrr", rrr) for r in rrr[1:]: rr = r.round() cc2 = cc + rr * v[1] dd2 = dd + rr * v[0] cd2 = cd + rr * v[2] if cc2.is_square() and dd2.is_square() and gcd(cc2, dd2) == 1: print("Found a solution:", cc2, cd2, dd2) find = True cc, cd, dd = cc2, cd2, dd2 break if not find: return None
# print(cc * v1 + cd * v2 + dd - u) if not (cc.is_square() and dd.is_square()): return None c = sqrt(cc) d = sqrt(dd) if cd < 0: d = -d assert c * d == cd print(f"c = {c}, d = {d}") print(d**2+c**2 * v1 + cd * v2 - u) # c**2 * (x0**2+y0**2) + 2cd * x0 + d**2 = u return c, d
tau1 = ttau2 = get_tau(target_j)
print()print(target_j)print(elliptic_j((t+1)/3))
for i in range(10, 180): print("checking i =", i) k = 3**i ret = try_get_transform(tau2, tau1/k) if ret is None: continue c, d = ret print(gcd(c, d))
_, a, b = xgcd(d, c) b = -b
assert a * d - b * c == 1
tau2_transformed = (a * tau2 + b) / (c * tau2 + d) u = tau2_transformed.real_part().floor() a -= u * c b -= u * d tau2_transformed = (a * tau2 + b) / (c * tau2 + d) print(tau2_transformed)
u = (tau2_transformed*k).real() v = tau1.real() if abs(u.frac() - v.frac()) > abs(u.frac() - (1 - v.frac())): raise tau2_transformed = 1-tau2_transformed.conjugate()
print((tau2_transformed*k - tau1)) print((tau2_transformed*k - tau1).real().round(), k) u0 = (tau2_transformed*k - tau1).real().round() break
def get_j(tau): while True: tau -= tau.real_part().round() if tau.norm() <= 1: tau = -1 / tau else: break return elliptic_j(tau)
print(get_j((tau1+u0) / k))digits = []test_j = test_jp = ComplexField(1337)(1337)
def walk(j, jp, d): roots = phi(X=j).univariate_polynomial().roots(multiplicities=False) roots.sort(key=jp.dist) del roots[0] # no turning back roots.sort(key=arg) return roots[d], j
for i in range(400): if k <= 3**i: break
jp = get_j((tau1+u0) / 3**(i-1)) j0 = get_j((tau1+u0) / 3**i) j1 = get_j((tau1+u0) / 3**(i+1))
roots = phi(X=ComplexField(1337)(j0)).univariate_polynomial().roots(multiplicities=False) roots.sort(key=jp.dist) del roots[0] # no turning back roots.sort(key=j1.dist) target = roots[0] # print(j1.dist(target)) assert j1.dist(target) < 1e-10, j1.dist(target) roots.sort(key=arg) d = roots.index(target) digits.append(d)
assert test_j.dist(get_j((tau1+u0) / 3**(i))) < 1e-10, i test_j, test_jp = walk(test_j, test_jp, d)
print(test_j)print("Digits:", digits)
flag = sum(d * l**i for i, d in enumerate(digits))print(bytes.fromhex(hex(flag)[2:]))