Writeups for crypto challenges from SekaiCTF 2025. unfairy-ring
will be in a separate post.
SSSS
Description
Shamir SendS the Secret to everyone
Author: Utaha
chall.py
import random, os
p = 2 ** 256 - 189FLAG = os.getenv("FLAG", "SEKAI{}")
def challenge(secret): t = int(input()) assert 20 <= t <= 50, "Number of parties not in range"
f = gen(t, secret)
for i in range(t): x = int(input()) assert 0 < x < p, "Bad input" print(poly_eval(f, x))
if int(input()) == secret: print(FLAG) exit(0) else: print(":<")
def gen(degree, secret): poly = [random.randrange(0, p) for _ in range(degree + 1)] index = random.randint(0, degree)
poly[index] = secret return poly
def poly_eval(f, x): return sum(c * pow(x, i, p) for i, c in enumerate(f)) % p
if __name__ == "__main__": secret = random.randrange(0, p) for _ in range(2): challenge(secret)
Solution
In general, Shamir needs degree+1
points to reconstruct the polynomial. However if we have g
with order t
, the constant term overlaps with the t
-th term, so the polynomial can be reconstructed with t
points. Notice that (p-1)%29==0
, so we can find g
with order t=29
. Since we have two queries, secret is the intersection of the coefficients of the two polynomials.
solve.py
from sage.all import *from pwn import *
context.log_level = 'debug'
p = 2 ** 256 - 189R = PolynomialRing(GF(p), 'x')t = 29
io = process(["python3", "chall.py"])
def sample(): io.sendline(str(t).encode()) while True: g = randint(1, p) g = pow(g, (p-1)//t, p) if g != 1: break shares = [] for i in range(t): x0 = pow(g, i, p) io.sendline(str(x0).encode()) y0 = int(io.recvline().strip()) shares.append((x0, y0)) return R.lagrange_polynomial(shares).coefficients()
s0 = sample()io.sendline(b'1')io.recvline()s1 = sample()
for secret in set(s0) & set(s1): io.sendline(str(secret).encode()) io.interactive()
I Dream of Genni
Description
I had the strangest dream last night.
Author: Neobeo
chall.py
from hashlib import sha256from Crypto.Cipher import AES
x = int(input('Enter an 8-digit multiplicand: '))y = int(input('Enter a 7-digit multiplier: '))assert 1e6 <= y < 1e7 <= x < 1e8, "Incorrect lengths"assert x * y != 3_81_40_42_24_40_28_42, "Insufficient ntr-opy"
def dream_multiply(x, y): x, y = str(x), str(y) assert len(x) == len(y) + 1 digits = x[0] for a, b in zip(x[1:], y): digits += str(int(a) * int(b)) return int(digits)assert dream_multiply(x, y) == x * y, "More like a nightmare"
ct = '75bd1089b2248540e3406aa014dc2b5add4fb83ffdc54d09beb878bbb0d42717e9cc6114311767dd9f3b8b070b359a1ac2eb695cd31f435680ea885e85690f89'print(AES.new(sha256(str((x, y)).encode()).digest(), AES.MODE_ECB).decrypt(bytes.fromhex(ct)).decode())
Solution
A simple branch and prune is enough to solve this challenge. We could use MITM to find larger pairs, but once the digits go beyond 10, it is nearly impossible to find a valid pair.
If we accept 1*3=03
, we can still find some larger pairs like 8827954385162 * 987958187368 = 8721649812532036435033616
.
solve.cpp
#include <iostream>#include <vector>#include <algorithm>#include <map>#include <memory>#include <cmath>// g++ -O3 solve.cpp -o solve && ./solve
typedef __int128 int128_t;std::ostream &operator<<(std::ostream &os, int128_t val){ std::string result; bool is_negative = val < 0; if (is_negative) { val *= -1; }
do { result.push_back((val % 10) + '0'); val /= 10; } while (val != 0);
if (is_negative) { result.push_back('-'); }
std::reverse(result.begin(), result.end()); return (os << result);}
typedef std::pair<int64_t, int64_t> vec2;typedef std::tuple<int64_t, int64_t, int128_t> vec3;std::vector<vec2> pairs;
void gen1(int digits, std::shared_ptr<std::vector<vec3>> &result){ if (digits == 1) { for (const auto &pair : pairs) { int64_t a = pair.first; int64_t b = pair.second; result->emplace_back(a, b, a * b); } return; }
int128_t mul = std::round(std::pow(10, digits - 1)); int128_t mul1 = std::round(std::pow(10, digits)); int128_t mul2 = mul * mul; if (mul % 10 != 0 || mul1 % 10 != 0 || mul2 % 10 != 0) { std::cerr << "Error: mul is not a multiple of 10" << std::endl; exit(1); }
std::shared_ptr<std::vector<vec3>> result0 = std::make_shared<std::vector<vec3>>(); gen1(digits - 1, result0); for (const auto &item : *result0) { int64_t x = std::get<0>(item); int64_t y = std::get<1>(item); int128_t xy_mul = std::get<2>(item); for (const auto &pair : pairs) { int64_t a = pair.first; int64_t b = pair.second; int128_t x2 = a * mul + x; int128_t y2 = b * mul + y; int128_t xy_mul2 = xy_mul + (a * b) * mul2; if ((x2 * y2 - xy_mul2) % mul1 == 0) { result->emplace_back(x2, y2, xy_mul2); } } }}
void gen2(int digits, std::shared_ptr<std::vector<vec3>> &result){ if (digits == 1) { for (const auto &pair : pairs) { int64_t a = pair.first; int64_t b = pair.second; for (int64_t i = 1; i < 10; i++) { int64_t x = 10 * i + a; int64_t y = b; result->emplace_back(x, y, 100 * i + a * b); } } return; }
std::shared_ptr<std::vector<vec3>> result0 = std::make_shared<std::vector<vec3>>(); gen2(digits - 1, result0); for (const auto &item : *result0) { int64_t x = std::get<0>(item); int64_t y = std::get<1>(item); int128_t xy_mul = std::get<2>(item); for (const auto &pair : pairs) { int64_t a = pair.first; int64_t b = pair.second; int128_t x2 = 10 * x + a; int128_t y2 = 10 * y + b; int128_t xy_mul2 = 100 * xy_mul + a * b; int128_t lower_bound = x2 * y2; int128_t upper_bound = (x2 + 1) * (y2 + 1); if (lower_bound <= xy_mul2 && xy_mul2 < upper_bound) { result->emplace_back(x2, y2, xy_mul2); } } }}
#define M1 2#define M2 3#define M3 2
int main(){ for (int64_t i = 1; i < 10; i++) for (int64_t j = 1; j < 10; j++) { if (i * j >= 10) { pairs.push_back({i, j}); } }
std::shared_ptr<std::vector<vec3>> result1 = std::make_shared<std::vector<vec3>>(); std::shared_ptr<std::vector<vec3>> result2 = std::make_shared<std::vector<vec3>>(); gen1(M2 + M3, result1);
std::cout << "Size1: " << result1->size() << std::endl;
int128_t last_M3_digits = std::round(std::pow(10, M3)); int128_t last_M3_digits2 = last_M3_digits * last_M3_digits; int128_t last_M2_digits = std::round(std::pow(10, M2));
std::shared_ptr<std::map<uint64_t, std::vector<vec3>>> result_map = std::make_shared<std::map<uint64_t, std::vector<vec3>>>(); for (const auto &item : *result1) { int64_t x = std::get<0>(item); int64_t y = std::get<1>(item); int128_t xy_mul = std::get<2>(item); uint64_t x2 = x / last_M3_digits; uint64_t y2 = y / last_M3_digits; uint64_t key = (x2 << 32) | y2; if (result_map->find(key) == result_map->end()) { (*result_map)[key] = std::vector<vec3>(); } (*result_map)[key].emplace_back(x, y, xy_mul); } result1.reset(); std::cout << "Finished processing map" << std::endl;
gen2(M1 + M2, result2); std::cout << "Size2: " << result2->size() << std::endl; for (const auto &item : *result2) { int64_t x = std::get<0>(item); int64_t y = std::get<1>(item); int128_t xy_mul = std::get<2>(item); uint64_t x2 = x % last_M2_digits; uint64_t y2 = y % last_M2_digits; uint64_t key = (x2 << 32) | y2; auto it = result_map->find(key); if (it != result_map->end()) { for (const auto &pair : it->second) { int64_t x1 = std::get<0>(pair); int64_t y1 = std::get<1>(pair); int128_t xy_mul1 = std::get<2>(pair); int128_t x_final = (int128_t)x * last_M3_digits + (int128_t)x1 % last_M3_digits; int128_t y_final = (int128_t)y * last_M3_digits + (int128_t)y1 % last_M3_digits; int128_t xy_mul_final = xy_mul * last_M3_digits2 + xy_mul1 % last_M3_digits2; if (x_final * y_final == xy_mul_final) { std::cout << "Found: x=" << x_final << ", y=" << y_final << ", x*y=" << xy_mul_final << std::endl; } } } }}
Literal Eval
Description
Trust me, it’s 1000% safe to use literal eval in crypto!
Author: Sceleri
chall.py
from Crypto.Util.number import bytes_to_longfrom hashlib import shake_128from ast import literal_evalfrom secrets import token_bytesfrom math import floor, ceil, log2import os
FLAG = os.getenv("FLAG", "SEKAI{}")m = 256w = 21n = 128l1 = ceil(m / log2(w))l2 = floor(log2(l1*(w-1)) / log2(w)) + 1l = l1 + l2
class WOTS: def __init__(self): self.sk = [token_bytes(n // 8) for _ in range(l)] self.pk = [WOTS.chain(sk, w - 1) for sk in self.sk]
def sign(self, digest: bytes) -> list[bytes]: assert 8 * len(digest) == m d1 = WOTS.pack(bytes_to_long(digest), l1, w) checksum = sum(w-1-i for i in d1) d2 = WOTS.pack(checksum, l2, w) d = d1 + d2
sig = [WOTS.chain(self.sk[i], w - d[i] - 1) for i in range(l)] return sig
def get_pubkey_hash(self) -> bytes: hasher = shake_128(b"\x04") for i in range(l): hasher.update(self.pk[i]) return hasher.digest(16)
@staticmethod def pack(num: int, length: int, base: int) -> list[int]: packed = [] while num > 0: packed.append(num % base) num //= base if len(packed) < length: packed += [0] * (length - len(packed)) return packed
@staticmethod def chain(x: bytes, n: int) -> bytes: if n == 0: return x x = shake_128(b"\x03" + x).digest(16) return WOTS.chain(x, n - 1)
@staticmethod def verify(digest: bytes, sig: list[bytes]) -> bytes: d1 = WOTS.pack(bytes_to_long(digest), l1, w) checksum = sum(w-1-i for i in d1) d2 = WOTS.pack(checksum, l2, w) d = d1 + d2
sig_pk = [WOTS.chain(sig[i], d[i]) for i in range(l)] hasher = shake_128(b"\x04") for i in range(len(sig_pk)): hasher.update(sig_pk[i]) sig_hash = hasher.digest(16) return sig_hash
class MerkleTree: def __init__(self, height: int = 8): self.h = height self.keys = [WOTS() for _ in range(2**height)] self.tree = [] self.root = self.build_tree([key.get_pubkey_hash() for key in self.keys])
def build_tree(self, leaves: list[bytes]) -> bytes: self.tree.append(leaves)
if len(leaves) == 1: return leaves[0]
parents = [] for i in range(0, len(leaves), 2): left = leaves[i] if i + 1 < len(leaves): right = leaves[i + 1] else: right = leaves[i] hasher = shake_128(b"\x02" + left + right).digest(16) parents.append(hasher)
return self.build_tree(parents)
def sign(self, index: int, digest: bytes) -> list: assert 0 <= index < len(self.keys) key = self.keys[index] wots_sig = key.sign(digest) sig = [wots_sig] for i in range(self.h): leaves = self.tree[i] u = index >> i if u % 2 == 0: if u + 1 < len(leaves): sig.append((0, leaves[u + 1])) else: sig.append((0, leaves[u])) else: sig.append((1, leaves[u - 1])) return sig
@staticmethod def verify(sig: list, digest: bytes) -> bytes: wots_sig = sig[0] sig = sig[1:] pk_hash = WOTS.verify(digest, wots_sig) root_hash = pk_hash for (side, leaf) in sig: if side == 0: root_hash = shake_128(b"\x02" + root_hash + leaf).digest(16) else: root_hash = shake_128(b"\x02" + leaf + root_hash).digest(16) return root_hash
class Challenge: def __init__(self, h: int = 8): self.h = h self.max_signs = 2 ** h - 1 self.tree = MerkleTree(h) self.root = self.tree.root self.used = set() self.before_input = f"public key: {self.root.hex()}"
def sign(self, num_sign: int, inds: list, messages: list): assert num_sign + len(self.used) <= self.max_signs assert len(inds) == len(set(inds)) == len(messages) == num_sign assert self.used.isdisjoint(inds) assert all(b"flag" not in msg for msg in messages) sigs = [] for i in range(num_sign): digest = shake_128(b"\x00" + messages[i]).digest(32) sigs.append(self.tree.sign(inds[i], digest)) self.used.update(inds) return sigs
def next(self): new_tree = MerkleTree(self.h) digest = shake_128(b"\x01" + new_tree.root).digest(32) index = next(i for i in range(2 ** self.h) if i not in self.used) sig = new_tree.sign(index, digest) self.tree = new_tree return { "root": new_tree.root, "sig": sig, "index": index, }
def verify(self, sig: list, message: bytes): digest = shake_128(b"\x00" + message).digest(32) for i, s in enumerate(reversed(sig)): if i != 0: digest = shake_128(b"\x01" + digest).digest(32) digest = MerkleTree.verify(s, digest) return digest == self.root
def get_flag(self, sig: list): if not self.verify(sig, b"Give me the flag"): return {"message": "Invalid signature"} else: return {"message": f"Congratulations! Here is your flag: {FLAG}"}
def __call__(self, type: str, **kwargs): assert type in ["sign", "next", "get_flag"] return getattr(self, type)(**kwargs)
challenge = Challenge()print(challenge.before_input)try: while True: print(challenge(**literal_eval(input("input: "))))except: exit(1)
Solution
The challenge is a hash based signature scheme which supports infinite time signatures. As the name suggests, the target is to use literal_eval
to craft malicious input.
The most suspicious function is sign
, which signs many messages at once. Additionally, you may specify the indexes of the WOTS keys to sign. A straightforward idea is to somehow reuse the same key for multiple messages, hence the one-time signature is no longer secure.
Although the sign
function does many checks to prevent this, the type hint of function parameters isn’t enforced. Hence, we can pass a dictionary with numerical keys. Since dict is also iterable, it won’t raise an error and the keys will be used as indexes. For example, {i: 0 for i in range(255)}
will generate 255 signatures with the same key.
The rest is a simple WOTS key reuse attack.
solve.py
from math import floor, ceil, log2from Crypto.Util.number import bytes_to_longimport osfrom hashlib import shake_128from ast import literal_evalfrom secrets import token_bytesfrom pwn import *
def pack(num: int, length: int, base: int) -> list[int]: packed = [] while num > 0: packed.append(num % base) num //= base if len(packed) < length: packed += [0] * (length - len(packed)) return packed
def get_d(digest, m = 256, w = 21): l1 = ceil(m / log2(w)) l2 = floor(log2(l1*(w-1)) / log2(w)) + 1 l = l1 + l2 d1 = pack(bytes_to_long(digest), l1, w) checksum = sum(w-1-i for i in d1) d2 = pack(checksum, l2, w) d = d1 + d2 return d
io = process(["python3", "chall.py"])io.recvuntil(b"public key:")root = bytes.fromhex(io.recvline().strip().decode())k = 255
def send(msg): io.recvuntil(b"input:") io.sendline(str(msg).encode()) ret = io.recvline().decode() if "Traceback" in ret: io.interactive() return literal_eval(ret)
msgs = [os.urandom(32) for _ in range(k)]disgests = [shake_128(b"\x00" + msg).digest(32) for msg in msgs]ds = [get_d(digest) for digest in disgests]sigs = send({ "type": "sign", "num_sign": k, "inds": {i: 0 for i in range(k)}, "messages": msgs,})
target_digest = shake_128(b"\x00" + b"Give me the flag").digest(32)target = get_d(target_digest)
wots_sign = []for i in range(len(target)): find = False for dd, sig in zip(ds, sigs): if dd[i] == target[i]: wots_sign.append(sig[0][i]) find = True break assert find
forged_sig = [wots_sign] + sigs[0][1:]print(send({ "type": "get_flag", "sig": [forged_sig],}))
Alter Ego
Description
Even so, You alone should drown in blue and vanish.
Before being consumed by those sorrow-filled eyes, divide the world by zero, humming with a hoarse voice.
The finale’s sound is ringing out.
Author: kanon
chall.sage
from Crypto.Util.number import *
from random import randintimport os
from montgomery_isogenies.kummer_line import KummerLinefrom montgomery_isogenies.kummer_isogeny import KummerLineIsogeny
FLAG = os.getenv('flag', "SEKAI{here_is_test_flag_hehe}").encode()proof.arithmetic(False)
MI = 3KU = 9MIKU = 39
ells = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 587]p = 4 * prod(ells) - 1
Fp = GF(p)F = GF(p**2, modulus=x**2 + 1, names='i')i = F.gen(0)E0 = EllipticCurve(F, [1, 0])E0.set_order((p + 1)**2)
def group_action(_C, priv, G): es = priv[:] while any(es): x = Fp.random_element() P = _C(x) A = _C.curve().a2() s = 1 if Fp(x ^ 3 + A * x ^ 2 + x).is_square() else -1
S = [i for i, e in enumerate(es) if sign(e) == s and e != 0] k = prod([ells[i] for i in S]) Q = int((p + 1) // k) * P for i in S: R = (k // ells[i]) * Q if R.is_zero(): continue
phi = KummerLineIsogeny(_C, R, ells[i]) _C = phi.codomain() Q, G = phi(Q), phi(G) es[i] -= s k //= ells[i]
return _C, G
def BEAM(base_alice_priv): alice_priv = base_alice_priv
pub = 0
for _ in range(MIKU):
E = EllipticCurve(F, [0, pub, 0, 1, 0]) omae_E = KummerLine(E) G = E.random_point() _G = omae_E(G)
_final_E1, _final_G = group_action(omae_E, alice_priv, _G) _final_G = _final_G print(f"final_a2 = {_final_E1.curve().a2()}") print(f"{_final_G=}")
omae_priv = list(map(int, input("your priv >").split(", ")))
assert all([abs(pi) < 2 for pi in omae_priv]) assert len(omae_priv) == len(ells)
alice_priv = [ai + yi for ai, yi in zip(alice_priv, omae_priv)] print("updated")
pub = _final_E1.curve().a2() print("FIN!")
if __name__ == "__main__":
print("And now, it's time for the moment you've been waiting for!")
alice_priv = [randint(MI + KU, MI * KU) for _ in ells] BEAM(alice_priv)
alter_ego = list(map(int, input('ready?! here is the "alter ego" >').split(", ")))
assert alice_priv != alter_ego assert len(alice_priv) == len(alter_ego) assert all([-MI * KU <= ai < 0 for ai in alter_ego])
_E0 = KummerLine(E0) G = E0.random_point() _G = _E0(G)
_alter_ego_E1, _ = group_action(_E0, alter_ego, _G) _alice_E1, __ = group_action(_E0, alice_priv, _G)
if _alter_ego_E1.curve().a2() == _alice_E1.curve().a2(): print("There you are... I've been waiting and waiting for you to come to me.") print(FLAG) else: print("YOU CANT FIND MY ALTER EGO....") exit()
Solution
Every can be decomposed into , where and lies on the twist of . The -isogeny in function group_action
can reduce the order of either or by a factor of depending on the sign.
For any odd prime , the order of can’t have factor if the sign in alice_priv
is positive. This indicates the sign of alice_priv
. Hence, we can keep sending and observe the order of and .
After retrieving alice_priv
, the next step is to find a negative priv that walks to the same curve. The idea is the same as su_auth
from SUCTF 2025. So simply [e-36 for e in alice_priv]
passes the check.
solve.py
from sage.all import *from pwn import *from tqdm import tqdm
MI = 3KU = 9MIKU = 39
ells = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 587]p = 4 * prod(ells) - 1
Fp = GF(p)F = GF(p**2, modulus=[1, 0, 1], names='i')i = F.gen(0)
E0 = EllipticCurve(F, [0, 0, 0, 1, 0])
io = process(['sage', 'chall.sage'])
def group_action(E0, priv): E = E0 es = priv[:] while any(es): x = Fp.random_element() P = E.lift_x(x) s = 1 if P[1] in Fp else -1 S = [i for i, e in enumerate(es) if sign(e) == s and e != 0] k = prod([ells[i] for i in S]) Q = ((p + 1) // k) * P
for i in S: R = (k // ells[i]) * Q if R.is_zero(): continue phi = E.isogeny(R) E = phi.codomain() Q = phi(Q) es[i] -= s k //= ells[i] return E
def read_curve(): io.recvuntil(b'final_a2 = ') a2 = int(io.recvline().strip()) E2 = EllipticCurve(F, [0, a2, 0, 1, 0]) E2.set_order((p + 1)**2) return E2
def read_point(EC): io.recvuntil(b'_final_G=') point_str = io.recvline().strip().decode() x, z = eval(point_str.replace(':', ', ')) x = F(x) z = F(z) G = EC.lift_x(x / z) return G
def get_orders(G): EC = G.curve() gen1 = gen2 = None while gen1 is None or gen2 is None: x0 = Fp.random_element() P = EC.lift_x(x0) is_gen = True for ell in ells: P1 = ((p + 1) // ell) * P if P1.is_zero(): is_gen = False break if not is_gen: continue if P.y() in Fp: gen1 = P else: gen2 = P pairing1 = G.weil_pairing(gen1, p+1) pairing2 = G.weil_pairing(gen2, p+1) return pairing2.multiplicative_order(), pairing1.multiplicative_order()
orders = []curves = []
for _ in tqdm(range(MIKU)): cur_E = read_curve() curves.append(cur_E) G = read_point(cur_E) orders.append(get_orders(G)) io.sendline(", ".join(["-1"]*len(ells)).encode())io.recvuntil(b"FIN!\n")
guess_priv = []
for ell in ells: left = MI + KU right = MI * KU for i, (o1, o2) in enumerate(orders): if o1 % ell == 0: right = min(right, i) if o2 % ell == 0: left = max(left, i) guess_priv.append((left, right-left))
print("guess_priv =", guess_priv)
E1 = curves[0]G0 = E0.random_point()
left, offset = zip(*guess_priv)left = list(left)offset = list(offset)print(offset)
def gen_all_possible_privs(offset): if len(offset) == 1: for i in range(offset[0]+1): yield [i] else: for os in gen_all_possible_privs(offset[1:]): for i in range(offset[0]+1): yield [i] + os
E1_base = group_action(E0, left)
def check(): for os in gen_all_possible_privs(offset): E1_new = group_action(E1_base, os).montgomery_model() # print(E1_new.j_invariant(), E1.j_invariant(), E1_new, os) if E1_new.j_invariant() == E1.j_invariant(): real_priv = [l + o for l, o in zip(left, os)] print("Found valid private key:", real_priv) return real_priv
real_priv = check()real_priv2 = [x - 36 for x in real_priv]io.sendline(", ".join(map(str, real_priv2)).encode())
io.recvuntil(b'There you are...')io.recvline()print(io.recvline().strip().decode())
APES
Description
The Advanced Permutation Encryption Standard (APES) offers 512-bit security for encrypting your permutations.
Author: Neobeo
chall.py
import osFLAG = os.getenv("FLAG", "SEKAI{TESTFLAG}")
def play(): plainperm = bytes.fromhex(input('Plainperm: ')) assert sorted(plainperm) == list(range(256)), "Invalid permutation"
key = os.urandom(64) def f(i): for k in key[:-1]: i = plainperm[i ^ k] return i ^ key[-1]
cipherperm = bytes(map(f, range(256))) print(f'Cipherperm: {cipherperm.hex()}') print('Do you know my key?') return bytes.fromhex(input()) == key
if __name__ == '__main__': # if you're planning to brute force anyway, let's prevent excessive connections while not play(): print('Bad luck, try again.') print(f'APES TOGETHER STRONG! {FLAG}')
Solution
This is a super tiny block cipher with block size 8 bits. The encyrption function consists of multiple rounds of substitution and round key addition. We can choose the substitution table and need to find an unique key.
Theoretically, if we choose a random permutation, we get about 1600 bits of information. But random permutation also makes it hard to find the key. So we need a more structured permutation and linear mapping seems to be a good choice.
Let’s say we use a linear mapping , where is a generator of the field. Since xor is the same as addition, after the encryption, we can get . As we can see, the only information we get is , which isn’t enough to recover the key.
The solution is simple, we can add some error to the linear mapping. For a 3 cycle permutation 1, only three numbers are changed, so and is another 3 cycle permutation. Therefore,
where and are linear.
Since these 3-cycles can affect at most numbers, we can recover using the remaining numbers. The problem is then reduced to decomposing the permutation into 3-cycles. We can simplify it by choosing a special case where every 3-cycle only make the cycle “larger”. This means that each time we add a 3-cycle, the cycle won’t be split into smaller cycles. With this property, we can ensure that these cycles don’t contain the numbers that satisfy the linear mapping . Hence, we can filter all possible and enumerate to find the real key.
The average probability of success is about in keys.
solve.py
from sage.all import *from pwn import *from tqdm import tqdmfrom collections import Counter
io = process(['python3', 'chall.py'])
F = GF(2**8, 'a')a = F.gen()
perm1 = []for i in range(256): x = F.from_integer(i) y = a * x perm1.append(y.to_integer())perm1[0], perm1[1], perm1[255] = perm1[1], perm1[255], perm1[0]
def fail(): io.sendline(b'00') io.recvuntil(b'Bad luck, try again.')
def try_solve(): io.sendlineafter(b'Plainperm: ', bytes(perm1).hex().encode()) io.recvuntil(b'Cipherperm: ') cipherperm = bytes.fromhex(io.recvline().strip().decode()) cipherperm_F = [F.from_integer(x) for x in cipherperm]
cipherperm_F = [x / (a**63) for x in cipherperm_F] for b in range(257): assert b < 256, "This should not happen" b0 = F.from_integer(b) cipherperm_F2 = [(x+b0).to_integer()+1 for x in cipherperm_F] if len(Permutation(cipherperm_F2).fixed_points()) > 20: break
perm2 = Permutation(cipherperm_F2) cycles = perm2.cycle_tuples() cycles = [cycle for cycle in cycles if len(cycle) > 1] u = sum([(len(cycle)-1)//2 for cycle in cycles])
if u != 63 or max([len(cycle) for cycle in cycles]) > 51: fail() return
cycles = [tuple(F.from_integer(x-1) for x in cycle) for cycle in cycles] print(([len(cycle) for cycle in cycles]), u)
all_numbers = [F.from_integer(i) for i in range(256)] possible_k0 = [all_numbers[:] for _ in range(63)]
def check(us): for cycle in cycles: if all(u in cycle for u in us): inds = [cycle.index(u) for u in us] inversions = 0 if inds[0] > inds[1]: inversions += 1 if inds[0] > inds[2]: inversions += 1 if inds[1] > inds[2]: inversions += 1 if inversions % 2 == 1: return False return True return False
for i in range(63): filtered = [] a0 = a ** i ts = [F.from_integer(t)/a0 for t in [0, 1, 255]]
for k0 in possible_k0[i]: us = [u + k0 for u in ts] if check(us): filtered.append(k0) possible_k0[i] = filtered
def update(): all_pos = sum([list(cycle) for cycle in cycles], []) c = {pos: [] for pos in all_pos}
for i in range(63): a0 = a ** i ts = [F.from_integer(t)/a0 for t in [0, 1, 255]] for k0 in possible_k0[i]: us = [u + k0 for u in ts] for u in us: c[u].append((i, k0))
for pos in c: if len(c[pos]) == 1: i, k0 = c[pos][0] possible_k0[i] = [k0]
for _ in range(16): update()
if prod([len(x) for x in possible_k0]) > 2**14: fail() return
print([len(x) for x in possible_k0], prod([len(x) for x in possible_k0]).bit_length())
possible_maps = [[] for _ in range(63)] for i in range(63): a0 = a ** i ts = [F.from_integer(t)/a0 for t in [0, 1, 255]] for k0 in possible_k0[i]: us = [u + k0 for u in ts] us = [x.to_integer() + 1 for x in us] perm3 = Permutation([tuple(us)]) possible_maps[i].append((k0, perm3))
def search(cur_map, i): if i == 63: if cur_map == perm2: return [[]] else: return None
all_results = [] for k0, perm3 in possible_maps[i]: new_map = cur_map * perm3 result = search(new_map, i + 1) if result is not None: for r in result: all_results.append([k0] + r)
if all_results: return all_results
pathes = search(cur_map=Permutation(list(range(1, 257))), i=0) if pathes is None or len(pathes) == 0: fail() return
print(len(pathes), pathes) path = pathes[0]
real_ks = [] a0 = 1 c0 = 0 for i in range(63): k0 = path[i] ts = [F.from_integer(t)/a0 for t in [0, 1, 255]] us = [u + k0 for u in ts] c1 = a0 * us[0] real_k = (c0 - c1) real_ks.append(real_k.to_integer()) a0 *= a c0 += real_k c0 *= a real_ks.append((b0 * (a ** 63) + c0).to_integer()) print(bytes(real_ks).hex()) io.sendlineafter(b'Do you know my key?', bytes(real_ks).hex().encode()) io.interactive()
for _ in tqdm(range(1000)): ret = try_solve()
law and order
Description
My friends were the FROST to implement sharing signatures for flag and ensured to always include me but somehow its still not working?
Author: deuterium
chall.py
from py_ecc.secp256k1 import P, G as G_lib, Nfrom py_ecc.secp256k1.secp256k1 import multiply, addimport randomfrom secrets import randbelowimport osfrom hashlib import sha256import sys
CONTEXT = os.urandom(69)NUM_PARTIES = 9THRESHOLD = (2 * NUM_PARTIES) // 3 + 1NUM_SIGNS = 100FLAG = os.environ.get("FLAG", "FLAG{I_AM_A_NINCOMPOOP}")
# IOdef send(msg): print(msg, flush=True)
def handle_error(msg): send(msg) sys.exit(1)
def receive_line(): try: return sys.stdin.readline().strip() except BaseException: handle_error("Connection closed by client. Exiting.")
def input_int(range=P): try: x = int(receive_line()) assert 0 <= x <= range - 1 return x except BaseException: handle_error("Invalid input integer")
def input_point(): try: x = input_int() y = input_int() assert not (x == 0 and y == 0) return Point(x, y) except BaseException: handle_error("Invalid input Point")
# Helper classclass Point: """easy operator overloading"""
def __init__(self, x, y): self.point = (x, y)
def __add__(self, other): return Point(*add(self.point, other.point))
def __mul__(self, scalar): return Point(*multiply(self.point, scalar))
def __rmul__(self, scalar): return Point(*multiply(self.point, scalar))
def __neg__(self): return Point(self.point[0], -self.point[1])
def __eq__(self, other): return (self + (-other)).point[0] == 0
def __repr__(self): return str(self.point)
G = Point(*G_lib)
def random_element(P): return randbelow(P - 1) + 1
def H(*args): return int.from_bytes(sha256(str(args).encode()).digest(), "big")
def sum_pts(points): res = 0 * G for point in points: res = res + point return res
def sample_poly(t): return [random_element(N) for _ in range(t)]
def gen_proof(secret, i): k = random_element(P) R = k * G mu = k + secret * H(CONTEXT, i, secret * G, R) return R, mu % N
def verify_proof(C, R, mu, i): c = H(CONTEXT, i, C, R) return R == mu * G + (-c * C)
def gen_poly_comms(coeffs): return [i * G for i in coeffs]
def eval_poly(coeffs, x): res = 0 for coeff in coeffs[::-1]: res = (res * x + coeff) % N return res
def gen_shares(n, coeffs): return {i: eval_poly(coeffs, i) for i in range(1, n + 1)}
def poly_eval_comms(comms, i): return sum_pts([comms[k] * pow(i, k, N) for k in range(THRESHOLD)])
def check_shares(comms, shares, i): return G * shares[i] == poly_eval_comms(comms, i)
def gen_nonsense(): d, e = random_element(N), random_element(N) D, E = d * G, e * G return (d, e), (D, E)
def lamb(i, S): num, den = 1, 1 for j in S: if j == i: continue num *= j den *= (j - i) return (num * pow(den, -1, N)) % N
def main(): """https://eprint.iacr.org/2020/852.pdf""" send("=" * 50) send("=== Law and Order! You should always include your friends to sign and you are mine <3 ===") send("=" * 50) send( f"We have {NUM_PARTIES - 1} parties here, and you will be party #{NUM_PARTIES}.") send(f"Idk why is our group signing not working") send("\n--- Round 1: Commitment ---")
# Keygen
send(f"context string {CONTEXT.hex()}")
your_id = NUM_PARTIES
all_coeffs = {} all_comms = {} for i in range(1, NUM_PARTIES): # 1.1 coeffs = sample_poly(THRESHOLD) # 1.2 zero_proof = gen_proof(coeffs[0], i) # 1.3 comms = gen_poly_comms(coeffs) # 1.5 if not verify_proof(comms[0], zero_proof[0], zero_proof[1], i): handle_error(f"[-] Party {i} secret PoK invalid") all_coeffs[i] = coeffs all_comms[i] = comms send(f"[+] Commitments from party {i}:") for k, C_ik in enumerate(comms): send(f" C_{i},{k} = {C_ik}")
send("\n[?] Now, provide the commitments (points) for your coefficients.") your_comms = [input_point() for _ in range(THRESHOLD)] send("\n[?] Finally, provide your proof-of-knowledge for your secret share (c_i,0).") send("[>] Send Point R:") your_zero_proof_R = input_point() send("[>] Send mu:") your_zero_proof_mu = input_int() your_zero_proof = (your_zero_proof_R, your_zero_proof_mu)
if not verify_proof(your_comms[0], your_zero_proof[0], your_zero_proof[1], your_id): handle_error(f"[-] party {your_id} secret PoK invalid") all_comms[your_id] = your_comms send("[+] Your commitments and proof have been accepted.") send("\n--- Round 2: Share Distribution ---")
send(f"[?] Please provide your shares for the other {NUM_PARTIES} parties.") # 2.1 your_shares = {} for i in range(1, NUM_PARTIES + 1): send(f"[>] Send share for party {i}:") your_shares[i] = input_int(N)
# 2.2 for i in range(1, NUM_PARTIES + 1): if not check_shares(your_comms, your_shares, i): handle_error(f"[-] party {your_id} shares for party {i} invalid") send("[+] Your shares have been verified")
all_shares = {} for l in range(1, NUM_PARTIES): shares_l = gen_shares(NUM_PARTIES, all_coeffs[l]) for i in range(1, NUM_PARTIES + 1): if not check_shares(all_comms[l], shares_l, i): handle_error(f"[-] party {l} shares for party {i} invalid") all_shares[l] = shares_l send(f"[+] Share for you from party {l}: {shares_l[your_id]}") all_shares[your_id] = your_shares
# 2.3 signing_shares = {} for i in range(1, NUM_PARTIES + 1): signing_shares[i] = 0 for j in range(1, NUM_PARTIES + 1): signing_shares[i] += all_shares[j][i]
# 2.4 group_public_key = sum_pts([all_comms[i][0] for i in range(1, NUM_PARTIES + 1)]) send(f"\n[+] Group Public Key: {group_public_key}")
send("[?] Provide your public verification share `Y_i`.") your_public_share = input_point() if your_public_share != sum_pts([poly_eval_comms(all_comms[j], your_id) for j in range(1, NUM_PARTIES + 1)]): handle_error(f"[-] party {your_id} public share invalid")
public_shares = {i: v * G for i, v in signing_shares.items()} public_shares[your_id] = your_public_share send("[+] Public verification shares have been computed.")
send(f"\n--- Phase 3: Presign and Sign ({NUM_SIGNS} rounds) ---") for _ in range(NUM_SIGNS): # presign send("[?] Provide your nonces (D_i, E_i) for this round.") your_D = input_point() your_E = input_point() your_nonsense = (your_D, your_E)
nonsense_sec, nonsense = {}, {} for i in range(1, NUM_PARTIES + 1): # 1.a, 1.b & 1.c (d, e), (D, E) = gen_nonsense() nonsense_sec[i] = (d, e) nonsense[i] = (D, E) nonsense[your_id] = your_nonsense
# Sign
# S is set of alpha, alpha is in t to n right? S = {random.randint(THRESHOLD, NUM_PARTIES) for _ in range(NUM_PARTIES - THRESHOLD)} S.add(your_id) # should always include you <3 send(f"[+] Set of signers for this round: {S}")
m = "GIVE ME THE FLAG PLEASE" combined_nonsense = {}
group_nonsense = 0 * G nonsense = {i: nonsense[i] for i in S} # 2 nonsense_ordered = sorted([(i, Di, Ei) for i, (Di, Ei) in nonsense.items()]) # 4 rhos = {i: H(i, m, nonsense_ordered) for i in S} for i, (D, E) in nonsense.items(): send(f"[+] Party {i} nonces: D={D}, E={E}") D, E = nonsense[i] nonsense_i = D + rhos[i] * E group_nonsense = group_nonsense + nonsense_i combined_nonsense[i] = nonsense_i
group_challenge = H(group_nonsense, group_public_key, m) send(f"[+] Group challenge `c`: {group_challenge}") send("[?] Provide your signature share `z_i`.")
your_zi = input_int(N)
# 7.b if your_zi * G != combined_nonsense[your_id] + \ public_shares[your_id] * group_challenge * lamb(your_id, S): handle_error(f"[-] party {your_id} signature shares invalid")
final_signing_shares = {your_id: your_zi} for i in S - {your_id}: si = signing_shares[i] di, ei = nonsense_sec[i] # 5 zi = di + (ei * rhos[i]) + si * group_challenge * lamb(i, S) Ri, Yi = combined_nonsense[i], public_shares[i] if Yi != sum_pts([poly_eval_comms(all_comms[j], i) for j in range(1, NUM_PARTIES + 1)]): handle_error(f"[-] party {i} public share invalid") if zi * G != Ri + Yi * group_challenge * lamb(i, S): handle_error(f"[-] party {i} signature share invalid") final_signing_shares[i] = zi
# 7.c z = sum(final_signing_shares.values()) % N
if z * G == group_nonsense + group_public_key * group_challenge: send("[+] Signature verification successful") send(f"[+] Here is your flag: {FLAG}") sys.exit(0) handle_error("[-] We are out of signing ink")
if __name__ == "__main__": try: main() except Exception as e: handle_error("[-] An error occured: {e}")
Solution
We have created many variants of this challenge. There’re 2 major modifications:
- Replace point equality check with
(self + (-other)).point == (0, 0)
. It’ll block points like(0, y)
. - No input of
your_public_share
, directly compute it fromall_comms
. But we need eitherNUM_PARTIES%3!=1
orNUM_PARTIES=15
to solve it.
Although we’ve decided to use the easiest version, the handout still contains the second modification and NUM_PARTIES
is still 9. I’m very sorry for that.
The bug is both py_ecc
and Point
didn’t check if the point is on the curve. Hence we can send invalid points and try to pass all assertions.
Let’s first take a look at the point operations in py_ecc
.
Jacobian coordinates
def from_jacobian(p: "PlainPoint3D") -> "PlainPoint2D": """ Convert a Jacobian point back to its corresponding 2D point representation.
:param p: the point to convert :type p: PlainPoint3D
:return: the 2D point representation :rtype: PlainPoint2D """ z = inv(p[2], P) return cast("PlainPoint2D", ((p[0] * z**2) % P, (p[1] * z**3) % P))
py_ecc
first converts the point to Jacobian coordinates and then do the calculation. The relation of Jacobian coordinates and 2D coordinates is .
Double and add
def jacobian_add(p: "PlainPoint3D", q: "PlainPoint3D") -> "PlainPoint3D": """ Add two points in Jacobian coordinates and return the result.
:param p: the first point to add :type p: PlainPoint3D :param q: the second point to add :type q: PlainPoint3D
:return: the resulting Jacobian point :rtype: PlainPoint3D """ if not p[1]: return q if not q[1]: return p U1 = (p[0] * q[2] ** 2) % P U2 = (q[0] * p[2] ** 2) % P S1 = (p[1] * q[2] ** 3) % P S2 = (q[1] * p[2] ** 3) % P if U1 == U2: if S1 != S2: return cast("PlainPoint3D", (0, 0, 1)) return jacobian_double(p) H = U2 - U1 R = S2 - S1 H2 = (H * H) % P H3 = (H * H2) % P U1H2 = (U1 * H2) % P nx = (R**2 - H3 - 2 * U1H2) % P ny = (R * (U1H2 - nx) - S1 * H3) % P nz = (H * p[2] * q[2]) % P return cast("PlainPoint3D", (nx, ny, nz))
Addition doesn’t use A
and B
at all, so it is actually adding on the unique curve y^2 = x^3 + Ax + B
determined by the input points.
def jacobian_double(p: "PlainPoint3D") -> "PlainPoint3D": """ Double a point in Jacobian coordinates and return the result.
:param p: the point to double :type p: PlainPoint3D
:return: the resulting Jacobian point :rtype: PlainPoint3D """ if not p[1]: return cast("PlainPoint3D", (0, 0, 0)) ysq = (p[1] ** 2) % P S = (4 * p[0] * ysq) % P M = (3 * p[0] ** 2 + A * p[2] ** 4) % P nx = (M**2 - 2 * S) % P ny = (M * (S - nx) - 8 * ysq**2) % P nz = (2 * p[1] * p[2]) % P return cast("PlainPoint3D", (nx, ny, nz))
The doubling algorithm uses A=0
to compute. Similarly, it is doubling on the unique curve y^2 = x^3 + B
determined by the input point.
Infinite point
In py_ecc
, the infinite point is encoded as (0, 0)
. This isn’t a problem if the point check is working correctly. As we can see, if the y-axis is zero, the point is considered infinite in double and add. It is very helpful because we can create infinite points more easily.
Also, if the x-axis is same in addition, the result is also infinite.
Low order points attack
Even though we can create invalid points, it is still impossible to get the shared secret or control group_challenge
. The only way to pass it is to generate an invalid group_public_key
whose order is sufficiently low. Similarly, to pass 7.b
, we want your_public_share
to be a low order point. To achieve this, we must pass all the checks.
The only way to control group_public_key
is sending bad your_comms[0]
. But there’s a commitment check which is hard for points not on the curve. So we need your_comms[0]
to be a low order point to generate the commitment at infinite point.
The next problem is passing check_shares(your_comms, your_shares, i)
for all parties. It is the hardest part because now your_comms[0]
is not on the curve. There’re many feasible constructions. The intended solution is to find the bug of (self + (-other)).point[0] == 0
, which allows (0, y)==(0, 0)
. Also (0, y)
is a low order point, so everything is fine. My idea is to use low order points and Fermat’s little theorem. Details are in the harder version.2
The last problem is your_public_share
. Since it’s from input, the check in 2.4
is easily achievable. We can fix the x-axis to be the same as sum_pts([poly_eval_comms(all_comms[j], your_id) for j in range(1, NUM_PARTIES + 1)])
and use the y-axis to find a low order point.
solve.py
from sage.all import *from sage.rings.generic import ProductTreefrom pwn import *from hashlib import sha256from ast import literal_evalfrom tqdm import tqdmfrom py_ecc.secp256k1 import P, G as G_lib, Nfrom py_ecc.secp256k1.secp256k1 import multiply, add
NUM_PARTIES = 9THRESHOLD = 7p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2fn = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
R = PolynomialRing(GF(p), ['x', 'y'])xbar, ybar = R.gens()
def send_point(p): io.sendline(str(p.point[0]).encode()) io.sendline(str(p.point[1]).encode())
def send_int(n): io.sendline(str(n).encode())
def jacobian_double(p): """ Double a point in Jacobian coordinates and return the result.
:param p: the point to double :type p: PlainPoint3D
:return: the resulting Jacobian point :rtype: PlainPoint3D """ # if not p[1]: # return (0, 0, 0) ysq = (p[1] ** 2) S = (4 * p[0] * ysq) A = 0 M = (3 * p[0] ** 2 + A * p[2] ** 4) nx = (M**2 - 2 * S) ny = (M * (S - nx) - 8 * ysq**2) nz = (2 * p[1] * p[2]) return (nx, ny, nz)
def jacobian_add(p, q, return_H=False): """ Add two points in Jacobian coordinates and return the result.
:param p: the first point to add :type p: PlainPoint3D :param q: the second point to add :type q: PlainPoint3D
:return: the resulting Jacobian point :rtype: PlainPoint3D """ # if not p[1]: # return q # if not q[1]: # return p U1 = (p[0] * q[2] ** 2) U2 = (q[0] * p[2] ** 2) S1 = (p[1] * q[2] ** 3) S2 = (q[1] * p[2] ** 3) # if U1 == U2: # if S1 != S2: # return (0, 0, 1) # return jacobian_double(p) H = U2 - U1 R = S2 - S1 H2 = (H * H) H3 = (H * H2) U1H2 = (U1 * H2) nx = (R**2 - H3 - 2 * U1H2) ny = (R * (U1H2 - nx) - S1 * H3) nz = (H * p[2] * q[2]) if return_H: return (nx, ny, nz), H return (nx, ny, nz)
class Point: """easy operator overloading""" def __init__(self, x, y): self.point = (x, y)
def __add__(self, other): return Point(*add(self.point, other.point))
def __mul__(self, scalar): return Point(*multiply(self.point, scalar))
def __rmul__(self, scalar): return Point(*multiply(self.point, scalar))
def __neg__(self): return Point(self.point[0], -self.point[1])
def __eq__(self, other): return (self + (-other)).point == (0, 0)
def to_jacobian(self): return (self.point[0], self.point[1], 1)
def __repr__(self): return str(self.point)
G = Point(*G_lib)
def H(*args): return int.from_bytes(sha256(str(args).encode()).digest(), "big")
def verify_proof(C, R, mu, i): c = H(CONTEXT, i, C, R) return R == mu * G + (-c * C)
def sum_pts(points): res = 0 * G for point in points: res = res + point return res
def poly_eval_comms(comms, i): return sum_pts([comms[k] * pow(i, k, N) for k in range(THRESHOLD)])
def check_shares(comms, shares, i): return G * shares[i] == poly_eval_comms(comms, i)
def start(): global CONTEXT, io io = process(["python3", "chall.py"]) io.recvuntil(b"context string")
CONTEXT = bytes.fromhex(io.recvline().strip().decode()) all_comms = {}
for i in range(1, 9): io.recvuntil(b"[+] Commitments from party ") io.recvline() comm = [] for j in range(THRESHOLD): c_j = literal_eval(io.recvline().strip().decode().split("=")[-1].strip()) comm.append(Point(*c_j)) all_comms[i] = comm
pt0 = sum([all_comms[i][0] for i in range(1, NUM_PARTIES)], Point(0, 0))
P0 = (xbar, ybar, 1) PA = jacobian_add(P0, pt0.to_jacobian()) PB = jacobian_add(jacobian_double(P0), P0)
poly1 = PA[1] poly2 = 4 * ybar**2 - 3 * xbar**3
p3 = poly1.change_ring(ZZ).resultant(poly2.change_ring(ZZ), ybar.change_ring(ZZ)).change_ring(GF(p)).univariate_polynomial()
def find(): for x0 in p3.roots(multiplicities=False): if x0 == 0: continue p1 = PA[1](x=x0).univariate_polynomial() for y0 in p1.roots(multiplicities=False): if poly2(x=x0, y=y0) != 0: continue your_comm0 = Point(int(x0), int(y0)) if (pt0 + your_comm0).point[1] != 0: continue return int(x0), int(y0)
res = find() if res is None: io.close() return x0, y0 = res print(f"Found x0: {x0}, y0: {y0}") assert poly2(x=x0, y=y0)==0
your_comm0 = Point(int(x0), int(y0))
assert (pt0 + your_comm0).point[1] == 0, (pt0 + your_comm0) assert (3 * your_comm0).point == (0, 0), (3 * your_comm0)
def find_mu(): for mu in range(1, 10000): R0 = mu * G if verify_proof(your_comm0, R0, mu, NUM_PARTIES): print(f"Found mu: {mu}") return R0, mu
your_zero_proof = find_mu()
def find_7(): Q7 = (xbar, ybar, 1) _, H7 = jacobian_add(Q7, jacobian_double(jacobian_add(Q7, jacobian_double(Q7))), return_H=True) T = jacobian_add(Q7, jacobian_double(Q7)) H7 = H7//(ybar**8) Q0 = jacobian_add(your_comm0.to_jacobian(), Q7)
while True: r = randint(1, p) R0 = r * G poly2 = R0.point[0] * Q0[2]**2 - Q0[0] poly1 = H7 M = poly1.sylvester_matrix(poly2, ybar) R0 = PolynomialRing(GF(P), 'x') M = matrix(R0, M) p3 = M.determinant() for x0 in p3.roots(multiplicities=False): if x0 == 0: continue p1 = poly2(x=x0).univariate_polynomial() for y0 in p1.roots(multiplicities=False): if poly1(x=x0, y=y0) != 0: continue your_comm6 = Point(int(x0), int(y0)) if (your_comm0+your_comm6)==r*G: return r, your_comm6
def find_mid(your_comm6): while True: r7 = randint(1, p) r1 = randint(1, p) R7 = r7 * G R1 = r1 * G
S0 = (R7.point[0], ybar, 1) S1 = jacobian_add(S0, your_comm6.to_jacobian()) poly = S1[0] - R1.point[0] * S1[2]**2 poly = poly.univariate_polynomial() for y0 in poly.roots(multiplicities=False): if y0 == 0: continue mid = Point(R7.point[0], int(y0)) if R1 == (mid + your_comm6) and R7 == mid: return r1, r7, mid
def try_find_3(your_comm0, mid): R2 = PolynomialRing(GF(P), ['r0', 'r1']) r0, r1 = R2.gens()
M2 = (12 * r0**2, 36 * r0**3, 1) M4 = (12 * r1**2, -36 * r1**3, 1)
mid1 = jacobian_add(your_comm0.to_jacobian(), M2) mid2 = jacobian_add(mid.to_jacobian(), M4)
rpoly1 = mid1[0] * mid2[2]**2 - mid2[0] * mid1[2]**2 rpoly2 = mid1[1] * mid2[2]**3 - mid2[1] * mid1[2]**3
M = rpoly1.sylvester_matrix(rpoly2, r1) R0 = PolynomialRing(GF(P), 'r0') rbar = R0.gen() M = matrix(R0, M)
det = M.determinant() print(f"det degree: {det.degree()}")
for r0_sol in det.roots(multiplicities=False): if r0_sol == 0: continue if rpoly1(r0=r0_sol) == 0: continue for r1_sol in rpoly1(r0=r0_sol).univariate_polynomial().roots(multiplicities=False): if rpoly2(r0=r0_sol, r1=r1_sol) != 0: continue if r1_sol == 0: continue your_comm2 = Point(int(12 * r0_sol**2), int(36 * r0_sol**3)) your_comm4 = Point(int(12 * r1_sol**2), int(36 * r1_sol**3)) aaa = your_comm0 + your_comm2 + your_comm4 if aaa.point == mid.point: print(f"Found your_comm2: {your_comm2}, your_comm4: {your_comm4}") return your_comm2, your_comm4
def try_find_2(): while True: r3, your_comm6 = find_7() target_point = sum_pts([poly_eval_comms(all_comms[j], NUM_PARTIES) for j in range(1, NUM_PARTIES)]) target_point = target_point + (your_comm0+your_comm6) poly = 4 * ybar**2 - 3 * target_point.point[0] ** 3 y_fake = poly.univariate_polynomial().roots(multiplicities=False) if len(y_fake) > 0: break
while True: r1, r7, mid = find_mid(your_comm6) try: res = try_find_3(your_comm0, mid) if res is not None: your_comm2, your_comm4 = res return r1, r3, r7, your_comm2, your_comm4, your_comm6 except Exception as e: print(f"Error: {e}, retrying with new r1, r7") continue
r1, r3, r7, your_comm2, your_comm4, your_comm6 = try_find_2() your_comms = [your_comm0, Point(1, 0), your_comm2, Point(1, 0), your_comm4, Point(1, 0), your_comm6] shares = [None, r1, r1, r3, r1, r1, r3, r7, r1, r3]
for i in range(1, NUM_PARTIES + 1): print(i, check_shares(your_comms, shares, i))
for your_comm in your_comms: send_point(your_comm) all_comms[NUM_PARTIES] = your_comms
send_point(your_zero_proof[0]) send_int(your_zero_proof[1])
io.recvuntil(b"[+] Your commitments and proof have been accepted.")
for i in range(1, NUM_PARTIES + 1): send_int(shares[i])
my_shares = {} for i in range(1, NUM_PARTIES): io.recvuntil(b"[+] Share for you from party ") my_shares[i] = int(io.recvline().strip().decode().split(":")[-1].strip()) my_shares[NUM_PARTIES] = r3
target_point = sum_pts([poly_eval_comms(all_comms[j], NUM_PARTIES) for j in range(1, NUM_PARTIES + 1)])
poly = 4 * ybar**2 - 3 * target_point.point[0] ** 3
y_fake = poly.univariate_polynomial().roots(multiplicities=False)[0] P_fake = Point(target_point.point[0], y_fake)
assert P_fake == target_point group_public_key = pt0 + your_comm0
send_point(P_fake)
io.recvuntil(b"[+] Public verification shares have been computed.")
def gen_D_E(): while True: r = randint(1, p) E = Point(12*r**2 % p, 36*r**3 % p) D = (1337, ybar, 1) nonce = jacobian_add(D, E.to_jacobian()) poly = nonce[1] poly = poly.univariate_polynomial() for y0 in poly.roots(multiplicities=False): if y0 == 0: continue D = Point(1337, int(y0)) nonsense_ordered = [(NUM_PARTIES, D, E)] m = "GIVE ME THE FLAG PLEASE" rho = H(NUM_PARTIES, m, nonsense_ordered) if rho % 3 == 1: group_nonsense = D + E group_challenge = H(group_nonsense, group_public_key, m) if group_challenge % 3 != 0: continue return D, E
D, E = gen_D_E() send_point(D) send_point(E)
io.recvuntil(b"[+] Set of signers for this round: ") signers = literal_eval(io.recvline().strip().decode()) print(f"Signers: {signers}") if len(signers) > 1: io.close() return
io.recvuntil(b"[+] Group challenge `c`: ") c = int(io.recvline().strip().decode()) print(f"Challenge {c}, {c % 3}") send_int(0)
io.interactive()
for _ in range(100): start()
Harder version
Here’s the harder version of the challenge.
--- chall.py+++ chall_clean5.py@@ -7,8 +7,8 @@ import sys
CONTEXT = os.urandom(69)-NUM_PARTIES = 9-THRESHOLD = (2 * NUM_PARTIES) // 3 + 1+NUM_PARTIES = 15+THRESHOLD = 13 NUM_SIGNS = 100 FLAG = os.environ.get("FLAG", "FLAG{I_AM_A_NINCOMPOOP}")
@@ -68,7 +68,7 @@ return Point(self.point[0], -self.point[1])
def __eq__(self, other):- return (self + (-other)).point[0] == 0+ return (self + (-other)).point == (0, 0)
def __repr__(self): return str(self.point)@@ -229,12 +229,8 @@ group_public_key = sum_pts([all_comms[i][0] for i in range(1, NUM_PARTIES + 1)]) send(f"\n[+] Group Public Key: {group_public_key}")
- send("[?] Provide your public verification share `Y_i`.")- your_public_share = input_point()- if your_public_share != sum_pts([poly_eval_comms(all_comms[j], your_id) for j in range(1, NUM_PARTIES + 1)]):- handle_error(f"[-] party {your_id} public share invalid")- public_shares = {i: v * G for i, v in signing_shares.items()}+ your_public_share = sum_pts([poly_eval_comms(all_comms[j], your_id) for j in range(1, NUM_PARTIES + 1)]) public_shares[your_id] = your_public_share send("[+] Public verification shares have been computed.")
The solution is similar to the easier one except the construction of your_comms
.
For simplicity, denote as your_comm[i]
.
Let’s assume and are order points. Hence has two values. If , it’ll be and if , it’ll be . It is intuitive that can go to almost any points. Similarly, we can set and to be order points. For , it’ll be , hence we can control the value of your_public_share
. The final step is to make sure and can be verified. This is easy because and are still unused.
solve_harder.py
from sage.all import *from sage.rings.generic import ProductTreefrom pwn import *from hashlib import sha256from ast import literal_evalfrom tqdm import tqdmfrom py_ecc.secp256k1 import P, G as G_lib, Nfrom py_ecc.secp256k1.secp256k1 import multiply, add
NUM_PARTIES = 15THRESHOLD = 13p = 0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2fn = 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141
R = PolynomialRing(GF(p), ['x', 'y'])xbar, ybar = R.gens()
def jacobian_double(p): """ Double a point in Jacobian coordinates and return the result.
:param p: the point to double :type p: PlainPoint3D
:return: the resulting Jacobian point :rtype: PlainPoint3D """ # if not p[1]: # return (0, 0, 0) ysq = (p[1] ** 2) S = (4 * p[0] * ysq) A = 0 M = (3 * p[0] ** 2 + A * p[2] ** 4) nx = (M**2 - 2 * S) ny = (M * (S - nx) - 8 * ysq**2) nz = (2 * p[1] * p[2]) return (nx, ny, nz)
def send_point(p): io.sendline(str(p.point[0]).encode()) io.sendline(str(p.point[1]).encode())
def send_int(n): io.sendline(str(n).encode())
def jacobian_add(p, q, return_H=False): """ Add two points in Jacobian coordinates and return the result.
:param p: the first point to add :type p: PlainPoint3D :param q: the second point to add :type q: PlainPoint3D
:return: the resulting Jacobian point :rtype: PlainPoint3D """ # if not p[1]: # return q # if not q[1]: # return p U1 = (p[0] * q[2] ** 2) U2 = (q[0] * p[2] ** 2) S1 = (p[1] * q[2] ** 3) S2 = (q[1] * p[2] ** 3) # if U1 == U2: # if S1 != S2: # return (0, 0, 1) # return jacobian_double(p) H = U2 - U1 R = S2 - S1 H2 = (H * H) H3 = (H * H2) U1H2 = (U1 * H2) nx = (R**2 - H3 - 2 * U1H2) ny = (R * (U1H2 - nx) - S1 * H3) nz = (H * p[2] * q[2]) if return_H: return (nx, ny, nz), H return (nx, ny, nz)
class Point: """easy operator overloading"""
def __init__(self, x, y): self.point = (x, y)
def __add__(self, other): return Point(*add(self.point, other.point))
def __mul__(self, scalar): return Point(*multiply(self.point, scalar))
def __rmul__(self, scalar): return Point(*multiply(self.point, scalar))
def __neg__(self): return Point(self.point[0], -self.point[1])
def __eq__(self, other): return (self + (-other)).point == (0, 0)
def to_jacobian(self): return (self.point[0], self.point[1], 1)
def __repr__(self): return str(self.point)
G = Point(*G_lib)
def H(*args): return int.from_bytes(sha256(str(args).encode()).digest(), "big")
def verify_proof(C, R, mu, i): c = H(CONTEXT, i, C, R) return R == mu * G + (-c * C)
def sum_pts(points): res = 0 * G for point in points: res = res + point return res
def poly_eval_comms(comms, i): return sum_pts([comms[k] * pow(i, k, N) for k in range(THRESHOLD)])
def check_shares(comms, shares, i): return G * shares[i] == poly_eval_comms(comms, i)
def gen_7_point(): while True: E = EllipticCurve(GF(p), [0, randint(1, p-1)]) order = E.order() if order % 7 == 0: break return E.torsion_basis(7)
def start(): global CONTEXT, io io = process(["python3", "chall_clean5.py"]) io.recvuntil(b"context string")
CONTEXT = bytes.fromhex(io.recvline().strip().decode())
all_comms = {}
for i in range(1, NUM_PARTIES): io.recvuntil(b"[+] Commitments from party ") io.recvline() comm = [] for j in range(THRESHOLD): c_j = literal_eval(io.recvline().strip().decode().split("=")[-1].strip()) comm.append(Point(*c_j)) all_comms[i] = comm
pt0 = sum([all_comms[i][0] for i in range(1, NUM_PARTIES)], Point(0, 0))
P0 = (xbar, ybar, 1)
PA = jacobian_add(P0, pt0.to_jacobian()) PB = jacobian_add(jacobian_double(P0), P0)
RZ = PolynomialRing(ZZ, ['x', 'y'])
poly1 = PA[1] poly2 = 4 * ybar**2 - 3 * xbar**3
p3 = poly1.change_ring(ZZ).resultant(poly2.change_ring(ZZ), ybar.change_ring(ZZ)).change_ring(GF(p)).univariate_polynomial()
def find(): for x0 in p3.roots(multiplicities=False): if x0 == 0: continue p1 = PA[1](x=x0).univariate_polynomial() for y0 in p1.roots(multiplicities=False): if poly2(x=x0, y=y0) != 0: continue your_comm0 = Point(int(x0), int(y0)) if (pt0 + your_comm0).point[1] != 0: continue return int(x0), int(y0)
res = find() if res is None: io.close() return x0, y0 = res print(f"Found x0: {x0}, y0: {y0}") assert poly2(x=x0, y=y0)==0
your_comm0 = Point(int(x0), int(y0))
assert (pt0 + your_comm0).point[1] == 0, (pt0 + your_comm0) assert (3 * your_comm0).point == (0, 0), (3 * your_comm0)
def find_mu(): for mu in range(1, 10000): R0 = mu * G if verify_proof(your_comm0, R0, mu, NUM_PARTIES): print(f"Found mu: {mu}") return R0, mu
your_zero_proof = find_mu()
your_public_share_pt0 = sum_pts([poly_eval_comms(all_comms[j], NUM_PARTIES) for j in range(1, NUM_PARTIES)])
def find_7(your_comm0): R2 = PolynomialRing(GF(P), ['r0', 'r1']) r0, r1 = R2.gens()
A7_, B7_ = gen_7_point()
A7 = (A7_.x() * r0 ** 2, A7_.y() * r0 ** 3, 1) B7 = (B7_.x() * r1 ** 2, B7_.y() * r1 ** 3, 1)
C = jacobian_add(your_comm0.to_jacobian(), A7) C = jacobian_add(C, B7)
while True: r3 = randint(1, p) R = r3 * G R_fake = (R.point[0], ybar, 1) fake_share = jacobian_add(your_public_share_pt0.to_jacobian(), R_fake) poly0 = 4 * fake_share[1] ** 2 - 3 * fake_share[0] ** 3 y_fake = poly0.univariate_polynomial().roots(multiplicities=False) if len(y_fake) == 0: continue y_fake = y_fake[-1] R_fake = (int(R.point[0]), int(y_fake))
poly1 = R_fake[0] * C[2] ** 2 - C[0] poly2 = R_fake[1] * C[2] ** 3 - C[1]
M = poly1.sylvester_matrix(poly2, r1) R0 = PolynomialRing(GF(P), 'r0') M = matrix(R0, M) det = M.determinant() print(f"det degree: {det.degree()}")
for r0_sol in det.roots(multiplicities=False): if r0_sol == 0: continue if poly1(r0=r0_sol) == 0: continue for r1_sol in poly1(r0=r0_sol).univariate_polynomial().roots(multiplicities=False): if poly2(r0=r0_sol, r1=r1_sol) != 0: continue if r1_sol == 0: continue your_comm6 = Point(int(A7_.x() * r0_sol ** 2), int(A7_.y() * r0_sol ** 3)) your_comm12 = Point(int(B7_.x() * r1_sol ** 2), int(B7_.y() * r1_sol ** 3)) aaa = your_comm0 + your_comm6 + your_comm12 if aaa.point == R_fake: return r3, your_comm6, your_comm12
def find_mid(your_comm6, your_comm12): while True: r7 = randint(1, p) r1 = randint(1, p) R7 = r7 * G R1 = r1 * G
S0 = (R7.point[0], ybar, 1) S1 = jacobian_add(S0, your_comm6.to_jacobian()) S1 = jacobian_add(S1, your_comm12.to_jacobian()) poly = S1[0] - R1.point[0] * S1[2]**2 poly = poly.univariate_polynomial() for y0 in poly.roots(multiplicities=False): if y0 == 0: continue mid = Point(R7.point[0], int(y0)) if R1 == (mid + your_comm6 + your_comm12) and R7 == mid: return r1, r7, mid
def try_find_3(your_comm0, mid): R2 = PolynomialRing(GF(P), ['r0', 'r1']) r0, r1 = R2.gens()
M2 = (12 * r0**2, 36 * r0**3, 1) M4 = (12 * r1**2, -36 * r1**3, 1)
mid1 = jacobian_add(your_comm0.to_jacobian(), M2) mid2 = jacobian_add(mid.to_jacobian(), M4)
rpoly1 = mid1[0] * mid2[2]**2 - mid2[0] * mid1[2]**2 rpoly2 = mid1[1] * mid2[2]**3 - mid2[1] * mid1[2]**3
M = rpoly1.sylvester_matrix(rpoly2, r1) R0 = PolynomialRing(GF(P), 'r0') rbar = R0.gen() M = matrix(R0, M)
det = M.determinant() print(f"det degree: {det.degree()}")
for r0_sol in det.roots(multiplicities=False): if r0_sol == 0: continue if rpoly1(r0=r0_sol) == 0: continue for r1_sol in rpoly1(r0=r0_sol).univariate_polynomial().roots(multiplicities=False): if rpoly2(r0=r0_sol, r1=r1_sol) != 0: continue if r1_sol == 0: continue your_comm2 = Point(int(12 * r0_sol**2), int(36 * r0_sol**3)) your_comm4 = Point(int(12 * r1_sol**2), int(36 * r1_sol**3)) aaa = your_comm0 + your_comm2 + your_comm4 if aaa.point == mid.point: print(f"Found your_comm2: {your_comm2}, your_comm4: {your_comm4}") return your_comm2, your_comm4
def try_find_2(): r3, your_comm6, your_comm12 = find_7(your_comm0) print(f"Found r3: {r3}, your_comm6: {your_comm6}, your_comm12: {your_comm12}")
while True: r1, r7, mid = find_mid(your_comm6, your_comm12) try: res = try_find_3(your_comm0, mid) if res is not None: your_comm2, your_comm4 = res return r1, r3, r7, your_comm2, your_comm4, your_comm6, your_comm12 except Exception as e: print(f"Error: {e}, retrying with new r1, r7") continue
r1, r3, r7, your_comm2, your_comm4, your_comm6, your_comm12 = try_find_2() your_comms = [your_comm0, Point(1, 0), your_comm2, Point(1, 0), your_comm4, Point(1, 0), your_comm6] + [Point(1, 0)] * 5 + [your_comm12] shares = [None, r1, r1, r3, r1, r1, r3, r7, r1, r3, r1, r1, r3, r1, r7, r3]
for i in range(1, NUM_PARTIES + 1): print(i, check_shares(your_comms, shares, i))
for your_comm in your_comms: send_point(your_comm) all_comms[NUM_PARTIES] = your_comms
send_point(your_zero_proof[0]) send_int(your_zero_proof[1])
io.recvuntil(b"[+] Your commitments and proof have been accepted.")
for i in range(1, NUM_PARTIES + 1): send_int(shares[i])
my_shares = {} for i in range(1, NUM_PARTIES): io.recvuntil(b"[+] Share for you from party ") my_shares[i] = int(io.recvline().strip().decode().split(":")[-1].strip()) my_shares[NUM_PARTIES] = r3
group_public_key = pt0 + your_comm0 io.recvuntil(b"[+] Public verification shares have been computed.")
def gen_D_E(): while True: r = randint(1, p) E = Point(12*r**2 % p, 36*r**3 % p) D = (1337, ybar, 1) nonce = jacobian_add(D, E.to_jacobian()) poly = nonce[1] poly = poly.univariate_polynomial() for y0 in poly.roots(multiplicities=False): if y0 == 0: continue D = Point(1337, int(y0)) nonsense_ordered = [(NUM_PARTIES, D, E)] m = "GIVE ME THE FLAG PLEASE" rho = H(NUM_PARTIES, m, nonsense_ordered) if rho % 3 == 1: group_nonsense = D + E group_challenge = H(group_nonsense, group_public_key, m) if group_challenge % 3 != 0: continue return D, E
D, E = gen_D_E()
send_point(D) send_point(E)
io.recvuntil(b"[+] Set of signers for this round: ") signers = literal_eval(io.recvline().strip().decode()) print(f"Signers: {signers}") if len(signers) > 1: io.close() return
io.recvuntil(b"[+] Group challenge `c`: ") c = int(io.recvline().strip().decode()) print(f"Challenge {c}, {c % 3}") send_int(0)
io.interactive()
for _ in range(100): start()