;;;; modes.lisp -- using encryption modes with block ciphers

(in-package :crypto)

;;; internal entry points to assure speed
(defgeneric encrypt-with-mode (cipher mode plaintext ciphertext
                                      &key plaintext-start
                                      plaintext-end
                                      ciphertext-start)
  (:documentation "Encrypt PLAINTEXT, beginning at PLAINTEXT-START and
continuing until PLAINTEXT-END, according to CIPHER in mode MODE.  Place
the result in CIPHERTEXT, beginning at CIPHERTEXT-START.  PLAINTEXT and
CIPHERTEXT are allowed to be the same array.  Return the number of bytes
encrypted, which may be less than specified."))

(defgeneric decrypt-with-mode (cipher mode ciphertext plaintext
                                      &key ciphertext-start
                                      ciphertext-end
                                      plaintext-start)
  (:documentation "Decrypt CIPHERTEXT, beginning at CIPHERTEXT-START and
continuing until CIPHERTEXT-END, according to CIPHER in mode MODE.  Place
the result in PLAINTEXT, beginning at PLAINTEXT-START.  CIPHERTEXT and
PLAINTEXT are allowed to be the same array.  Return the number of bytes
encrypted, which may be less than specified."))

(defclass encryption-mode ()
  ((cipher :reader cipher :initarg :cipher)))
(defclass ecb-mode (encryption-mode) ())
(defclass inititialization-vector-mixin ()
  ((iv :reader iv :initarg :initialization-vector)
   (position :accessor iv-position :initform 0)))
(defclass cbc-mode (encryption-mode inititialization-vector-mixin) ())
(defclass ofb-mode (encryption-mode inititialization-vector-mixin) ())
(defclass cfb-mode (encryption-mode inititialization-vector-mixin) ())
(defclass ctr-mode (encryption-mode inititialization-vector-mixin)
  ((encrypted-iv :reader encrypted-iv :initarg :encrypted-iv)))

(defmethod initialize-instance :after ((mode ctr-mode) &key)
  (let ((iv (iv mode)))
    (setf (slot-value mode 'encrypted-iv) (copy-seq iv))))

(defvar *supported-modes* (list :ecb :cbc :ofb :cfb :ctr))

(defun mode-supported-p (name)
  (member name *supported-modes*))

(defun list-all-modes ()
  (copy-seq *supported-modes*))

(defmethod encrypt (context plaintext ciphertext
                    &key (plaintext-start 0) plaintext-end
                    (ciphertext-start 0))
  (encrypt-with-mode (cipher context) context plaintext ciphertext
                     :plaintext-start plaintext-start
                     :plaintext-end plaintext-end
                     :ciphertext-start ciphertext-start))

(defmethod decrypt (context ciphertext plaintext
                    &key (ciphertext-start 0) ciphertext-end
                    (plaintext-start 0))
  (decrypt-with-mode (cipher context) context ciphertext plaintext
                     :ciphertext-start ciphertext-start
                     :ciphertext-end ciphertext-end
                     :plaintext-start plaintext-start))

(defgeneric generate-cipher-mode-functions (mode context-name block-length
                                                 encryption-function
                                                 decryption-function)
  (:documentation "Return a list of DEFMETHOD forms which implement
encryptions and decryption for the cipher with CONTEXT-NAME in MODE.
BLOCK-LENGTH is the block length of the cipher in bytes; ENCRYPTION-FUNCTION
and DECRYPTION-FUNCTION are symbols FBOUND to the appropriate functionality
for the given cipher.

This is an internal interface to the CRYPTO package."))

(defmethod generate-cipher-mode-functions ((mode (eql :ecb)) context-name
                                           block-length
                                           encryption-function
                                           decryption-function)
  (list
   `(defmethod encrypt-with-mode ((context ,context-name) (mode ecb-mode)
                                  plaintext ciphertext
                                  &key (plaintext-start 0) plaintext-end
                                  (ciphertext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) plaintext ciphertext))
     (loop with offset = plaintext-start
           with plaintext-end = (or plaintext-end (length plaintext))
           while (<= (+ offset ,block-length) plaintext-end)
           do (,encryption-function context plaintext offset
                                    ciphertext ciphertext-start)
           (incf offset ,block-length)
           (incf ciphertext-start ,block-length)
           finally (return (- offset plaintext-start))))
   `(defmethod decrypt-with-mode ((context ,context-name) (mode ecb-mode)
                                  ciphertext plaintext
                                  &key (ciphertext-start 0) ciphertext-end
                                  (plaintext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) ciphertext plaintext))
     (loop with offset = ciphertext-start
           with ciphertext-end = (or ciphertext-end (length ciphertext))
           while (<= (+ offset ,block-length) ciphertext-end)
           do (,decryption-function context ciphertext offset
                                    plaintext plaintext-start)
           (incf offset ,block-length)
           (incf plaintext-start ,block-length)
           finally (return (- offset ciphertext-start))))))

(declaim (inline xor-block))
(defun xor-block (block-length input-block1 input-block2 input-block2-start
                               output-block output-block-start)
  (declare (type (simple-array (unsigned-byte 8) (*)) input-block1 input-block2 output-block))
  (declare (type index block-length input-block2-start output-block-start))
  ;; this could be made more efficient by doing things in a word-wise fashion.
  ;; of course, then we'd have to deal with fun things like boundary
  ;; conditions and such like.  maybe we could just win by unrolling the
  ;; loop a bit.  BLOCK-LENGTH should be a constant in all calls to this
  ;; function; maybe a compiler macro would work well.
  (dotimes (i block-length)
    (setf (aref output-block (+ output-block-start i))
          (logxor (aref input-block1 i)
                  (aref input-block2 (+ input-block2-start i))))))

#-sbcl
(defun increment-counter-block (block)
  (let ((length (length block))
        (carry 1))
    (loop for i from (1- length) downto 0
          until (zerop carry) do
          (let ((sum (+ (aref block i) carry)))
            (setf (aref block i) (ldb (byte 8 0) sum)
                  carry (ldb (byte 1 8) sum))))
    (values)))

#+sbcl
(defun increment-counter-block (block)
  (let ((words (truncate (length block) sb-vm:n-word-bytes))
        (carry 1))
    (loop for i from (1- words) downto 0
          until (zerop carry) do
          (let ((word (sb-kernel:%vector-raw-bits block i)))
            (multiple-value-setq (word carry)
              (sb-bignum:%add-with-carry word 0 carry))
            (setf (sb-kernel:%vector-raw-bits block i) word))
          (values))))
          
(defmethod generate-cipher-mode-functions ((mode (eql :cbc)) context-name
                                           block-length
                                           encryption-function
                                           decryption-function)
  (list
   `(defmethod encrypt-with-mode ((context ,context-name) (mode cbc-mode)
                                  plaintext ciphertext
                                  &key (plaintext-start 0) plaintext-end
                                  (ciphertext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) plaintext ciphertext))
     (let ((plaintext-end (or plaintext-end (length plaintext)))
           (iv (iv mode))
           (offset plaintext-start))
       (declare (type (simple-array (unsigned-byte 8) (,block-length)) iv))
       (loop while (<= (+ offset ,block-length) plaintext-end)
         do (xor-block ,block-length iv plaintext offset ciphertext ciphertext-start)
         (,encryption-function context ciphertext ciphertext-start
                               ciphertext ciphertext-start)
         ;; ugh, this is probably slow
         (replace iv ciphertext :start1 0 :end1 ,block-length :start2 ciphertext-start)
         (incf offset ,block-length)
         (incf ciphertext-start ,block-length)
         finally (return (- offset plaintext-start)))))
   `(defmethod decrypt-with-mode ((context ,context-name) (mode cbc-mode)
                                  ciphertext plaintext
                                  &key (ciphertext-start 0) ciphertext-end
                                  (plaintext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) plaintext ciphertext))
     (let ((ciphertext-end (or ciphertext-end (length ciphertext)))
           (iv (iv mode))
           (offset ciphertext-start)
           ;; stash the ciphertext block in case (EQ PLAINTEXT CIPHERTEXT)
           (temp-block (make-array ,block-length :element-type '(unsigned-byte 8))))
       (declare (type (simple-array (unsigned-byte 8) (,block-length)) iv temp-block))
       (declare (dynamic-extent temp-block))
       (loop while (<= (+ offset ,block-length) ciphertext-end)
         do (replace temp-block ciphertext :start1 0 :end1 ,block-length :start2 offset)
         (,decryption-function context ciphertext offset
                               plaintext plaintext-start)
         (xor-block ,block-length iv plaintext plaintext-start plaintext plaintext-start)
         ;; ugh, this is probably slow
         (replace iv temp-block :end1 ,block-length :end2 ,block-length)
         (incf offset ,block-length)
         (incf plaintext-start ,block-length)
         finally (return (- offset ciphertext-start)))))))

(defmethod generate-cipher-mode-functions ((mode (eql :cfb)) context-name
                                           block-length
                                           encryption-function
                                           decryption-function)
  (declare (ignore decryption-function))
  (list
   `(defmethod encrypt-with-mode ((context ,context-name) (mode cfb-mode)
                                  plaintext ciphertext
                                  &key (plaintext-start 0) plaintext-end
                                  (ciphertext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) plaintext ciphertext))
     (let ((plaintext-end (or plaintext-end (length plaintext)))
           (iv (iv mode))
           (iv-position (iv-position mode)))
       (declare (type (simple-array (unsigned-byte 8) (,block-length)) iv))
       (declare (type (integer 0 ,block-length) iv-position))
       (do ((i plaintext-start (1+ i))
            (j ciphertext-start (1+ j)))
           ((>= i plaintext-end)
            (setf (iv-position mode) iv-position)
            (- plaintext-end plaintext-start))
         (when (zerop iv-position)
           (,encryption-function context iv 0 iv 0))
         (let ((b (logxor (aref plaintext i) (aref iv iv-position))))
           (setf (aref ciphertext j) b)
           (setf (aref iv iv-position) b)
           (setf iv-position (mod (1+ iv-position) ,block-length))))))
   `(defmethod decrypt-with-mode ((context ,context-name) (mode cfb-mode)
                                  ciphertext plaintext
                                  &key (ciphertext-start 0) ciphertext-end
                                  (plaintext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) ciphertext plaintext))
     (let ((ciphertext-end (or ciphertext-end (length ciphertext)))
           (iv (iv mode))
           (iv-position (iv-position mode)))
       (declare (type (simple-array (unsigned-byte 8) (,block-length)) iv))
       (declare (type (integer 0 ,block-length) iv-position))
       (do ((i ciphertext-start (1+ i))
            (j plaintext-start (1+ j)))
           ((>= i ciphertext-end)
            (setf (iv-position mode) iv-position)
            (- ciphertext-end ciphertext-start))
         (when (zerop iv-position)
           (,encryption-function context iv 0 iv 0))
         (let ((b (logxor (aref ciphertext i) (aref iv iv-position))))
           (setf (aref iv iv-position) (aref ciphertext i))
           (setf (aref plaintext j) b)
           (setf iv-position (mod (1+ iv-position) ,block-length))))))))

(defmethod generate-cipher-mode-functions ((mode (eql :ofb)) context-name
                                           block-length
                                           encryption-function
                                           decryption-function)
  (declare (ignore decryption-function))
  (list
   `(defmethod encrypt-with-mode ((context ,context-name) (mode ofb-mode)
                                  plaintext ciphertext
                                  &key (plaintext-start 0) plaintext-end
                                  (ciphertext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) plaintext ciphertext))
     (let ((plaintext-end (or plaintext-end (length plaintext)))
           (iv (iv mode))
           (iv-position (iv-position mode)))
       (declare (type (simple-array (unsigned-byte 8) (,block-length)) iv))
       (declare (type (integer 0 ,block-length) iv-position))
       (do ((i plaintext-start (1+ i))
            (j ciphertext-start (1+ j)))
           ((>= i plaintext-end)
            (setf (iv-position mode) iv-position)
            (- plaintext-end plaintext-start))
         (when (zerop iv-position)
           (,encryption-function context iv 0 iv 0))
         (setf (aref ciphertext j) (logxor (aref plaintext i)
                                           (aref iv iv-position)))
         (setf iv-position (mod (1+ iv-position) ,block-length)))))
   `(defmethod decrypt-with-mode ((context ,context-name) (mode ofb-mode)
                                  ciphertext plaintext
                                  &key (ciphertext-start 0) ciphertext-end
                                  (plaintext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) plaintext ciphertext))
     (let ((ciphertext-end (or ciphertext-end (length ciphertext)))
           (iv (iv mode))
           (iv-position (iv-position mode)))
       (declare (type (simple-array (unsigned-byte 8) (,block-length)) iv))
       (declare (type (integer 0 ,block-length) iv-position))
       (do ((i plaintext-start (1+ i))
            (j ciphertext-start (1+ j)))
           ((>= i ciphertext-end)
            (setf (iv-position mode) iv-position)
            (- ciphertext-end ciphertext-start))
         (when (zerop iv-position)
           (,encryption-function context iv 0 iv 0))
         (setf (aref plaintext i) (logxor (aref ciphertext j)
                                          (aref iv iv-position)))
         (setf iv-position (mod (1+ iv-position) ,block-length)))))))

(defmethod generate-cipher-mode-functions ((mode (eql :ctr)) context-name
                                           block-length
                                           encryption-function
                                           decryption-function)
  (declare (ignore decryption-function))
  (list
   `(defmethod encrypt-with-mode ((context ,context-name) (mode ctr-mode)
                                  plaintext ciphertext
                                  &key (plaintext-start 0) plaintext-end
                                  (ciphertext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) plaintext ciphertext))
     (let ((plaintext-end (or plaintext-end (length plaintext)))
           (iv (iv mode))
           (iv-position (iv-position mode))
           (encrypted-iv (encrypted-iv mode)))
       (declare (type (simple-array (unsigned-byte 8) (,block-length))
                      iv encrypted-iv))
       (declare (type (integer 0 ,block-length) iv-position))
       (do ((i plaintext-start (1+ i))
            (j ciphertext-start (1+ j)))
           ((>= i plaintext-end)
            (setf (iv-position mode) iv-position)
            (- plaintext-end plaintext-start))
         (when (zerop iv-position)
           (,encryption-function context iv 0 encrypted-iv 0)
           (increment-counter-block iv))
         (setf (aref ciphertext j) (logxor (aref plaintext i)
                                           (aref encrypted-iv iv-position)))
         (setf iv-position (mod (1+ iv-position) ,block-length)))))
   `(defmethod decrypt-with-mode ((context ,context-name) (mode ctr-mode)
                                  ciphertext plaintext
                                  &key (ciphertext-start 0) ciphertext-end
                                  (plaintext-start 0))
     (declare (type (simple-array (unsigned-byte 8) (*)) plaintext ciphertext))
     (let ((ciphertext-end (or ciphertext-end (length ciphertext)))
           (iv (iv mode))
           (iv-position (iv-position mode))
           (encrypted-iv (encrypted-iv mode)))
       (declare (type (simple-array (unsigned-byte 8) (,block-length))
                      iv encrypted-iv))
       (declare (type (integer 0 ,block-length) iv-position))
       (do ((i plaintext-start (1+ i))
            (j ciphertext-start (1+ j)))
           ((>= i ciphertext-end)
            (setf (iv-position mode) iv-position)
            (- ciphertext-end ciphertext-start))
         (when (zerop iv-position)
           (,encryption-function context iv 0 encrypted-iv 0)
           (increment-counter-block iv))
         (setf (aref plaintext i) (logxor (aref ciphertext j)
                                          (aref encrypted-iv iv-position)))
         (setf iv-position (mod (1+ iv-position) ,block-length)))))))
