LAProof.C.sparse_model

Require Import VST.floyd.proofauto.
Require Import vcfloat.VCFloat.
Require Import vcfloat.FPStdCompCert.
Require Import VSTlib.spec_math.
From LAProof.C Require Import floatlib distinct.


Open Scope logic.

Record csr_matrix {t: type} := {
  csr_cols: Z;
  csr_vals: list (ftype t);
  csr_col_ind: list Z;
  csr_row_ptr: list Z;
  csr_rows: Z := Zlength (csr_row_ptr) - 1
}.
Arguments csr_matrix t : clear implicits.

Inductive csr_row_rep {t: type} :
   (cols: Z) (vals: list (ftype t)) (col_ind: list Z)
  (v: list (ftype t)), Prop :=
 | csr_row_rep_nil: csr_row_rep 0%Z nil nil nil
 | csr_row_rep_zero: cols vals col_ind v,
          csr_row_rep (cols-1) vals (map Z.pred col_ind) v
          csr_row_rep cols vals col_ind (Zconst t 0 :: v)
 | csr_row_rep_val: cols x vals col_ind v,
          csr_row_rep (cols-1) vals (map Z.pred col_ind) v
          csr_row_rep cols (x::vals) (0%Z::col_ind) (x::v).

Definition csr_to_matrix {t} (csr: csr_matrix t) (mval: matrix t) :=
  Zlength (csr_row_ptr csr) = 1 + Zlength mval
  Zlength (csr_vals csr) = Znth (Zlength mval) (csr_row_ptr csr)
  Zlength (csr_col_ind csr) = Znth (Zlength mval) (csr_row_ptr csr)
  list_solver.sorted Z.le (0 :: csr_row_ptr csr ++ [Int.max_unsigned])
   j, 0 j < Zlength mval
        csr_row_rep (csr_cols csr)
             (sublist (Znth j (csr_row_ptr csr)) (Znth (j+1) (csr_row_ptr csr)) (csr_vals csr))
             (sublist (Znth j (csr_row_ptr csr)) (Znth (j+1) (csr_row_ptr csr)) (csr_col_ind csr))
             (Znth j mval).

Lemma sorted_app_e1:
  {A} {HA: Inhabitant A} (le: A A Prop) al bl,
  list_solver.sorted le (al++bl) list_solver.sorted le al.

Lemma csr_to_matrix_rows {t: type}:
    (mval: matrix t) (csr: csr_matrix t),
   csr_to_matrix csr mval
  csr_rows csr = matrix_rows mval.

Lemma csr_to_matrix_cols {t: type}:
    (mval: matrix t) (csr: csr_matrix t),
   csr_to_matrix csr mval matrix_cols mval (csr_cols csr).

Lemma csr_row_rep_cols_nonneg:
  {t} cols (vals: list (ftype t)) col_ind vval,
  csr_row_rep cols vals col_ind vval
  0 cols.

Lemma csr_row_rep_col_range:
  {t} cols (vals: list (ftype t)) col_ind vval,
  csr_row_rep cols vals col_ind vval
    j, 0 j < Zlength col_ind 0 Znth j col_ind < cols.

Lemma csr_row_rep_property:
  {t} (P: ftype t Prop) cols (vals: list (ftype t)) col_ind vval,
  csr_row_rep cols vals col_ind vval
  Forall P vval Forall P vals.

Inductive csr_matrix_wellformed {t} (csr: csr_matrix t) : Prop :=
 build_csr_matrix_wellformed:
  (CSR_wf_rows: 0 csr_rows csr)
        (CSR_wf_cols: 0 csr_cols csr)
        (CSR_wf_vals: Zlength (csr_vals csr) = Zlength (csr_col_ind csr))
        (CSR_wf_vals': Zlength (csr_vals csr) = Znth (csr_rows csr) (csr_row_ptr csr))
        (CSR_wf_sorted: list_solver.sorted Z.le (0 :: csr_row_ptr csr ++ [Int.max_unsigned]))
        (CSR_wf_rowsorted: r, 0 r < csr_rows csr
              sorted Z.lt
                (-1 :: sublist (Znth r (csr_row_ptr csr)) (Znth (r+1) (csr_row_ptr csr)) (csr_col_ind csr) ++ [csr_cols csr])),
    csr_matrix_wellformed csr.

Lemma rowptr_sorted_e:
  row_ptr (H: list_solver.sorted Z.le (0 :: row_ptr ++ [Int.max_unsigned]))
       (i j: Z),
   0 i j j < Zlength row_ptr
   0 Znth i row_ptr Znth j row_ptr Znth j row_ptr Int.max_unsigned.

Lemma rowptr_sorted_e1:
   row_ptr (H: list_solver.sorted Z.le (0 :: row_ptr ++ [Int.max_unsigned]))
       (i: Z),
   0 i < Zlength row_ptr
   0 Znth i row_ptr Int.max_unsigned.

Fixpoint build_csr_row {t} (cols: Z) (vals: list (ftype t)) (col_ind: list Z) : list (ftype t) :=
 match vals, col_ind with
 | v::vals', c::col_ind'Zrepeat (Zconst t 0) c ++ v ::
                            build_csr_row (cols-c-1) vals' (map (fun jj-c-1) col_ind')

 | _, _Zrepeat (Zconst t 0) cols
 end.

Fixpoint build_csr_rows {t} (csr: csr_matrix t) (k: Z) (row_ptr: list Z) : list (list (ftype t)) :=
 match row_ptr with
 | []nil
 | k'::row_ptr'build_csr_row (csr_cols csr) (sublist k k' (csr_vals csr))
                  (sublist k k' (csr_col_ind csr)) ::
                  build_csr_rows csr k' row_ptr'
 end.

Definition build_csr_matrix {t} (csr: csr_matrix t) : matrix t :=
 match csr_row_ptr csr with
 | k::row_ptr'build_csr_rows csr k row_ptr'
 | [][]
 end.

Lemma build_csr_row_correct:
   {t} cols (vals: list (ftype t)) col_ind,
     0 cols
     Zlength vals = Zlength col_ind
     sorted Z.lt (-1 :: col_ind ++ [cols])
    csr_row_rep cols vals col_ind (build_csr_row cols vals col_ind).

Lemma build_csr_matrix_correct:
   {t} (csr: csr_matrix t),
  csr_matrix_wellformed csr
  csr_to_matrix csr (build_csr_matrix csr).

Fixpoint rowmult {t} (s: ftype t)
            (vals: list (ftype t)) (col_ind: list Z) (vval: list (ftype t)) :=
 match vals, col_ind with
  | v1::vals', c1::col_ind'rowmult (BFMA v1 (Znth c1 vval) s) vals' col_ind' vval
  | _, _s
 end.

Add Parametric Morphism {t: type} : rowmult
  with signature (@feq t) ==> Forall2 feq ==> @eq (list Z) ==> Forall2 feq ==> feq
  as rowmult_mor.

Add Parametric Morphism {t: type} : rowmult
  with signature (@feq t) ==> Forall2 strict_feq ==> @eq (list Z) ==> Forall2 strict_feq ==> feq
  as rowmult_stricter_mor.

Definition partial_row {t} (i: Z) (h: Z) (vals: list (ftype t)) (col_ind: list Z) (row_ptr: list Z)
                (vval: vector t) : ftype t :=
 let vals' := sublist (Znth i row_ptr) h vals in
 let col_ind' := sublist (Znth i row_ptr) h col_ind in
   rowmult (Zconst t 0) vals' col_ind' vval.

Lemma partial_row_start:
  {t} i (mval: matrix t) csr vval,
  csr_to_matrix csr mval
  partial_row i (Znth i (csr_row_ptr csr)) (csr_vals csr) (csr_col_ind csr) (csr_row_ptr csr) vval = Zconst t 0.

Lemma strict_feq_i:
  {t} (x: ftype t), finite x strict_feq x x.

Lemma strict_floatlist_eqv_i:
   {t} (vec: list (ftype t)), Forall finite vec Forall2 strict_feq vec vec.
#[export] Hint Resolve strict_feq_i strict_floatlist_eqv_i : core.

Lemma partial_row_end:
  {t} i (mval: matrix t) csr vval
  (FINvval: Forall finite vval)
  (FINmval: Forall (Forall finite) mval)
  (LEN: Zlength vval = csr_cols csr),
  0 i < matrix_rows mval
  csr_to_matrix csr mval
  feq (partial_row i (Znth (i+1) (csr_row_ptr csr)) (csr_vals csr) (csr_col_ind csr) (csr_row_ptr csr) vval)
      (Znth i (matrix_vector_mult mval vval)).

Lemma rowmult_app:
  {t} (s: ftype t) vals1 col_ind1 vals2 col_ind2 vvals,
   Zlength vals1 = Zlength col_ind1
   rowmult s (vals1++vals2) (col_ind1++col_ind2) vvals =
   rowmult (rowmult s vals1 col_ind1 vvals) vals2 col_ind2 vvals.

Lemma partial_row_next:
  {t} i h (mval: matrix t) csr vval,
  0 Znth i (csr_row_ptr csr)
  Znth i (csr_row_ptr csr) h < Zlength (csr_vals csr)
  Zlength (csr_vals csr) = Zlength (csr_col_ind csr)
  csr_to_matrix csr mval
partial_row i (h + 1) (csr_vals csr) (csr_col_ind csr) (csr_row_ptr csr) vval =
BFMA (Znth h (csr_vals csr)) (Znth (Znth h (csr_col_ind csr)) vval)
  (partial_row i h (csr_vals csr) (csr_col_ind csr) (csr_row_ptr csr) vval).

Inductive sum_any {t}: (v: vector t) (s: ftype t), Prop :=
| Sum_Any_0: sum_any nil (Zconst t 0)
| Sum_Any_1: x, sum_any [x] x
| Sum_Any_split: al bl a b, sum_any al a sum_any bl b sum_any (al++bl) (BPLUS a b)
| Sum_Any_perm: al bl s, Permutation al bl sum_any al s sum_any bl s.

Require LAProof.accuracy_proofs.common.

Lemma sum_any_accuracy{t}: (v: vector t) (s: ftype t),
  let mag := fold_left Rmax (map FT2R v) R0 in
  sum_any v s
  (Rabs (fold_left Rplus (map FT2R v) R0 - FT2R s) @common.g t (length v) × (INR (length v) × mag))%R.

Record coo_matrix {t: type} := {
  coo_rows: Z;
  coo_cols: Z;
  coo_entries: list (Z × Z × ftype t)
}.
Arguments coo_matrix t : clear implicits.

Definition coo_matrix_wellformed {t} (coo: coo_matrix t) :=
 (0 coo_rows coo 0 coo_cols coo)
   Forall (fun e ⇒ 0 fst (fst e) < coo_rows coo 0 snd (fst e) < coo_cols coo)
      (coo_entries coo).

Definition coo_matrix_equiv {t: type} (a b : coo_matrix t) :=
  coo_rows a = coo_rows b coo_cols a = coo_cols b
   Permutation (coo_entries a) (coo_entries b).

Lemma coo_matrix_wellformed_equiv {t: type} (a b: coo_matrix t):
   coo_matrix_equiv a b coo_matrix_wellformed a coo_matrix_wellformed b.

Definition coord_eqb (a b: Z × Z) :=
       andb (Z.eqb (fst a) (fst b)) (Z.eqb (snd a) (snd b)).

Definition coo_to_matrix {t: type} (coo: coo_matrix t) (m: matrix t) : Prop :=
  coo_rows coo = matrix_rows m
  matrix_cols m (coo_cols coo)
    i, 0 i < coo_rows coo
     j, 0 j < coo_cols coo
     sum_any (map snd (filter (coord_eqb (i,j) oo fst) (coo_entries coo)))
          (matrix_index m (Z.to_nat i) (Z.to_nat j)).

Lemma coo_to_matrix_equiv:
   {t} (m: matrix t) (coo coo': coo_matrix t),
    coo_matrix_equiv coo coo' coo_to_matrix coo m coo_to_matrix coo' m.

Lemma coo_matrix_equiv_refl:
   {t} (a : coo_matrix t), coo_matrix_equiv a a.

Lemma coo_matrix_equiv_symm:
   {t} (a b : coo_matrix t), coo_matrix_equiv a b coo_matrix_equiv b a.

Lemma coo_matrix_equiv_trans:
   {t} (a b c : coo_matrix t), coo_matrix_equiv a b coo_matrix_equiv b c coo_matrix_equiv a c.

Definition coord_le {t} (a b : Z×Z×ftype t) : Prop :=
  fst (fst a) < fst (fst b)
  fst (fst a) = fst (fst b) snd (fst a) snd (fst b).

Definition coord_leb {t} (a b : Z×Z×ftype t) : bool :=
  orb (fst (fst a) <? fst (fst b))
       (andb (fst (fst a) =? fst (fst b)) (snd (fst a) <=? snd (fst b))).

Lemma reflect_coord_le {t} a b : reflect (@coord_le t a b) (@coord_leb t a b).

Instance CoordBO {t}: BoolOrder (@coord_le t) :=
  {| test := coord_leb; test_spec := reflect_coord_le |}.

Instance CoordPO {t: type}: PreOrder (@coord_le t).

Instance CoordBPO {t: type}: BPO.BoolPreOrder (@coord_le t) :=
 {| BPO.BO := CoordBO; BPO.PO := CoordPO |}.