LAProof.accuracy_proofs.mv_mathcomp


From LAProof.accuracy_proofs Require Import preamble common
    dotprod_model sum_model dot_acc float_acc_lems.

From mathcomp.algebra_tactics Require Import ring.

Open Scope ring_scope.
Open Scope order_scope.

Definition sum_abs {m n} (A: 'M[R]_(m,n)) i : R:= \sum_j (Rabs (A i j)).
Definition normv {m} (v: 'cV[R]_m) : R:= \big[maxr/0]_(i < m) Rabs (v i 0%Ri).
Definition normM {m n} (A: 'M[R]_(m,n)) : R:= \big[maxr/0]_i (sum_abs A i).
Definition seq_of_rV {T}[n] (x: 'rV[T]_n) := map (x ord0) (ord_enum n).

Given a variable i of type Z or nat, replace it everywhere with a variable i of type 'I_n, appropriately coerced.
Ltac ordify n i :=
  let Hi := fresh "H" i in
  let Hj := fresh "H" i in
  let j := fresh "i" in
  match type of i with ?tlet t' := eval hnf in t in match t' with
    | Zassert (Hi: Datatypes.is_true (ssrnat.leq (S (Z.to_nat i)) n)) by lia;
               set (j := @Ordinal n (Z.to_nat i) Hi);
               assert (Hj : i = Z.of_nat (nat_of_ord j)) by (simpl; lia)
    | natassert (Hi: Datatypes.is_true (ssrnat.leq (S i) n)) by lia;
                  set (j := @Ordinal n i Hi);
                  assert (Hj : i = nat_of_ord j) by (simpl; lia)
   end end;
   clearbody j; clear Hi;
   subst i;
   rename j into i.

Ltac case_splitP j :=
  first [is_var j | fail 1 "case_splitP requires a variable, but got" j];
 match type of j with 'I_(addn ?a ?b)
  let i := fresh "j" in let H := fresh in
  destruct (splitP j) as [i H | i H];
 [replace j with (@lshift a b i); [ | apply ord_inj; simpl; lia]
 |replace j with (@rshift a b i); [ | apply ord_inj; simpl; lia]];
 clear j H; rename i into j
 end.

Example of how to use case_splitP
Local Remark mul_mx_row' [R : GRing.SemiRing.type] m n p1 p2
    (A: 'M[R]_(m,n)) (Bl: 'M[R]_(n,p1)) (Br: 'M[R]_(n,p2)):
  A ×m row_mx Bl Br = row_mx (A ×m Bl) (A ×m Br).

Example of how the mathcomp experts do this another way, from mathcomp.algebra.matrix
Local Remark mul_mx_row'' [R : GRing.SemiRing.type] m n p1 p2 (A : 'M[R]_(m, n)) (Bl : 'M_(n, p1)) (Br : 'M_(n, p2)) :
  A ×m row_mx Bl Br = row_mx (A ×m Bl) (A ×m Br).

Lemma nth_List_nth: {A: Type} (d: A) (l: seq.seq A) (n: nat),
  seq.nth d l n = List.nth n l d.

Lemma pred_lt: [n: nat], (0 < n n.-1 < n)%nat.

Definition pred_ord [n: nat] (Hn: (0 < n)%nat) : 'I_n := Ordinal (pred_lt Hn).

Lemma ordinal_enum_size: n: nat,
  size (Finite.enum (ordinal n)) = n.

Lemma size_ord_enum: n, size (ord_enum n) = n.

Lemma nth_index_enum: {n: nat} (x: 'I_n) y,
  seq.nth y (index_enum (ordinal n)) x = x.

Lemma nth_ord_enum': n (i: 'I_n) x, seq.nth x (ord_enum n) i = i.

Lemma index_ord_enum: (n: nat), (index_enum (ordinal n)) = ord_enum n.

Lemma size_seq_of_rV : {T} [n] x, size (@seq_of_rV T n x) = n.

Lemma nth_seq_of_rV: {T}[n](d: T)(x: 'rV[T]_n) (i: 'I_n), nth d (seq_of_rV x) i = x ord0 i.

Lemma maxrC : @commutative R R maxr.

Lemma maxrA : @associative R maxr.

Lemma big_mul {n:nat} (F : ordinal n R) op a:
( i b, op (F i) b × a = op (F i × a) (b × a))
R0 a \big[op/0]_(i0 < n) (F i0) × a = \big[op/0]_(i0 < n) (F i0 × a).

Lemma big_max_mul {n:nat} (F : ordinal n R) a:
R0 a \big[maxr/0]_(i0 < n) (F i0) × a = \big[maxr/0]_(i0 < n) (F i0 × a).


Lemma normv_pos {m} (v: 'cV[R]_m) : R0 normv v.

Lemma normM_pos [m n] (A: 'M[R]_(m,n)) : R0 normM A.

Lemma Rabs_sum (n:nat) : (F : ordinal n R),
Rabs (\sum_j F j) \sum_j Rabs (F j).

Lemma subMultNorm m n (A: 'M[R]_(m,n)) (u : 'cV_n) :
  normv ( A ×m u ) normM A × normv u.

Lemma normv_triang m (u v: 'cV_m) :
  normv ( u + v ) normv u + normv v.

Local Definition crazy (T: Type): 'I_0 T.
Defined.

Lemma exists_mx: {T} [m n] (F: 'I_m 'I_n T Prop),
  ( i j, x, F i j x)
   A: 'M[T]_(m,n), i j, F i j (A i j).

Lemma rev_ord_enum: n, rev (ord_enum n) = map (@rev_ord n) (ord_enum n).

Lemma nth_ord_enum_lemma:
  [T] (d: T) (u: seq T),
   u = map (nth d u \o @nat_of_ord (size u)) (ord_enum (size u)).

Lemma sumR_sum: (x: seq R), sumR x = \sum_(i in 'I_(size x)) nth R0 x (nat_of_ord i).

Module F.
Section WithNAN.
Context {NAN: FPCore.Nans} {t : type}.

Definition sum [n: nat] (x: 'I_n ftype t) : ftype t :=
    \big[BPLUS / neg_zero]_i x (rev_ord i).

Definition dotprod [n: nat] (x: 'rV[ftype t]_n) (y: 'cV[ftype t]_n) : ftype t :=
   \big[BPLUS / pos_zero]_i (BMULT (x ord0 (rev_ord i)) (y (rev_ord i) ord0)).

Definition FMA_dotprod [n: nat] (x: 'rV[ftype t]_n) (y: 'cV[ftype t]_n) : ftype t :=
   fma_dotprod (seq_of_rV x) (seq_of_rV y^T).

Definition mulmx [m n p] (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,p)) :=
 \matrix_(i,k) dotprod (row i A) (col k B).

Definition FMA_mulmx [m n p] (A: 'M[ftype t]_(m,n)) (B: 'M[ftype t]_(n,p)) :=
 \matrix_(i,k) FMA_dotprod (row i A) (col k B).

Definition scalemx [m n] (a: ftype t) (M: 'M[ftype t]_(m,n)) :=
  map_mx (BMULT a) M.

Definition addmx [m n] (A B: 'M[ftype t]_(m,n)) : 'M[ftype t]_(m,n) :=
  \matrix_(i,j) BPLUS (A i j) (B i j).

Lemma mulmx_row:
  m n p1 p2 (A: 'M[ftype t]_(m,n)) (Bl: 'M_(n,p1)) (Br: 'M_(n,p2)),
  mulmx A (row_mx Bl Br) = row_mx (mulmx A Bl) (mulmx A Br).

Lemma FMA_mulmx_row:
  m n p1 p2 (A: 'M[ftype t]_(m,n)) (Bl: 'M_(n,p1)) (Br: 'M_(n,p2)),
  FMA_mulmx A (row_mx Bl Br) = row_mx (FMA_mulmx A Bl) (FMA_mulmx A Br).

Lemma mulmx_col:
  m1 m2 n p (Au: 'M[ftype t]_(m1,n)) (Ad: 'M[ftype t]_(m2,n)) (B: 'M_(n,p)),
  mulmx (col_mx Au Ad) B = col_mx (mulmx Au B) (mulmx Ad B).

Lemma FMA_mulmx_col:
  m1 m2 n p (Au: 'M[ftype t]_(m1,n)) (Ad: 'M[ftype t]_(m2,n)) (B: 'M_(n,p)),
  FMA_mulmx (col_mx Au Ad) B = col_mx (FMA_mulmx Au B) (FMA_mulmx Ad B).

Lemma sum_sumF: [n] (x: 'I_n ftype t), sum x = sumF (map x (ord_enum n)).

Lemma dotprod_dotprodF:
   [n] (x: 'rV[ftype t]_n) (y: 'cV[ftype t]_n),
  dotprod x y = dotprodF (seq_of_rV x) (seq_of_rV (trmx y)).

Lemma mulmx_dotprodF:
   [n] (A: 'M[ftype t]_(1,n)) (B: 'M[ftype t]_(n,1)),
 mulmx A B = const_mx (dotprodF (seq_of_rV A) (seq_of_rV (trmx B))).

Lemma FMA_mulmx_fma_dotprod:
   [n] (A: 'M[ftype t]_(1,n)) (B: 'M[ftype t]_(n,1)),
 FMA_mulmx A B = const_mx (fma_dotprod (seq_of_rV A) (seq_of_rV (trmx B))).

Definition finitemx [m n] (A: 'M[ftype t]_(m,n)) : Prop :=
   ( i j, Binary.is_finite (A i j)).

Lemma finitemx_addmx_e: [m n] (A B: 'M[ftype t]_(m,n)),
  finitemx (addmx A B) finitemx A finitemx B.

Lemma finitemx_scalemx_e: [m n] (c: ftype t) (A: 'M[ftype t]_(m,n)),
  finitemx (scalemx c A) finitemx A.

End WithNAN.

End F.

Definition listlist_of_mx {T} [m n: nat] (A: 'M[T]_(m,n)) : list (list T) :=
  map (fun i: 'I_mmap (A i) (ord_enum n)) (ord_enum m).

Definition list_of_cV {T} [n: nat] (V: 'cV[T]_n) : list T :=
   map (fun iV i ord0) (ord_enum n).

Definition mx_of_listlist {T} {d: T} (rows cols: nat) (mval: list (list T)) : 'M[T]_(rows, cols) :=
 \matrix_(i,j) seq.nth (d: T) (seq.nth nil mval i) j.

Definition cV_of_list {T} {d: T} (n: nat) (vval: list T) : 'cV[T]_n :=
  \matrix_(i,j) seq.nth (d:T) vval i.

Definition matrix_cols_nat {T} (m: list (list T)) (cols: nat) :=
    Forall (fun rsize r = cols) m.

Lemma listlist_of_mx_of_listlist:
   {t} {d} rows cols (mval: list (list (ftype t))),
   rows = Datatypes.length mval
   matrix_cols_nat mval cols
   listlist_of_mx (@mx_of_listlist _ d rows cols mval) = mval.

Lemma mx_of_listlist_of_mx:
   {T} {d:T} rows cols (A: 'M[T]_(rows,cols)),
   @mx_of_listlist _ d rows cols (listlist_of_mx A) = A.

Lemma list_of_cV_of_list:
   {T} {d:T} n (vval: list T),
   size vval = n
   list_of_cV (@cV_of_list _ d n vval) = vval.

Lemma cV_of_list_of_cV:
   {T} `{d:T} n (x: 'cV[T]_n),
  @cV_of_list _ d n (list_of_cV x) = x.

Lemma matrix_rows_listlist_of_mx: {T} [rows cols] (A: 'M[T]_(rows,cols)),
   size (listlist_of_mx A) = rows.

Lemma matrix_cols_listlist_of_mx: {T} [rows cols] (A: 'M[T]_(rows,cols)),
  matrix_cols_nat (listlist_of_mx A) cols.

Lemma size_list_of_cV: {T} [n] (vval: 'cV[T]_n),
  size (list_of_cV vval) = n.

Lemma nth_list_of_cV:
   {T} {d:T} [n] (vval: 'cV[T]_n) (i: 'I_n),
   nth d (list_of_cV vval) (nat_of_ord i) = vval i ord0.

Definition list_dotprod {NAN: FPCore.Nans} {t: type} (v1 v2: list (ftype t)) : ftype t :=
  foldl (fun s x12BFMA (fst x12) (snd x12) s) (Zconst t 0) (zip v1 v2) .

Definition matrix_vector_mult {NAN: FPCore.Nans}{t: type} (m: list (list (ftype t))) (v: list (ftype t)) : list (ftype t) :=
      map (fun rowlist_dotprod row v) m.

Lemma list_of_cV_col_mx: {T} n1 n2 (x: 'cV[T]_n1) (y: 'cV[T]_n2),
  list_of_cV (col_mx x y) = list_of_cV x ++ list_of_cV y.

Lemma map_const_len:
   {A B} (c: B) (al: list A), map (fun _c) al = repeat c (length al).

Lemma listlist_of_mx_col_mx: {T} n1 n2 m (A: 'M[T]_(n1,m)) (B: 'M[T]_(n2,m)),
  listlist_of_mx (col_mx A B) = listlist_of_mx A ++ listlist_of_mx B.

Lemma listlist_of_mx_inj: {T} [m n] (A B: 'M[T]_(m,n)),
  listlist_of_mx A = listlist_of_mx B A=B.

Lemma Fmulmx_matrix_vector_mult:
  {NAN: FPCore.Nans}{t} rows cols (mval: list (list (ftype t))) (vval: list (ftype t)),
   rows = size mval
   cols = size vval
   matrix_cols_nat mval cols
   matrix_vector_mult mval vval = list_of_cV (F.FMA_mulmx (@mx_of_listlist _ (Zconst t 0) rows cols mval)
                                                                                        (@cV_of_list _ (Zconst t 0) cols vval)).