LAProof.accuracy_proofs.gemm_acc
Matrix Multiplication Forward and Mixed Error Analysis
Error Factors
Error Bound Taxonomy
Key Results
- MMC_error _(mixed)_: Shows that the floating-point matrix product
\(\mathtt{fl}(AB)\) equals the exact product of slightly perturbed
columns plus a small entry-wise absolute error. The column perturbation
is bounded column-wise by \(g(n)\) relative to the input, and the
absolute residual by \(g_1(n,n)\) per entry.
- scaleM_error _(mixed)_: Shows that floating-point scalar-matrix
multiplication equals exact scaling of a slightly perturbed matrix plus
a small entry-wise absolute error. The relative perturbation is bounded
by \(\mathbf{u}\) and the absolute residual by \(\eta\).
- sMMC_error _(mixed)_: Composes MMC_error and scaleM_error to give
a structured decomposition of \(\mathtt{fl}(x \cdot (AB))\)
with backward perturbations from both the matrix product and the scaling
step, together with forward absolute errors from each.
- mat_sum_error _(pure backward)_: Shows that floating-point matrix
addition equals the exact sum of two slightly perturbed matrices, with
each entry perturbed by a relative factor bounded by \(\mathbf{u}\).
No forward error term appears.
- mat_axpby_error _(mixed)_: Bounds \(\mathtt{fl}(xA + yB)\) by
combining mixed errors from each scaling step with a backward error from
the floating-point addition, yielding relative perturbations of the
inputs and small absolute forward errors.
- GEMM_error _(mixed)_: Master theorem for \(\mathtt{fl}(s_1(AB) + s_2 Y)\). Decomposes the full GEMM result into backward perturbation components and forward absolute errors from matrix multiplication, scalar scaling, and matrix addition.
From LAProof.accuracy_proofs Require Import
preamble common dotprod_model sum_model dot_acc
float_acc_lems mv_mathcomp gemv_acc vec_op_acc.
Section MMERROR.
We work in an abstract floating-point context NAN (specifying NaN
behavior) and over an abstract floating-point type
t .
Context {NAN : FPCore.Nans} {t : FPStdLib.type}.
Notation g := (@common.g t).
Notation g1 := (@common.g1 t).
Matrix Multiplication Error
Theorem MMC_error :
∀ m n p
(A : 'M[ftype t]_(m, n))
(B : 'M[ftype t]_(n, p))
(Hfin : F.finitemx (F.mulmx A B)),
∃ (E eta : 'M[R]_(m, p)),
map_mx FT2R (F.mulmx A B)
= (map_mx FT2R A ×m map_mx FT2R B + E + eta)%Ri
∧ (∀ k : 'I_p,
∃ E0 : 'M[R]_(m, n),
col k E = E0 ×m col k (map_mx FT2R B)
∧ (∀ i j,
Rabs (E0 i j) ≤ g n × Rabs (map_mx FT2R A i j)))
∧ (∀ i j, Rabs (eta i j) ≤ g1 n n).
Apply the induction hypothesis to the right submatrix of B.
Apply the matrix-vector mixed error lemma to the left submatrix.
Scalar-Matrix Multiplication Error
Theorem scaleM_error :
∀ m n
(A : 'M[ftype t]_(m, n))
(x : ftype t)
(Hfin : F.finitemx (F.scalemx x A)),
∃ (E eta : 'M[R]_(m, n)),
map_mx FT2R (F.scalemx x A)
= scalemx (FT2R x) (map_mx FT2R A + E) + eta
∧ (∀ i j,
Rabs (E i j) ≤ @default_rel t × Rabs (map_mx FT2R A i j))
∧ (∀ i j,
Rabs (eta i j) ≤ @default_abs t).
Scaled Matrix Product Error
Theorem sMMC_error :
∀ m n p
(A : 'M[ftype t]_(m, n))
(B : 'M[ftype t]_(n, p))
(x : ftype t)
(Hfin : F.finitemx (F.scalemx x (F.mulmx A B))),
∃ E1 E eta1 eta : 'M[R]_(m, p),
map_mx FT2R (F.scalemx x (F.mulmx A B))
= scalemx (FT2R x)
(((map_mx FT2R A ×m map_mx FT2R B + E1) + eta1) + E) + eta
∧ (∀ k : 'I_p,
∃ E0,
col k E1 = E0 ×m col k (map_mx FT2R B)
∧ (∀ i j,
Rabs (E0 i j) ≤ g n × Rabs (map_mx FT2R A i j)))
∧ (∀ i j, Rabs (eta1 i j) ≤ g1 n n)
∧ (∀ i j, Rabs (eta i j) ≤ @default_abs t)
∧ (∀ i j,
Rabs (E i j) ≤
@default_rel t ×
Rabs (((map_mx FT2R A ×m map_mx FT2R B + E1) + eta1)%Ri i j)).
Decompose the outer scaling error for x * (A*B).
Decompose the matrix-multiplication error for A*B,
propagating finiteness from F.scalemx x (F.mulmx A B).
Matrix Addition Error
Theorem mat_sum_error :
∀ m n
(A B : 'M[ftype t]_(m, n))
(Hfin : F.finitemx (F.addmx A B)),
∃ EA EB : 'M[R]_(m, n),
map_mx FT2R (F.addmx A B)
= (map_mx FT2R A + EA) + (map_mx FT2R B + EB)
∧ (∀ i j, ∃ d,
EA i j = map_mx FT2R A i j × d ∧ Rabs d ≤ @default_rel t)
∧ (∀ i j, ∃ d,
EB i j = map_mx FT2R B i j × d ∧ Rabs d ≤ @default_rel t).
Scaled Matrix Sum Error
Theorem mat_axpby_error :
∀ [m n]
(A B : 'M[ftype t]_(m, n))
(x y : ftype t)
(Hfin : F.finitemx
(F.addmx (F.scalemx x A) (F.scalemx y B))),
∃ EA EB ea eb eta1 eta2 : 'M[R]_(m, n),
map_mx FT2R (F.addmx (F.scalemx x A) (F.scalemx y B))
= scalemx (FT2R x) (map_mx FT2R A + EA) + eta1 + ea
+ scalemx (FT2R y) (map_mx FT2R B + EB) + eta2 + eb
∧ (∀ i j,
Rabs (EA i j) ≤ @default_rel t × Rabs (map_mx FT2R A i j))
∧ (∀ i j,
Rabs (EB i j) ≤ @default_rel t × Rabs (map_mx FT2R B i j))
∧ (∀ i j, ∃ d,
ea i j
= (scalemx (FT2R x) (map_mx FT2R A + EA) + eta1) i j × d
∧ Rabs d ≤ @default_rel t)
∧ (∀ i j, ∃ d,
eb i j
= (scalemx (FT2R y) (map_mx FT2R B + EB) + eta2) i j × d
∧ Rabs d ≤ @default_rel t)
∧ (∀ i j, Rabs (eta1 i j) ≤ @default_abs t)
∧ (∀ i j, Rabs (eta2 i j) ≤ @default_abs t).
Decompose the outer addition as a pure backward error.
Decompose the mixed error for the scaling x * A.
Decompose the mixed error for the scaling y * B.
General GEMM Error
Theorem GEMM_error :
∀ [m n p]
(A : 'M[ftype t]_(m, n))
(B : 'M[ftype t]_(n, p))
(Y : 'M[ftype t]_(m, p))
(s1 s2 : ftype t)
(Hfin : F.finitemx
(F.addmx (F.scalemx s1 (F.mulmx A B)) (F.scalemx s2 Y))),
∃ ab1 ab2 ab3 ab4 ab5 y1 y2 y3 : 'M[R]_(m, p),
map_mx FT2R
(F.addmx (F.scalemx s1 (F.mulmx A B)) (F.scalemx s2 Y))
= (scalemx (FT2R s1)
((((map_mx FT2R A ×m map_mx FT2R B) + ab1) + ab2) + ab3)
+ ab4) + ab5
+ ((scalemx (FT2R s2) (map_mx FT2R Y + y1) + y2) + y3)
∧ (∀ k : 'I_p,
∃ E0,
col k ab1 = E0 ×m col k (map_mx FT2R B)
∧ (∀ i j,
Rabs (E0 i j) ≤ g n × Rabs (map_mx FT2R A i j)))
∧ (∀ i j, Rabs (ab2 i j) ≤ g1 n n)
∧ (∀ i j,
Rabs (ab3 i j) ≤
@default_rel t ×
Rabs ((((map_mx FT2R A ×m map_mx FT2R B) + ab1) + ab2)%Ri i j))
∧ (∀ i j,
Rabs (y1 i j) ≤ @default_rel t × Rabs (map_mx FT2R Y i j))
∧ (∀ i j, ∃ d,
ab5 i j
= (scalemx (FT2R s1)
((((map_mx FT2R A ×m map_mx FT2R B) + ab1) + ab2) + ab3)
+ ab4) i j × d
∧ Rabs d ≤ @default_rel t)
∧ (∀ i j, ∃ d,
y3 i j
= (scalemx (FT2R s2) (map_mx FT2R Y + y1) + y2) i j × d
∧ Rabs d ≤ @default_rel t)
∧ (∀ i j, Rabs (ab4 i j) ≤ @default_abs t)
∧ (∀ i j, Rabs (y2 i j) ≤ @default_abs t).
Decompose the axpby structure for s1*(A*B) + s2*Y, obtaining
backward addition errors ab5, y3 and forward scaling errors
ab4, y2, together with backward scaling perturbations ab3, y1.
Decompose the matrix-multiplication error for A*B, propagating
finiteness from the s1*(A*B) factor of Hfin.