upCTF Crypto Writeup¶
Challenge¶
I was given the following challenge:
from Crypto.Util.number import *
from Crypto.Hash import SHA256
import secrets
def hsh(x):
h = SHA256.new()
h.update(x)
return int.from_bytes(h.digest())
for i in range(10):
S=set()
p=getPrime(70)
print(f"{p=}")
x=secrets.randbelow(p)
while True:
m1=input("m1=")
if "Stop" in m1:
break
m1=bytes.fromhex(m1)
if m1 in S:
print("NO!")
exit(0)
S.add(m1)
m2=bytes.fromhex(input("m2="))
if m2 in S:
print("NO!")
exit(0)
S.add(m2)
res=pow(x,hsh(m1),p)+pow(x,hsh(m2),p)
res%=p
print(f"{res=}")
x2=int(input("x="))
if x2!=x:
print("FAIL!")
exit(0)
print(open("flag.txt","r").read())
The flag I recovered was:
upCTF{Wagner_algorithm_is_very_OP-Aj97UrTu5e9639b1}
My first observations¶
I started with the obvious simplification:
- The exponents are
SHA256(m)interpreted as a huge integer. - Since the computation is
pow(x, hsh(m), p), the exponent is really reduced modulop-1for nonzerox. - So I can define
e(m) = SHA256(m) mod (p-1), and every query is justres = x^e1 + x^e2 mod p.
I also checked for parser bugs and protocol bugs. There was nothing useful:
bytes.fromhex()was strict enough.- The message reuse check was real.
"Stop"only stopped the query loop.- There was no accidental leak of
x.
So I needed an actual algebraic attack.
The baseline idea I rejected¶
The most obvious attack is an exact exponent collision.
If I find two distinct messages m1, m2 such that:
e(m1) = e(m2) mod (p-1),
then one query gives:
res = x^e + x^e = 2*x^e mod p.
If gcd(e, p-1) = 1, then I can recover:
x = (res / 2)^(e^{-1} mod (p-1)) mod p.
This works, but it costs about a birthday collision in a 70-bit space:
- expected work: about
2^35hashes, - and that has to be done repeatedly across 10 rounds.
That was too expensive.
The better idea: I only needed a repeated difference¶
The real breakthrough was noticing that I did not need an exact collision at all.
Suppose I have four messages with exponents:
e_ae_be_ce_d
and they satisfy:
e_a + e_d = e_b + e_c mod (p-1).
Equivalently:
e_b - e_a = e_d - e_c = c mod (p-1).
Now I make two oracle queries:
-
query
(m_a, m_c):r1 = x^e_a + x^e_c mod p -
query
(m_b, m_d):r2 = x^e_b + x^e_d mod p
Since e_b = e_a + c and e_d = e_c + c, I get:
r2 = x^(e_a + c) + x^(e_c + c)
r2 = x^c * (x^e_a + x^e_c)
r2 = x^c * r1 mod p
So:
x^c = r2 / r1 mod p
and if gcd(c, p-1) = 1, then:
x = (r2 / r1)^(c^{-1} mod (p-1)) mod p
That was the key. I had turned the problem into:
find four messages whose exponent residues satisfy one modular 4-sum relation.
Turning it into a generalized birthday problem¶
Now I had a standard shape:
e_a + e_d - e_b - e_c = 0 mod (p-1)
or equivalently:
e_a + e_d = e_b + e_c mod (p-1).
This is exactly the kind of thing Wagner's generalized birthday algorithm is meant for.
For a random 70-bit modulus:
- the exact collision attack costs about
2^(70/2) = 2^35, - while the 4-list generalized birthday attack costs about
2^(70/3), - which is around
2^23.3.
That is a massive reduction.
In practice on this box, that meant:
- roughly 12 million residues per list was often enough,
- 16 million residues handled the unlucky rounds,
- and one full round typically took a few tens of seconds.
That was absolutely practical against the live service.
The concrete Wagner-style construction I used¶
I split my candidates into four disjoint lists:
- list
A - list
B - list
C - list
D
I encoded messages as 8-byte values with the top nibble identifying the list, so all four messages were automatically distinct:
A:0x0...B:0x1...C:0x2...D:0x3...
Then I searched for:
e_a + e_d = e_b + e_c mod n
where n = p - 1.
I did not compare all quadruples directly. Instead, I used a bucketed meet-in-the-middle.
Step 1: build a right-hand list from B + C¶
For each e_b and e_c, I wanted sums where the low t bits were zero after either:
- no wrap:
e_b + e_c - one wrap:
e_b + e_c - n
So for every candidate pair, I stored:
q = (e_b + e_c) >> tif lowtbits were zero, orq = (e_b + e_c - n) >> tif lowtbits were zero after subtractingn.
That gave me a compressed right-hand list keyed by the high bits.
Step 2: scan A + D¶
Then I generated candidate sums on the left:
e_a + e_de_a + e_d - n
again only keeping the ones with low t bits zero.
For each one, I looked for the same quotient q in the right-hand list.
When the quotients matched, I checked the full equality:
e_a + e_d = e_b + e_c mod n
If that equality held and gcd(e_b - e_a, n) = 1, I was done.
Why the query order matters¶
My searcher prints five lines:
m_am_cm_bm_dc = e_b - e_a mod (p-1)
That ordering is deliberate.
I query:
(m_a, m_c)to getr1(m_b, m_d)to getr2
Then I compute:
x = ((r2 * r1^{-1}) mod p) ^ (c^{-1} mod (p-1)) mod p
This is exactly the algebra derived above.
Edge cases I handled¶
There were a few real edge cases.
1. gcd(c, p-1) != 1¶
If c is not invertible modulo p-1, then recovering x from x^c is ambiguous.
I avoided this by rejecting such matches in the searcher.
2. Some rounds miss at 12 million¶
The 2^(70/3) estimate is an expectation, not a guarantee.
Some rounds did not produce a usable relation with:
L = 12000000t = 24
So I added fallback settings:
L = 16000000, t = 24L = 16000000, t = 23
That solved the unlucky rounds.
3. Message reuse is forbidden¶
The service keeps a set of used byte strings, so I needed all four query messages to be distinct.
Using separate list identifiers in the high nibble made that automatic.
The attack workflow I used live¶
For each round:
- Read
p. - Run the local relation search against
n = p - 1. - Get four messages and the shift
c. - Query
(m_a, m_c)and parser1. - Query
(m_b, m_d)and parser2. - Compute:
x = ((r2 / r1) mod p)^(c^{-1} mod (p-1)) mod p - Send
Stop. - Send
x.
I automated the whole thing in a small remote solver.
Complexity¶
The important comparison is:
- exact collision:
about
2^35SHA-256 evaluations - 4-list generalized birthday:
about
2^(70/3) ~= 2^23.3
In other words, Wagner's algorithm is exactly what made this challenge practical.
That is why the flag text saying Wagner is OP is accurate.
Build and run¶
I compiled the searcher like this:
Then I ran the remote solver against the service.
Solver 1: relation_search.cpp¶
This is the compiled Wagner-style 4-list searcher I used.
#include <openssl/sha.h>
#include <algorithm>
#include <array>
#include <chrono>
#include <cstdint>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <string>
#include <thread>
#include <utility>
#include <vector>
using u128 = unsigned __int128;
struct RightEntry {
uint64_t q;
uint64_t meta;
};
static inline std::string u128_to_string(u128 x) {
if (x == 0) return "0";
std::string s;
while (x > 0) {
s.push_back(char('0' + (x % 10)));
x /= 10;
}
std::reverse(s.begin(), s.end());
return s;
}
static inline u128 parse_u128(const std::string& s) {
u128 x = 0;
for (char c : s) {
if (c < '0' || c > '9') continue;
x = x * 10 + (c - '0');
}
return x;
}
static inline uint64_t make_msg(int list_id, uint32_t idx) {
return (uint64_t(list_id) << 60) | uint64_t(idx);
}
static inline std::array<unsigned char, 8> to_be8(uint64_t x) {
std::array<unsigned char, 8> out{};
for (int i = 7; i >= 0; --i) {
out[i] = static_cast<unsigned char>(x & 0xff);
x >>= 8;
}
return out;
}
static inline std::string hex_msg(uint64_t x) {
auto b = to_be8(x);
static const char* H = "0123456789abcdef";
std::string s;
s.resize(16);
for (int i = 0; i < 8; ++i) {
s[2 * i] = H[b[i] >> 4];
s[2 * i + 1] = H[b[i] & 15];
}
return s;
}
static inline u128 sha256_mod_u128(uint64_t msg, u128 mod) {
auto in = to_be8(msg);
unsigned char digest[SHA256_DIGEST_LENGTH];
SHA256(in.data(), in.size(), digest);
u128 r = 0;
for (int i = 0; i < SHA256_DIGEST_LENGTH; ++i) {
r = (r * 256 + digest[i]) % mod;
}
return r;
}
static void fill_residues(int list_id, uint32_t count, u128 mod, std::vector<u128>& out) {
out.resize(count);
unsigned threads = std::max(1u, std::thread::hardware_concurrency());
threads = std::min(threads, 2u);
std::vector<std::thread> pool;
for (unsigned t = 0; t < threads; ++t) {
uint32_t lo = (uint64_t(count) * t) / threads;
uint32_t hi = (uint64_t(count) * (t + 1)) / threads;
pool.emplace_back([=, &out]() {
for (uint32_t i = lo; i < hi; ++i) {
out[i] = sha256_mod_u128(make_msg(list_id, i), mod);
}
});
}
for (auto& th : pool) th.join();
}
static inline uint64_t lowbits(u128 x, uint32_t mask) {
return static_cast<uint64_t>(x) & mask;
}
static inline uint64_t gcd_u128(u128 a, u128 b) {
while (b != 0) {
u128 r = a % b;
a = b;
b = r;
}
return static_cast<uint64_t>(a);
}
static void build_bucket_index(
const std::vector<u128>& vals,
uint32_t t,
std::vector<uint32_t>& head,
std::vector<uint32_t>& idx
) {
uint32_t B = 1u << t;
uint32_t mask = B - 1;
head.assign(B + 1, 0);
for (u128 v : vals) {
++head[lowbits(v, mask)];
}
uint32_t sum = 0;
for (uint32_t i = 0; i < B; ++i) {
uint32_t c = head[i];
head[i] = sum;
sum += c;
}
head[B] = sum;
idx.resize(vals.size());
std::vector<uint32_t> cur(head.begin(), head.end());
for (uint32_t i = 0; i < vals.size(); ++i) {
idx[cur[lowbits(vals[i], mask)]++] = i;
}
}
int main(int argc, char** argv) {
if (argc < 2) {
std::cerr << "usage: " << argv[0] << " <p> [L] [t]\n";
return 1;
}
u128 p = parse_u128(argv[1]);
u128 n = p - 1;
uint32_t L = (argc >= 3) ? static_cast<uint32_t>(std::stoul(argv[2])) : 12000000u;
uint32_t t = (argc >= 4) ? static_cast<uint32_t>(std::stoul(argv[3])) : 24u;
uint32_t B = 1u << t;
uint32_t mask = B - 1;
auto t0 = std::chrono::steady_clock::now();
std::cerr << "[*] p=" << u128_to_string(p) << " L=" << L << " t=" << t << "\n";
std::vector<u128> bvals, cvals;
std::cerr << "[*] hashing B/C\n";
fill_residues(1, L, n, bvals);
fill_residues(2, L, n, cvals);
std::cerr << "[*] bucketing C\n";
std::vector<uint32_t> chead, cidx;
build_bucket_index(cvals, t, chead, cidx);
std::cerr << "[*] building right list\n";
std::vector<RightEntry> right;
right.reserve(L + L / 4);
for (uint32_t ib = 0; ib < L; ++ib) {
u128 eb = bvals[ib];
uint32_t want0 = static_cast<uint32_t>((B - lowbits(eb, mask)) & mask);
for (uint32_t pos = chead[want0]; pos < chead[want0 + 1]; ++pos) {
uint32_t ic = cidx[pos];
u128 s = eb + cvals[ic];
if ((static_cast<uint64_t>(s) & mask) == 0) {
uint64_t q = static_cast<uint64_t>(s >> t);
uint64_t meta = uint64_t(ib) | (uint64_t(ic) << 24);
right.push_back({q, meta});
}
}
uint32_t want1 = static_cast<uint32_t>((static_cast<uint64_t>(n) - lowbits(eb, mask)) & mask);
for (uint32_t pos = chead[want1]; pos < chead[want1 + 1]; ++pos) {
uint32_t ic = cidx[pos];
u128 s = eb + cvals[ic];
if (s >= n) {
u128 z = s - n;
if ((static_cast<uint64_t>(z) & mask) == 0) {
uint64_t q = static_cast<uint64_t>(z >> t);
uint64_t meta = uint64_t(ib) | (uint64_t(ic) << 24) | (1ull << 48);
right.push_back({q, meta});
}
}
}
}
std::cerr << "[*] right size=" << right.size() << "\n";
cvals.clear();
cvals.shrink_to_fit();
chead.clear();
chead.shrink_to_fit();
cidx.clear();
cidx.shrink_to_fit();
std::cerr << "[*] sorting right list\n";
std::sort(right.begin(), right.end(), [](const RightEntry& a, const RightEntry& b) {
if (a.q != b.q) return a.q < b.q;
return a.meta < b.meta;
});
std::vector<u128> dvals;
std::cerr << "[*] hashing D\n";
fill_residues(3, L, n, dvals);
std::cerr << "[*] bucketing D\n";
std::vector<uint32_t> dhead, didx;
build_bucket_index(dvals, t, dhead, didx);
auto lower_q = [&](uint64_t q) {
return std::lower_bound(right.begin(), right.end(), q, [](const RightEntry& e, uint64_t val) {
return e.q < val;
});
};
std::cerr << "[*] scanning A against right list\n";
for (uint32_t ia = 0; ia < L; ++ia) {
u128 ea = sha256_mod_u128(make_msg(0, ia), n);
uint32_t want0 = static_cast<uint32_t>((B - lowbits(ea, mask)) & mask);
for (uint32_t pos = dhead[want0]; pos < dhead[want0 + 1]; ++pos) {
uint32_t id = didx[pos];
u128 s = ea + dvals[id];
if ((static_cast<uint64_t>(s) & mask) != 0) continue;
uint64_t q = static_cast<uint64_t>(s >> t);
auto it = lower_q(q);
while (it != right.end() && it->q == q) {
uint32_t ib = static_cast<uint32_t>(it->meta & ((1ull << 24) - 1));
uint32_t ic = static_cast<uint32_t>((it->meta >> 24) & ((1ull << 24) - 1));
u128 eb = bvals[ib];
u128 ec = sha256_mod_u128(make_msg(2, ic), n);
u128 rhs = eb + ec;
if ((it->meta >> 48) & 1) rhs -= n;
if (s == rhs) {
u128 cshift = (eb >= ea) ? (eb - ea) : (eb + n - ea);
if (gcd_u128(cshift, n) != 1) {
++it;
continue;
}
auto t1 = std::chrono::steady_clock::now();
std::cerr << "[*] found in " << std::chrono::duration<double>(t1 - t0).count() << " sec\n";
std::cout << hex_msg(make_msg(0, ia)) << "\n";
std::cout << hex_msg(make_msg(2, ic)) << "\n";
std::cout << hex_msg(make_msg(1, ib)) << "\n";
std::cout << hex_msg(make_msg(3, id)) << "\n";
std::cout << u128_to_string(cshift) << "\n";
return 0;
}
++it;
}
}
uint32_t want1 = static_cast<uint32_t>((static_cast<uint64_t>(n) - lowbits(ea, mask)) & mask);
for (uint32_t pos = dhead[want1]; pos < dhead[want1 + 1]; ++pos) {
uint32_t id = didx[pos];
u128 s = ea + dvals[id];
if (s < n) continue;
u128 z = s - n;
if ((static_cast<uint64_t>(z) & mask) != 0) continue;
uint64_t q = static_cast<uint64_t>(z >> t);
auto it = lower_q(q);
while (it != right.end() && it->q == q) {
uint32_t ib = static_cast<uint32_t>(it->meta & ((1ull << 24) - 1));
uint32_t ic = static_cast<uint32_t>((it->meta >> 24) & ((1ull << 24) - 1));
u128 eb = bvals[ib];
u128 ec = sha256_mod_u128(make_msg(2, ic), n);
u128 rhs = eb + ec;
if ((it->meta >> 48) & 1) rhs -= n;
if (z == rhs) {
u128 cshift = (eb >= ea) ? (eb - ea) : (eb + n - ea);
if (gcd_u128(cshift, n) != 1) {
++it;
continue;
}
auto t1 = std::chrono::steady_clock::now();
std::cerr << "[*] found in " << std::chrono::duration<double>(t1 - t0).count() << " sec\n";
std::cout << hex_msg(make_msg(0, ia)) << "\n";
std::cout << hex_msg(make_msg(2, ic)) << "\n";
std::cout << hex_msg(make_msg(1, ib)) << "\n";
std::cout << hex_msg(make_msg(3, id)) << "\n";
std::cout << u128_to_string(cshift) << "\n";
return 0;
}
++it;
}
}
}
std::cerr << "[!] no relation found\n";
return 2;
}
Solver 2: find_relation.py¶
This is the small wrapper I used to retry the search with stronger parameters on unlucky rounds.
#!/usr/bin/env python3
import subprocess
import sys
CONFIGS = [
(12000000, 24),
(16000000, 24),
(16000000, 23),
]
def main() -> int:
if len(sys.argv) != 2:
print(f"usage: {sys.argv[0]} <p>", file=sys.stderr)
return 1
p = sys.argv[1]
for L, t in CONFIGS:
proc = subprocess.run(
["./relation_search", p, str(L), str(t)],
text=True,
capture_output=True,
)
sys.stderr.write(proc.stderr)
if proc.returncode == 0:
sys.stdout.write(proc.stdout)
return 0
return 2
if __name__ == "__main__":
raise SystemExit(main())
Solver 3: Remote solver¶
This is the full remote solver that ties everything together.
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import re
import socket
import subprocess
import sys
def run_relation_search(p: int) -> tuple[str, str, str, str, int]:
out = subprocess.check_output(["python3", "find_relation.py", str(p)], text=True)
lines = [line.strip() for line in out.splitlines() if line.strip()]
if len(lines) != 5:
raise RuntimeError(f"unexpected relation_search output: {lines!r}")
m_a, m_c, m_b, m_d, c = lines
return m_a, m_c, m_b, m_d, int(c)
def recover_x(p: int, r1: int, r2: int, c: int) -> int:
n = p - 1
xc = (r2 * pow(r1, -1, p)) % p
return pow(xc, pow(c, -1, n), p)
class Tube:
def __init__(self, host: str, port: int):
self.sock = socket.create_connection((host, port))
self.buf = b""
def close(self) -> None:
self.sock.close()
def sendline(self, data: str) -> None:
self.sock.sendall(data.encode() + b"\n")
def recv_until(self, marker: bytes) -> bytes:
while marker not in self.buf:
chunk = self.sock.recv(4096)
if not chunk:
out = self.buf
self.buf = b""
return out
self.buf += chunk
idx = self.buf.index(marker) + len(marker)
out = self.buf[:idx]
self.buf = self.buf[idx:]
return out
def recv_all(self) -> bytes:
chunks = [self.buf]
self.buf = b""
while True:
chunk = self.sock.recv(4096)
if not chunk:
break
chunks.append(chunk)
return b"".join(chunks)
def parse_p(blob: bytes) -> int:
m = re.search(rb"p=(\d+)", blob)
if not m:
raise RuntimeError(f"could not parse p from: {blob!r}")
return int(m.group(1))
def parse_res(blob: bytes) -> int:
m = re.search(rb"res=(\d+)", blob)
if not m:
raise RuntimeError(f"could not parse res from: {blob!r}")
return int(m.group(1))
def main() -> int:
parser = argparse.ArgumentParser(description="Solve the upCTF oracle challenge remotely.")
parser.add_argument("host")
parser.add_argument("port", type=int)
args = parser.parse_args()
tube = Tube(args.host, args.port)
try:
banner = tube.recv_until(b"m1=")
p = parse_p(banner)
for rnd in range(10):
print(f"[*] round {rnd + 1}/10 p={p}", file=sys.stderr)
m_a, m_c, m_b, m_d, c = run_relation_search(p)
tube.sendline(m_a)
tube.recv_until(b"m2=")
tube.sendline(m_c)
r1 = parse_res(tube.recv_until(b"m1="))
tube.sendline(m_b)
tube.recv_until(b"m2=")
tube.sendline(m_d)
r2 = parse_res(tube.recv_until(b"m1="))
x = recover_x(p, r1, r2, c)
tube.sendline("Stop")
tube.recv_until(b"x=")
tube.sendline(str(x))
if rnd == 9:
final = tube.recv_all().decode(errors="replace")
sys.stdout.write(final)
return 0
nxt = tube.recv_until(b"m1=")
p = parse_p(nxt)
return 0
finally:
tube.close()
if __name__ == "__main__":
raise SystemExit(main())
Final note¶
The intended lesson here is that the exact-collision approach is not the right lens.
Once I rewrote the problem as:
e_a + e_d = e_b + e_c mod (p-1)
the challenge stopped being a plain birthday search and became a generalized birthday problem.
At that point, Wagner's algorithm was exactly the right hammer.