missing </p>
[u/fanf2/talks/2014-03-nws42.git] / rsa-cheat
1 #! /usr/bin/python
2 #
3 # This proram was written by Mark Wooding.
4 # You can do anything with it. It has no warranty.
5 # http://creativecommons.org/publicdomain/zero/1.0/
6
7 from sys import argv
8 from struct import pack, unpack
9 from itertools import izip, count
10
11 import base64 as B
12 import datetime as D
13 import optparse as OP
14 import os as OS
15
16 import gmpy as G
17
18 class ExpectedError (Exception): pass
19
20 class struct (object):
21 def __init__(me, **kw): me.__dict__.update(**kw)
22
23 def noctets(n):
24 return (G.bit_length(n) + 7)//8
25 def loadb(s):
26 return G.mpz(s.encode('hex'), 16)
27 def storeb(n):
28 h = G.digits(n, 16)
29 if h.startswith('0x'): h = h[2:]
30 if len(h)%2: h = '0' + h
31 return h.decode('hex')
32
33 def rand_range(limit):
34 no = noctets(limit)
35 cutoff = 1 << 8*no
36 cutoff -= cutoff%limit
37 while True:
38 r = loadb(OS.urandom(no))
39 if r < cutoff: break
40 return r%limit
41
42 def modinv(a, m):
43 g, i, _ = G.gcdext(a, m)
44 assert g == 1
45 if i < 0: i += m
46 return i
47
48 def generate_key_with_message(modbits, msg):
49
50 ## Decode the message, and extract the necessary stuff. Somewhat annoying,
51 ## because Python's base64 decoder is picky. If the encoded prefix is a
52 ## multiple of four long then we have the whole thing; otherwise append `A'
53 ## to flush out the last byte (and leave zero bits pending) and `=' signs
54 ## to taste.
55 nbits = 6*len(msg)
56 pad = msg
57 if len(pad)%4: pad += 'A' + (3 - len(pad)%4)*'='
58 bin = B.b64decode(pad)[:(nbits + 7)//8]
59
60 ## Check that the initial part looks like a plausible public exponent and
61 ## extract it.
62 if nbits < 8:
63 raise ExpectedError, 'invalid exponent field (missing length)'
64 elen, = unpack('B', bin[0])
65 bin, nbits = bin[1:], nbits - 8
66 if elen == 0:
67 if nbits < 16:
68 raise ExpectedError, 'invalid exponent field (missing two-byte length)'
69 elen = unpack('>H', bin[:2])
70 bin, nbits = bin[2:], nbits - 16
71 if nbits < 8*elen:
72 raise ExpectedError, 'invalid exponent field (not long enough)'
73 e = loadb(bin[:elen])
74 bin, nbits = bin[elen:], nbits - 8*elen
75 if e%2 == 0:
76 raise ExpectedError, 'modulus must be odd'
77
78 ## Now the interesting bit begins. The remaining bits are destined for the
79 ## modulus. We want at least one octet, which mustn't be zero (or there's
80 ## no real point).
81 if not nbits:
82 raise ExpectedError, 'no modulus constraints (use a less mad generator!)'
83 if bin[0] == '\0':
84 raise ExpectedError, 'invalid modulus constraint (leading zero)'
85 if nbits > modbits:
86 raise ExpectedError, 'invalid modulus constraint (no freedom)'
87
88 ## We have an interesting constraint, then. Let's convert this into a
89 ## numerical problem: we'll want the modulus to be in some interval
90 ## L <= n < H. So we should find out what L and H are.
91 pat = loadb(bin)
92 lo = pat << (modbits - 8*len(bin))
93 hi = lo + (1 << modbits - nbits)
94
95 def prime_within_bounds(name, lo, hi, e):
96 """Return a random prime LO <= P < HI."""
97 while True:
98 start = lo + rand_range(hi - lo)
99 start |= 1
100 p = G.next_prime(start)
101 if p >= hi: continue
102 if G.gcd(e, p - 1) != 1: continue
103 return p
104
105 ## Now we want to pick one of the factors. If the modulus is a funny size,
106 ## then let this one be the smaller. This fixes the size of the other
107 ## factor, which in turn is going to give us constraints on this one.
108 pbits = modbits//2
109 qbits = modbits - pbits
110 qlo, qhi = 1 << (qbits - 1), 1 << qbits
111 p = prime_within_bounds("p", lo//qhi + 1, hi//qlo, e)
112
113 ## Next up, pick the other factor. We should now have easy bounds on where
114 ## this can be.
115 q = prime_within_bounds("q", lo//p + 1, hi//p, e)
116
117 ## Now build the private key.
118 d = modinv(e, G.lcm(p - 1, q - 1))
119 k = struct(p = p, q = q, e = e, n = p*q, d = d,
120 dp = d%(p - 1), dq = d%(q - 1), q_inv = modinv(q, p))
121
122 ## Done.
123 return k
124
125 def dnskey_rdata(p, k):
126 elen = noctets(k.e)
127 return \
128 pack('>HBB', p.flags, 3, p.alg) + \
129 (elen < 256 and pack('B', elen) or pack('>BH', 0, elen)) + \
130 storeb(k.e) + storeb(k.n)
131
132 def key_tag(rd):
133 if rd[3] == '\1':
134 return unpack('>H', rd[-3:-1])
135 else:
136 ac = 0
137 for i, c in izip(count(), rd):
138 if i%2 == 0: ac += ord(c) << 8
139 else: ac += ord(c)
140 return (ac + (ac >> 16)) & 0xffff
141
142 ALGMAP = { 8: 'RSASHA256' }
143 RALGMAP = dict((v, k) for k, v in ALGMAP.iteritems())
144
145 def save_key(p, k):
146
147 rd = dnskey_rdata(p, k)
148 tag = key_tag(rd)
149 base = 'K%s.+%03d+%05d' % (p.name, p.alg, tag)
150 keytype = 'zone'
151 if p.flags & 1: keytype = 'key'
152
153 with open(base + '.private', 'w') as f:
154 f.write('Private-key-format: v1.3\n')
155 f.write('Algorithm: %s\n' %
156 (lambda n, v:
157 v and '%d (%s)' % (n, v) or '%d' % n) \
158 (p.alg, ALGMAP.get(p.alg)))
159
160 for kwd, attr in [('Modulus', 'n'),
161 ('PublicExponent', 'e'),
162 ('PrivateExponent', 'd'),
163 ('Prime1', 'p'),
164 ('Prime2', 'q'),
165 ('Exponent1', 'dp'),
166 ('Exponent2', 'dq'),
167 ('Coefficient', 'q_inv')]:
168 f.write('%s: %s\n' %
169 (kwd, B.b64encode(storeb(getattr(k, attr)))))
170
171 stamp = D.datetime.utcnow().strftime('%Y%m%d%H%M%S')
172 for kwd in ['Created', 'Publish', 'Activate']:
173 f.write('%s: %s\n' % (kwd, stamp))
174
175 with open(base + '.key', 'w') as f:
176 f.write("; This is a %s-signing key, keyid %s, for %s\n" %
177 (keytype, tag, p.name))
178 body = B.b64encode(rd[4:])
179 bodylines = [body[i:min(i + 44, len(body))]
180 for i in xrange(0, len(body), 44)]
181 f.write('%s. IN DNSKEY %d 3 %d (\n\t%s )\n' % (
182 p.name, p.flags, p.alg, '\n\t'.join(bodylines)))
183
184 print base
185
186 op = OP.OptionParser(
187 description = 'Generate RSA keys for DNSSEC with embedded messages.',
188 usage = 'usage: %prog [-k] [-a ALG] [-b NBITS] NAME PREFIX')
189 for shortopt, longopt, kw in [
190 ('-a', '--algorithm', dict(
191 metavar = 'ALG', dest = 'alg', default = 8,
192 help = 'algorithm (numeric, or known label)')),
193 ('-b', '--bits', dict(
194 metavar = 'NBITS', dest = 'modbits', type = 'int', default = 2048,
195 help = 'size of modulus, in bits [default 2048]')),
196 ('-k', '--ksk', dict(
197 dest = 'kskp', action = 'store_true', default = False,
198 help = 'set the key-signing-key flag bit'))]:
199 op.add_option(shortopt, longopt, **kw)
200
201 param, args = op.parse_args()
202 if len(args) != 2: op.error("wrong number of arguments")
203 try: param.alg = RALGMAP[param.alg]
204 except KeyError: param.alg = int(param.alg)
205 param.name, msg = args
206 param.flags = 256
207 if param.kskp: param.flags |= 1
208
209 key = generate_key_with_message(param.modbits, msg)
210 save_key(param, key)