;;;; bn.scm: A user-level unsigned bignum arithmetic library ;;;; J. Welsh, August 2017 ;; A bignum is a list of words, least significant first. It must not have ;; trailing zeros. Thus each number has a unique representation, and zero is ;; the empty list. ;;; Package parameters (lambda (error base-nibbles) ;;; Constants (let* ((base-bits (* base-nibbles 4)) (base/2 (expt 2 (- base-bits 1))) (base (* 2 base/2))) ;; must be <= sqrt of largest fixnum (let ((base-1 (- base 1)) (hex "0123456789abcdef") (char0 (char->integer #\0)) (char10-A (- 10 (char->integer #\A))) (bn0 '()) (bn1 '(1)) (bn2 '(2)) (karatsuba-threshold 20) (karatsuba-sqr-threshold 30)) ;;; Helpers (define (fix->hex n) ;; note 0 -> empty string (do ((n n (quotient n 16)) (acc '() (cons (string-ref hex (remainder n 16)) acc))) ((zero? n) (list->string acc)))) (define (hexdigit->fix c) (if (char-numeric? c) (- (char->integer c) char0) (let ((i (+ (char->integer (char-upcase c)) char10-A))) (if (and (<= 10 i) (< i 16)) i (error "bad hex digit:" c))))) (define (decdigit->fix c) (if (char-numeric? c) (- (char->integer c) char0) (error "bad decimal digit:" c))) (define (left-pad s len char) (string-append (make-string (- len (string-length s)) char) s)) (define (bn-pad-word->hex w) (left-pad (fix->hex w) base-nibbles #\0)) (define (word->bn w) (if (zero? w) '() (list w))) ;; Construct bignum from big-endian, vector-like sequence of nibbles. ;; Ugly, but linear time. (define (nibbles->bn nibble-ref len) (define (loop-words start acc) (if (= start len) acc (let* ((next (+ start base-nibbles)) (word (get-word start (- next 1)))) (loop-words next (if (and (null? acc) (zero? word)) acc (cons word acc)))))) (define (get-word start stop) (define (loop start acc) (if (> start stop) acc (loop (+ start 1) (+ (* 16 acc) (nibble-ref start))))) (loop (+ start 1) (nibble-ref start))) (if (zero? len) '() (let* ((msw-end (remainder (- len 1) base-nibbles)) (msw (get-word 0 msw-end))) (loop-words (+ msw-end 1) (word->bn msw))))) ;; Rather than "shift left/right", which unnecessarily invoke endianness, ;; I'm using "shift" for multiplications and "unshift" for divisions. (define (shift-words a k) (if (null? a) a (shift-words-nz a k))) (define (shift-words-nz a k) (if (zero? k) a (shift-words-nz (cons 0 a) (- k 1)))) (define (unshift-words a k) (if (or (zero? k) (null? a)) a (unshift-words (cdr a) (- k 1)))) ;;; Type conversion (define (bn->hex n) (let ((n (reverse n))) (if (null? n) "0" (apply string-append (fix->hex (car n)) ;; optimize? (map bn-pad-word->hex (cdr n)))))) (define (hex->bn s) (nibbles->bn (lambda (k) (hexdigit->fix (string-ref s k))) (string-length s))) (define (bytes->bn v) (nibbles->bn (lambda (k) (let ((byte (vector-ref v (quotient k 2)))) (if (even? k) (quotient byte 16) ;; big-endian (remainder byte 16)))) (* 2 (vector-length v)))) ;; ~Cubic algorithm! (define (bn->dec n) (let loop ((n n) (acc '())) (if (null? n) (list->string acc) (bn-divrem n '(10) (lambda (q r) (loop q (cons (string-ref hex (bn->fix r)) acc))))))) ;; Quadratic algorithm! (define (dec->bn s) (do ((i 0 (+ i 1)) (acc bn0 (bn+fix (bn*fix acc 10) (decdigit->fix (string-ref s i))))) ((= i (string-length s)) acc))) ;; Can overflow (obviously) (define (bn->fix n) (do ((n (reverse n) (cdr n)) (acc 0 (+ (* acc base) (car n)))) ((null? n) acc))) (define (fix->bn n) (do ((n n (quotient n base)) (acc '() (cons (remainder n base) acc))) ((zero? n) (reverse acc)))) ;;; Predicates (define bn-zero? null?) (define (bn-even? a) (or (null? a) (even? (car a)))) (define (bn-odd? a) (not (bn-even? a))) (define (cmp a b) (cond ((< a b) -1) ((< b a) 1) (else 0))) (define (bn-cmp a b) (cond ((null? a) (if (null? b) 0 -1)) ((null? b) 1) (else (let ((c (bn-cmp (cdr a) (cdr b)))) (if (zero? c) (cmp (car a) (car b)) c))))) (define bn= equal?) (define (bn< a b) (< (bn-cmp a b) 0)) (define (bn> a b) (> (bn-cmp a b) 0)) (define (bn<= a b) (<= (bn-cmp a b) 0)) (define (bn>= a b) (>= (bn-cmp a b) 0)) ;;; Addition (define (bn+1 a) (if (null? a) bn1 (let ((head (car a))) (if (= head base-1) (cons 0 (bn+1 (cdr a))) (cons (+ head 1) (cdr a)))))) (define (bn+ a b) (cond ((null? a) b) ((null? b) a) (else (let ((sum (+ (car a) (car b)))) (if (< sum base) (cons sum (bn+ (cdr a) (cdr b))) (cons (- sum base) (bn+carry (cdr a) (cdr b)))))))) (define (bn+carry a b) (cond ((null? a) (bn+1 b)) ((null? b) (bn+1 a)) (else (let ((sum (+ (car a) (car b) 1))) (if (< sum base) (cons sum (bn+ (cdr a) (cdr b))) (cons (- sum base) (bn+carry (cdr a) (cdr b)))))))) ;; CAUTION: assumes 0 <= b < base (define (bn+fix a b) (cond ((zero? b) a) ((null? a) (list b)) (else (let ((sum (+ (car a) b))) (if (< sum base) (cons sum (cdr a)) (cons (- sum base) (bn+1 (cdr a)))))))) ;;; Subtraction (define (bn-1 a) (if (null? a) (error "bn-1: subtract from zero")) (let ((head (car a)) (tail (cdr a))) (cond ((zero? head) (cons base-1 (bn-1 tail))) ((and (= head 1) (null? tail)) '()) (else (cons (- head 1) tail))))) (define (bn- a b) (cond ((null? a) (if (null? b) b (error "bn-: subtract from zero"))) ((null? b) a) (else (let ((diff (- (car a) (car b)))) (if (< diff 0) (cons (+ diff base) (bn-sub-borrow (cdr a) (cdr b))) (let ((tail (bn- (cdr a) (cdr b)))) (if (and (= diff 0) (null? tail)) '() (cons diff tail)))))))) (define (bn-sub-borrow a b) (cond ((null? a) (error "bn-: subtract from zero")) ((null? b) (bn-1 a)) (else (let ((diff (- (car a) (car b) 1))) (if (< diff 0) (let ((tail (bn-sub-borrow (cdr a) (cdr b))) (diff (+ diff base))) (if (and (= diff 0) (null? tail)) '() (cons diff tail))) (let ((tail (bn- (cdr a) (cdr b)))) (if (and (= diff 0) (null? tail)) '() (cons diff tail)))))))) ;;; Multiplication (define (bn*2 a) (if (null? a) '() (let ((product (* (car a) 2))) (if (< product base) (cons product (bn*2 (cdr a))) (cons (- product base) (bn*2+carry (cdr a))))))) (define (bn*2+carry a) (if (null? a) bn1 (let ((product (+ (* (car a) 2) 1))) (if (< product base) (cons product (bn*2 (cdr a))) (cons (- product base) (bn*2+carry (cdr a))))))) ;; CAUTION: assumes 0 <= scale < base (define (bn*fix a scale) (if (or (null? a) (zero? scale)) '() (let ((product (* (car a) scale))) (if (< product base) (cons product (bn*fix (cdr a) scale)) (cons (remainder product base) (bn*fix+carry (cdr a) scale (quotient product base))))))) (define (bn*fix+carry a scale carry) (if (or (null? a) (zero? scale)) (list carry) (let ((product (+ (* (car a) scale) carry))) (if (< product base) (cons product (bn*fix (cdr a) scale)) (cons (remainder product base) (bn*fix+carry (cdr a) scale (quotient product base))))))) (define (bn-shift a bits) (cond ((< bits 0) (error "bn-shift: negative bits")) ((null? a) a) (else (bn*fix (shift-words-nz a (quotient bits base-bits)) (expt 2 (remainder bits base-bits)))))) (define (simple* a b) (define (a* b) (if (null? b) b (bn+ (bn*fix a (car b)) (shift-words (a* (cdr b)) 1)))) (if (null? a) a (a* b))) ;; Still quadratic, but ~30% faster than generic multiplication (define (simple^2 a) (if (null? a) '() (let* ((hd (car a)) (tl (cdr a)) (hd^2 (* hd hd)) (hd^2 (if (< hd^2 base) (word->bn hd^2) (list (remainder hd^2 base) (quotient hd^2 base))))) (if (null? tl) hd^2 (bn+ hd^2 (cons 0 (bn+ (cons 0 (simple^2 tl)) (bn*fix (bn*2 tl) hd)))))))) (define (strip-leading-zeros l) (cond ((null? l) l) ((zero? (car l)) (strip-leading-zeros (cdr l))) (else l))) (define (bn-split a k cont) (do ((head '() (cons (car tail) head)) (tail a (cdr tail)) (k k (- k 1))) ((or (null? tail) (zero? k)) (cont (reverse (strip-leading-zeros head)) tail)))) ;; Karatsuba multiplication (define (bn* a b) (if (or (null? a) (null? b)) '() (let ((n (max (length a) (length b)))) (if (< n karatsuba-threshold) (simple* a b) (let ((k (ceil-quotient n 2))) (define (with-split a0 a1 b0 b1) (let ((c0 (bn* a0 b0)) (c1 (bn* a1 b1)) (sa (bn> a0 a1)) (sb (bn> b0 b1))) (let ((c2 (bn* (if sa (bn- a0 a1) (bn- a1 a0)) (if sb (bn- b0 b1) (bn- b1 b0))))) (bn+ (bn+ c0 (shift-words ((if (eq? sa sb) bn- bn+) (bn+ c0 c1) c2) k)) (shift-words c1 (* 2 k)))))) (bn-split a k (lambda (a0 a1) (bn-split b k (lambda (b0 b1) (with-split a0 a1 b0 b1)))))))))) ;; I suspect this isn't optimal Karatsuba squaring, but at least it gets ;; to the base case of simple^2. (define (bn^2 a) (let ((n (length a))) (if (< n karatsuba-sqr-threshold) (simple^2 a) (let ((k (ceil-quotient n 2))) (define (with-split a0 a1) (let ((c0 (bn^2 a0)) (c1 (bn^2 a1))) (let ((c2 (bn^2 (if (bn> a0 a1) (bn- a0 a1) (bn- a1 a0))))) (bn+ (bn+ c0 (shift-words (bn- (bn+ c0 c1) c2) k)) (shift-words c1 (* 2 k)))))) (bn-split a k with-split))))) ;;; Division (define (bn/2 a) (if (null? a) a (cdr (bn*fix a base/2)))) (define (bn-unshift a bits) (if (< bits 0) (error "bn-unshift: negative bits")) (let* ((full-words (quotient bits base-bits)) (extra-bits (remainder bits base-bits)) (a (unshift-words a full-words))) (if (or (null? a) (zero? extra-bits)) a (cdr (bn*fix a (expt 2 (- base-bits extra-bits))))))) (define (num-bit-shifts start target) ;; optimize? (define (loop start n) (if (>= start target) n (loop (* start 2) (+ n 1)))) (loop start 0)) (define (last l) (if (null? (cdr l)) (car l) (last (cdr l)))) (define (bn-divrem a b return) (if (null? b) (error "division by zero")) ;; Normalize: most sig. bit of most sig. word of divisor must be 1 (let ((msw (last b))) (if (>= msw base/2) (div-normalized a b return) (let ((s (expt 2 (num-bit-shifts msw base/2)))) (div-normalized (bn*fix a s) (bn*fix b s) ;; optimize (lambda (q r) (return q (if (null? r) r (cdr (bn*fix r (quotient base s))))))))))) (define (divrem-q q r) q) (define (divrem-r q r) r) (define (bn-quotient a b) (bn-divrem a b divrem-q)) ;; optimize? (define (bn-remainder a b) (bn-divrem a b divrem-r)) (define (slice-2 a k) (if (null? a) 0 (let ((tail (cdr a))) (if (zero? k) (if (null? tail) (car a) (+ (car a) (* (car tail) base))) (slice-2 tail (- k 1)))))) (define (most-sig-word a) ;; assumes not null (define (loop a tail) (if (null? tail) (car a) (loop tail (cdr tail)))) (loop a (cdr a))) (define (div-normalized A B return) (let* ((n (length B)) (m (- (length A) n))) (if (< m 0) (return '() A) (let ((n-1 (- n 1)) (B-msw (most-sig-word B))) (define (get-qi i Q A) (define (try-qi qi) (let ((prod (shift-words (bn*fix B qi) i))) (if (bn< A prod) (try-qi (- qi 1)) (get-qi (- i 1) (cons qi Q) (bn- A prod))))) (if (< i 0) (return Q A) (try-qi (min base-1 (quotient (slice-2 A (+ n-1 i)) B-msw))))) (let ((B-shift (shift-words B m))) (if (bn< A B-shift) (get-qi (- m 1) '() A) (get-qi (- m 1) '(1) (bn- A B-shift)))))))) ;; Quotient and remainder of a/b where (<= (bn-bits (- a 1)) max-bits). ;; For accelerating repeated divides by the same number by converting to ;; multiply-and-shift with a precomputed inverse. (define (bn-divrem-by b max-bits) (let* ((n (bn-shift '(1) max-bits)) (n/b (bn-quotient n b))) (lambda (a cont) (if (bn> a n) (error "bn-divrem-by: precision too low for dividend") (let* ((q (bn-unshift (bn* a n/b) max-bits)) (r (bn- a (bn* q b)))) (if (bn< r b) (cont q r) (cont (bn+1 q) (bn- r b)))))))) ;; Multiplicative inverse of a mod n: ;; (bn-remainder (bn* a (bn-mod-inverse a n)) n) -> bn1 ;; Assumes reduced input (0 < a < n) (define (bn-mod-inverse a n) ;; Extended Euclidean algorithm: find x where ax + by = gcd(a, b) ;; If the gcd is 1, it follows that ax == 1 mod b ;; See [HAC] Algorithm 2.142 / 2.107 ;; Simplified / adjusted for unsigned bignums ;; Invariants: ;; [1] 0 <= r < b (loop terminates when b would reach 0) ;; [2] 0 < b < a (by [1], as b is last r and a is last b) ;; [3] q > 0 (by [2]) ;; [4] If x > 0, then last x <= 0 and next x < 0 ;; If x < 0, then last x > 0 and next x > 0 ;; (by [3], as next x is last x - qx) ;; Full proofs (mostly by induction) left as an exercise to the reader. (define (loop a b x neg last-x) (bn-divrem a b (lambda (q r) (if (bn-zero? r) (if (bn= b bn1) (if neg (bn- n x) x) (error "not invertible (modulus not prime?)")) (loop b r (bn+ last-x (bn* q x)) (not neg) x))))) (loop n a bn1 #f bn0)) ;;; Exponentials ;; a^b for positive fixnum b (define (bn-expt a b) (cond ((= b 0) bn1) ((< b 0) (error "negative exponent")) ((even? b) (bn-expt (bn^2 a) (quotient b 2))) (else (bn* a (bn-expt a (- b 1)))))) ; a^b mod n, all bignums (define (bn-expmod a b n) (let* ((/n (bn-divrem-by n (* 2 (bn-bits n)))) (modn (lambda (x) (/n x divrem-r)))) (do ((b b (bn/2 b)) (a^2^bits (bn-remainder a n) (modn (bn^2 a^2^bits))) (acc bn1 (if (bn-even? b) acc (modn (bn* acc a^2^bits))))) ((bn-zero? b) acc)))) ;; Number of significant bits ;; = least integer b such that 2^b > a ;; = ceil(log_2(a+1)) (define (bn-bits a) (if (null? a) 0 (+ (num-bit-shifts 1 (+ (last a) 1)) (* (- (length a) 1) base-bits)))) ;;; Random number generation (define (read-bytes n port) (let ((v (make-vector n))) (do ((k 0 (+ k 1))) ((= k n) v) (vector-set! v k (char->integer (read-char port)))))) (define (ceil-quotient a b) (quotient (+ a b -1) b)) ; Unbiased random integer generator in the interval [0, n) (define (rand-bn n) (if (bn-zero? n) (error "rand-bn: zero range")) (let* ;; Collecting one more byte than strictly necessary avoids cases ;; where a large part of the range is invalid (e.g. n=130) ((nbytes (+ (ceil-quotient (bn-bits (bn-1 n)) 8) 1)) (rand-range (bn-expt bn2 (* nbytes 8))) (unbiased-range (bn- rand-range (bn-remainder rand-range n)))) (lambda (rng-port) (define (retry) (let ((r (bytes->bn (read-bytes nbytes rng-port)))) (if (bn< r unbiased-range) (bn-remainder r n) (retry)))) (retry)))) (export bn0 bn1 bn2 hexdigit->fix decdigit->fix ;; not strictly bignum ops, but handy bn->hex hex->bn bytes->bn bn->dec dec->bn bn->fix fix->bn bn-zero? bn-even? bn-odd? bn= bn< bn> bn<= bn>= bn+1 bn+ bn-1 bn- bn*2 bn-shift bn*fix bn* bn^2 bn/2 bn-unshift bn-divrem bn-quotient bn-remainder bn-divrem-by bn-mod-inverse bn-expt bn-expmod bn-bits rand-bn))))