cl-quantum/math.lisp

468 lines
16 KiB
Common Lisp
Raw Normal View History

2024-12-06 07:54:58 -08:00
(in-package :cl-quantum)
(defmacro domatrix ((var matrix &optional retval) &body body)
"Execute BODY for with VAR bound once for each element in MATRIX, then
evaluate and return RETVAL. VAR can be of one of the following three forms:
- A symbol that will be bound on each iteration
- A list of three symbols which are variables to bind to the value, row,
and column on each iteration
- A list of two symbols which are variables to bind to the row and column on
each iteration"
(let* ((matrix-var (gensym))
internal-form row-var col-var)
(cond
((symbolp var)
(setq row-var (gensym)
col-var (gensym)
internal-form `(let ((,var (aref ,matrix-var ,row-var ,col-var)))
,@body)))
((and (listp var) (= (length var) 2))
(setf internal-form `(progn ,@body)
row-var (first var)
col-var (second var)))
((and (listp var) (= (length var) 3))
(setq row-var (second var)
col-var (third var)
internal-form `(let ((,(first var)
(aref ,matrix-var ,row-var ,col-var)))
,@body)))
(t (error "Malformed VAR spec: ~s" var)))
`(loop with ,matrix-var = ,matrix
for ,row-var below (array-dimension ,matrix-var 0) do
(loop for ,col-var below (array-dimension ,matrix-var 1) do
,internal-form)
finally (return ,retval))))
(defun mapmatrix (function matrix)
"Execute FUNCTION for each element in MATRIX. Return a new matrix made of the
return values of FUNCTION. FUNCTION should be a function of three arguments: the
VALUE, the ROW, and the COLUMN."
(let ((new-mat (make-array (array-dimensions matrix))))
(domatrix ((elem row col) matrix new-mat)
(setf (aref new-mat row col) (funcall function elem row col)))))
;; Matrix subroutines
(defun mat-minor (mat i j)
"Find the minor of MAT for I and J."
(destructuring-bind (height width)
(array-dimensions mat)
(let ((minor (make-array (list (1- height)
(1- width)))))
(dotimes (row height minor)
(dotimes (col width)
(unless (or (= row i)
(= col j))
(let ((out-row (if (> row i)
(1- row)
row))
(out-col (if (> col j)
(1- col)
col)))
(setf (aref minor out-row out-col) (aref mat row col)))))))))
(defun cofactor-sgn (i j)
"Return the sign of the cofactor at I and J."
(expt -1 (+ i j)))
(defun cofactor (mat i j)
"Find the cofactor for I and J in MAT."
(* (cofactor-sgn i j) (det (mat-minor mat i j))))
2024-12-06 22:24:29 -08:00
(defun first-column-cofactors (mat)
"Find the cofactors for the first column of MAT."
(let* ((height (array-dimension mat 0))
(out-arr (make-array height)))
(dotimes (i height out-arr)
(setf (aref out-arr i) (cofactor mat i 0)))))
2024-12-06 07:54:58 -08:00
(defun det2x2 (mat)
"Find the determinate of a 2x2 matrix MAT."
(let ((a (aref mat 0 0))
(b (aref mat 0 1))
(c (aref mat 1 0))
(d (aref mat 1 1)))
(- (* a d) (* b c))))
(defun det (mat &key first-column-cofactors)
"Find the determinant of MAT. If the cofactors for the first column have
already been calculated, they can be supplied in FIRST-COLUMN-COFACTORS."
(destructuring-bind (height width)
(array-dimensions mat)
(if (and (= height 2)
(= width 2))
(det2x2 mat)
(loop for i below height
when first-column-cofactors
summing (* (aref mat i 0) (aref first-column-cofactors i))
else
summing (* (aref mat i 0) (cofactor mat i 0))))))
(defun invert2x2 (mat)
"Invert the 2x2 matrix MAT."
(let ((a (aref mat 0 0))
(b (aref mat 0 1))
(c (aref mat 1 0))
(d (aref mat 1 1))
(ood (/ (det2x2 mat))))
(make-array '(2 2)
:initial-contents (list (list (* ood d) (* ood (- b)))
(list (* ood (- c)) (* ood a))))))
(defun invert (mat)
"Invert MAT. This will signal `division-by-zero' if MAT is singular."
(destructuring-bind (height width)
(array-dimensions mat)
(if (and (= width 2)
(= height 2))
(invert2x2 mat)
(let* ((first-column-cofactors (first-column-cofactors mat))
(one-over-det (/ (det mat :first-column-cofactors
first-column-cofactors))))
(mapmatrix (lambda (val row col)
(declare (ignorable val))
;; this calculates 1/det * adjugate[i][j]
(* one-over-det
(if (and first-column-cofactors (zerop row))
(aref first-column-cofactors col)
(cofactor mat col row))))
mat)))))
(defun transpose (mat)
"Transpose MAT."
(let ((out-mat (make-array (reverse (array-dimensions mat)))))
(domatrix ((val row col) mat out-mat)
(setf (aref out-mat col row) val))))
(defun norm (vec)
"Return the norm of VEC."
(sqrt (reduce (lambda (sum elt)
(+ sum (* elt elt)))
vec :initial-value 0)))
(defun dot-row-col (mat1 mat2 row col)
"Take the dot (scalar) product of ROW of MAT1 and COL of MAT2."
(let ((sum 0))
(dotimes (n (array-dimension mat2 0) sum)
(setq sum (+ sum (* (aref mat1 row n)
(aref mat2 n col)))))))
(defun *mm (mat1 mat2)
"Multiply MAT1 by MAT2."
(assert (= (array-dimension mat1 1)
(array-dimension mat2 0))
(mat1 mat2)
"Cannot multiply ~s by ~s." mat1 mat2)
(let* ((width (array-dimension mat1 0))
(height (array-dimension mat2 1))
(out-mat (make-array (list height width))))
(dotimes (i height out-mat)
(dotimes (j width)
(let ((dot (dot-row-col mat1 mat2 i j)))
(setf (aref out-mat i j) dot))))))
(defun *mv (mat vec)
"Multiply MAT by VEC."
(assert (= (array-dimension mat 1)
(length vec))
(mat vec)
"Cannot multiply ~s by ~s." mat vec)
(let* ((width (array-dimension mat 1))
(height (array-dimension mat 0))
(out-vec (make-array height)))
(dotimes (row height out-vec)
(setf (aref out-vec row)
(loop for col below width
summing (* (aref vec col) (aref mat row col)))))))
(defun *vm (vec mat)
"Multiply VEC by MAT."
(assert (= (array-dimension mat 0)
(length vec))
(vec mat)
"Cannot multiply ~s by ~s." vec mat)
(let* ((width (array-dimension mat 1))
(height (array-dimension mat 0))
(out-vec (make-array height)))
(dotimes (row height out-vec)
(setf (aref out-vec row)
(loop for col below width
summing (* (aref vec col) (aref mat col row)))))))
(defun dot (vec1 vec2)
"Compute the dot product (scalar product) of VEC1 and VEC2."
(assert (= (length vec1) (length vec2))
(vec1 vec2)
"Cannot multiply ~s by ~s." vec1 vec2)
(loop for i below (length vec1)
summing (* (aref vec1 i) (aref vec2 i))))
(defun +mm (mat1 mat2)
"Add MAT1 to MAT2."
(assert (equal (array-dimensions mat1)
(array-dimensions mat2))
(mat1 mat2)
"Cannot add ~s and ~s." mat1 mat2)
(mapmatrix (lambda (val row col)
(declare (ignorable row col))
(+ val (aref mat2 row col)))
mat1))
(defun +vv (vec1 vec2)
"Add VEC1 to VEC2."
(assert (= (length vec1)
(length vec2))
(vec1 vec2)
"Cannot add ~s and ~s." vec1 vec2)
(let ((sum (make-array (length vec1))))
(dotimes (i (length vec1) sum)
(setf (aref sum i) (+ (aref vec1 i)
(aref vec2 i))))))
(defun -mm (mat1 mat2)
"Subtract MAT2 from MAT1."
(assert (equal (array-dimensions mat1)
(array-dimensions mat2))
(mat1 mat2)
"Cannot subtract ~s and ~s." mat2 mat1)
(mapmatrix (lambda (val row col)
(declare (ignorable row col))
(- val (aref mat2 row col)))
mat1))
(defun -vv (vec1 vec2)
"Subtract VEC2 from VEC1."
(assert (= (length vec1)
(length vec2))
(vec1 vec2)
"Cannot subtract ~s from ~s." vec2 vec1)
(let ((sum (make-array (length vec1))))
(dotimes (i (length vec1) sum)
(setf (aref sum i) (- (aref vec1 i)
(aref vec2 i))))))
(defun *ms (mat scalar)
"Multiply MAT by SCALAR."
(mapmatrix (lambda (val row col)
(declare (ignorable row col))
(* val scalar))
mat))
(defun *vs (vec scalar)
"Multiply VEC by SCALAR."
(map 'vector (lambda (elt)
(* elt scalar))
vec))
(defun /ms (mat scalar)
"Divide MAT by SCALAR."
(*ms mat (/ scalar)))
(defun /vs (vec scalar)
"Divide VEC by SCALAR."
(*vs vec (/ scalar)))
(defun mconj (mat)
"Return the conjugate of MAT."
(mapmatrix (lambda (val row col)
(declare (ignorable row col))
(conjugate val))
mat))
(defun vconj (vec)
"Return the conjugate of VEC."
(map 'vector 'conjugate vec))
(defun squarep (mat)
"Return non-nil if MAT is a square matrix."
(and (= (array-rank mat) 2)
(apply '= (array-dimensions mat))))
(defun singularp (mat)
"Return non-nil if MAT is singular."
(zerop (det mat)))
(defun mtrace (mat)
"Return the trace of MAT."
(assert (squarep mat)
(mat)
"Not a square matrix: ~s" mat)
(loop for i below (array-dimension mat 0)
summing (aref mat i i)))
(defun make-identity-matrix (n)
"Return an N by N identity matrix."
(let ((mat (make-array (list n n))))
(dotimes (i n mat)
(setf (aref mat i i) 1))))
2024-12-06 22:24:29 -08:00
(defun copy-matrix (mat)
"Return a copy of MAT."
(mapmatrix (lambda (val row col)
(declare (ignorable row col))
val)
mat))
(defun nswap-rows (mat r1 r2)
"Swap rows R1 and R2 in mat. R1 and R2 are 0-indexed. This operation is
destructive."
(let ((width (array-dimension mat 1)))
(dotimes (i width mat)
(rotatef (aref mat r1 i) (aref mat r2 i)))))
(defun swap-rows (mat r1 r2)
"Swap rows R1 and R2 in mat. R1 and R2 are 0-indexed. This operation is
not destructive."
(mapmatrix (lambda (val row col)
(cond
((= r1 row)
(aref mat r2 col))
((= r2 row)
(aref mat r1 col))
(t val)))
mat))
(defun nscale-row (mat row scale)
"Replace ROW in MAT with itself multiplied by SCALE. ROW is 0-indexed."
(let ((width (array-dimension mat 1)))
(dotimes (i width mat)
(setf (aref mat row i)
(* scale (aref mat row i))))))
(defun scale-row (mat row scale)
"Like `nscale-row', but copy MAT."
(mapmatrix (lambda (val irow col)
(declare (ignorable col))
(if (= irow row)
(* val scale)
val))
mat))
(defun nreplace-row-with-sum (mat r1 r2 &key (scale 1))
"Replace row R2 in MAT with R1 + SCALE * R2. ROW is 0-indexed."
(let ((width (array-dimension mat 1)))
(dotimes (i width mat)
(incf (aref mat r1 i) (* scale (aref mat r2 i))))))
(defun replace-row-with-sum (mat r1 r2 &key (scale 1))
"Like `nreplace-row-with-sum', but copy MAT."
(mapmatrix (lambda (val row col)
(if (= row r1)
(+ val (* scale (aref mat r2 col)))
val))
mat))
(defun tensor-mm (m1 m2)
"Calculate the tensor product of M1 and M2."
(let* ((height (* (array-dimension m1 0)
(array-dimension m2 0)))
(width (* (array-dimension m1 1)
(array-dimension m2 1)))
(out-mat (make-array (list height width))))
(dotimes (row height out-mat)
(dotimes (col width)
(setf (aref out-mat row col)
(* (aref m1 (floor row (array-dimension m2 0))
(floor col (array-dimension m2 1)))
(aref m2 (mod row (array-dimension m2 0))
(mod col (array-dimension m2 1)))))))))
(defun tensor-vv (v1 v2)
"Calculate the tensor product of V1 and V2."
(apply 'concatenate 'vector
(map 'list (lambda (elt)
(*vs v2 elt))
v1)))
2024-12-06 07:54:58 -08:00
(defun round-to-place (num places &key (base 10))
"Round NUM to PLACES places in BASE."
(let ((scale (expt base places)))
(/ (floor (+ (* num scale) 1/2)) scale)))
(defun count-digits (num &key (base 10))
"Count the number of digits in NUM. If NUM is zero, return 1. If NUM is
negative, return the number of digits in its absolute value."
(if (zerop num)
1
;; throw out the extra values
(values (floor (1+ (log (abs num) base))))))
(defun build-float (int dec)
"Create a float with integer part INT and decimal part DEC."
(* (signum int) (+ (abs int) (/ dec (expt 10 (count-digits dec))))))
(defconstant +parse-real-regexp+
(ppcre:create-scanner
"^(\\s*([-+]?[0-9]+)(?:/([0-9]+)|\\.?([0-9]*)(?:[eE]([-+]?[0-9]+))?)\\s*)"
:extended-mode t)
"The regexp scanner used in `parse-real'.")
(defun parse-real (string &key (start 0) end junk-allowed)
"Parse STRING into a real. Parsing starts at START and ends at END. If end is
nil, the end of the string is used. If JUNK-ALLOWED is non-nil, don't signal an
error if an unexpected character is encountered. Two values are returned, the
first being the value parsed and the second being the index at which parsing
stopped. That is, the index of the first un-parsed character."
(values-list
(or
(ppcre:register-groups-bind (whole main denom decim exp)
(+parse-real-regexp+ string :start start :end end :sharedp t)
(unless (or junk-allowed
(= (length whole) (- (or end (length string)) start)))
(error "Malformed number: ~s" (subseq string start end)))
(let ((num
(cond
(denom
(/ (parse-integer main)
(parse-integer denom)))
((/= (length decim) 0)
(build-float (parse-integer main)
(parse-integer decim)))
(t
(parse-integer main)))))
(list (if exp
(* num (expt 10 (parse-integer exp)))
num)
(length whole))))
(if junk-allowed
(list 0 0)
(error "Malformed number: ~s" (subseq string start end))))))
(defconstant +parse-complex-regexp+
(ppcre:create-scanner
"^\\s*([-+])?\\s*([-+]?)([0-9/.]+(?:[eE][-+]?[0-9]+)?)?(i)?"
:extended-mode t)
"The regexp scanner used in `parse-complex'.")
(defun parse-complex (string &key (start 0) end junk-allowed)
"Parse STRING into a complex number. Parsing starts at START and ends at
END. If end is nil, the end of the string is used. If JUNK-ALLOWED is non-nil,
don't signal an error if an unexpected character is encountered. Two values are
returned, the first being the value parsed and the second being the index at
which parsing stopped. That is, the index of the first un-parsed character."
(unless end (setq end (length string)))
(loop for pos = start then (+ pos (length whole))
for (whole matches) = (multiple-value-list
(ppcre:scan-to-strings +parse-complex-regexp+
string
:start pos
:end end))
for times below 2
while whole
for coef = (cond
((aref matches 2)
(parse-real (concatenate 'string (aref matches 1)
(aref matches 2))))
((aref matches 3)
(if (equal (aref matches 1) "-") -1 1))
(t 0))
for sign = (if (equal (aref matches 0) "-") -1 1)
when (aref matches 3)
summing (complex 0 (* sign coef)) into num
else
summing (* sign coef) into num
finally
(if (and (not junk-allowed)
(< pos end))
(error "Junk in string: ~s" (subseq string start end))
(return (values num pos)))))