1 |
|
2 |
|
3 | const { randomBytes } = require('crypto')
|
4 |
|
5 | const base64url = require('./base64url')
|
6 | const errors = require('../errors')
|
7 |
|
8 | const ZERO = BigInt(0)
|
9 | const ONE = BigInt(1)
|
10 | const TWO = BigInt(2)
|
11 |
|
12 | const toJWKParameter = (n) => {
|
13 | const hex = n.toString(16)
|
14 | return base64url.encodeBuffer(Buffer.from(hex.length % 2 ? `0${hex}` : hex, 'hex'))
|
15 | }
|
16 | const fromBuffer = buf => BigInt(`0x${buf.toString('hex')}`)
|
17 | const bitLength = n => n.toString(2).length
|
18 |
|
19 | const eGcdX = (a, b) => {
|
20 | let x = ZERO
|
21 | let y = ONE
|
22 | let u = ONE
|
23 | let v = ZERO
|
24 |
|
25 | while (a !== ZERO) {
|
26 | const q = b / a
|
27 | const r = b % a
|
28 | const m = x - (u * q)
|
29 | const n = y - (v * q)
|
30 | b = a
|
31 | a = r
|
32 | x = u
|
33 | y = v
|
34 | u = m
|
35 | v = n
|
36 | }
|
37 | return x
|
38 | }
|
39 |
|
40 | const gcd = (a, b) => {
|
41 | let shift = ZERO
|
42 | while (!((a | b) & ONE)) {
|
43 | a >>= ONE
|
44 | b >>= ONE
|
45 | shift++
|
46 | }
|
47 | while (!(a & ONE)) {
|
48 | a >>= ONE
|
49 | }
|
50 | do {
|
51 | while (!(b & ONE)) {
|
52 | b >>= ONE
|
53 | }
|
54 | if (a > b) {
|
55 | const x = a
|
56 | a = b
|
57 | b = x
|
58 | }
|
59 | b -= a
|
60 | } while (b)
|
61 |
|
62 | return a << shift
|
63 | }
|
64 |
|
65 | const modPow = (a, b, n) => {
|
66 | a = toZn(a, n)
|
67 | let result = ONE
|
68 | let x = a
|
69 | while (b > 0) {
|
70 | var leastSignificantBit = b % TWO
|
71 | b = b / TWO
|
72 | if (leastSignificantBit === ONE) {
|
73 | result = result * x
|
74 | result = result % n
|
75 | }
|
76 | x = x * x
|
77 | x = x % n
|
78 | }
|
79 | return result
|
80 | }
|
81 |
|
82 | const randBetween = (min, max) => {
|
83 | const interval = max - min
|
84 | const bitLen = bitLength(interval)
|
85 | let rnd
|
86 | do {
|
87 | rnd = fromBuffer(randBits(bitLen))
|
88 | } while (rnd > interval)
|
89 | return rnd + min
|
90 | }
|
91 |
|
92 | const randBits = (bitLength) => {
|
93 | const byteLength = Math.ceil(bitLength / 8)
|
94 | const rndBytes = randomBytes(byteLength)
|
95 |
|
96 | rndBytes[0] = rndBytes[0] & (2 ** (bitLength % 8) - 1)
|
97 | return rndBytes
|
98 | }
|
99 |
|
100 | const toZn = (a, n) => {
|
101 | a = a % n
|
102 | return (a < 0) ? a + n : a
|
103 | }
|
104 |
|
105 | const odd = (n) => {
|
106 | let r = n
|
107 | while (r % TWO === ZERO) {
|
108 | r = r / TWO
|
109 | }
|
110 | return r
|
111 | }
|
112 |
|
113 |
|
114 | const maxCountWhileNoY = 30
|
115 | const maxCountWhileInot0 = 22
|
116 |
|
117 | const getPrimeFactors = (e, d, n) => {
|
118 | const r = odd(e * d - ONE)
|
119 |
|
120 | let countWhileNoY = 0
|
121 | let y
|
122 | do {
|
123 | countWhileNoY++
|
124 | if (countWhileNoY === maxCountWhileNoY) {
|
125 | throw new errors.JWKImportFailed('failed to calculate missing primes')
|
126 | }
|
127 |
|
128 | let countWhileInot0 = 0
|
129 | let i = modPow(randBetween(TWO, n), r, n)
|
130 | let o = ZERO
|
131 | while (i !== ONE) {
|
132 | countWhileInot0++
|
133 | if (countWhileInot0 === maxCountWhileInot0) {
|
134 | throw new errors.JWKImportFailed('failed to calculate missing primes')
|
135 | }
|
136 | o = i
|
137 | i = (i * i) % n
|
138 | }
|
139 | if (o !== (n - ONE)) {
|
140 | y = o
|
141 | }
|
142 | } while (!y)
|
143 |
|
144 | const p = gcd(y - ONE, n)
|
145 | const q = n / p
|
146 |
|
147 | return p > q ? { p, q } : { p: q, q: p }
|
148 | }
|
149 |
|
150 | module.exports = (jwk) => {
|
151 | const e = fromBuffer(base64url.decodeToBuffer(jwk.e))
|
152 | const d = fromBuffer(base64url.decodeToBuffer(jwk.d))
|
153 | const n = fromBuffer(base64url.decodeToBuffer(jwk.n))
|
154 |
|
155 | if (d >= n) {
|
156 | throw new errors.JWKInvalid('invalid RSA private exponent')
|
157 | }
|
158 |
|
159 | const { p, q } = getPrimeFactors(e, d, n)
|
160 | const dp = d % (p - ONE)
|
161 | const dq = d % (q - ONE)
|
162 | const qi = toZn(eGcdX(toZn(q, p), p), p)
|
163 |
|
164 | return {
|
165 | ...jwk,
|
166 | p: toJWKParameter(p),
|
167 | q: toJWKParameter(q),
|
168 | dp: toJWKParameter(dp),
|
169 | dq: toJWKParameter(dq),
|
170 | qi: toJWKParameter(qi)
|
171 | }
|
172 | }
|