LAProof.C.partial_csr

Require Import VST.floyd.proofauto.
From LAProof.C Require Import floatlib sparse_model distinct.
Require Import vcfloat.FPStdCompCert.
Require Import vcfloat.FPStdLib.
Require Import Coq.Classes.RelationClasses.


Lemma coo_entry_bounds {t} [coo: coo_matrix t]:
  coo_matrix_wellformed coo
   i,
  0 i < Zlength (coo_entries coo)
  0 fst (fst (Znth i (coo_entries coo))) < coo_rows coo
  0 snd (fst (Znth i (coo_entries coo))) < coo_cols coo.

Definition coo_upto (i: Z) {t} (coo: coo_matrix t) :=
  Build_coo_matrix _ (coo_rows coo) (coo_cols coo) (sublist 0 i (coo_entries coo)).

Definition cd_upto (i: Z) {t} (coo: coo_matrix t) : Z :=
   count_distinct (sublist 0 i (coo_entries coo)).

Lemma sorted_cons_e2: {A} (lt: relation A) a al, sorted lt (a::al) sorted lt al.

Definition entries_correspond {t} (coo: coo_matrix t) (csr: csr_matrix t) :=
h,
0 h < Zlength (coo_entries coo)
let '(r,c) := fst (Znth h (coo_entries coo)) in
let k := cd_upto (h+1) coo - 1 in
  Znth r (csr_row_ptr csr) k < Znth (r+1) (csr_row_ptr csr)
  Znth k (csr_col_ind csr) = c
  sum_any (map snd (filter (coord_eqb (r,c) oo fst) (coo_entries coo))) (Znth k (csr_vals csr)).

Definition no_extra_zeros {t} (coo: coo_matrix t) (csr: csr_matrix t) :=
   r k, 0 r < coo_rows coo
     Znth r (csr_row_ptr csr) k < Znth (r+1) (csr_row_ptr csr)
     let c := Znth k (csr_col_ind csr) in
        In (r,c) (map fst (coo_entries coo)).

Inductive coo_csr {t} (coo: coo_matrix t) (csr: csr_matrix t) : Prop :=
 build_coo_csr:
 
    (coo_csr_rows: coo_rows coo = csr_rows csr)
    (coo_csr_cols: coo_cols coo = csr_cols csr)
    (coo_csr_vals: Zlength (csr_vals csr) = count_distinct (coo_entries coo))
    (coo_csr_entries: entries_correspond coo csr)
    (coo_csr_zeros: no_extra_zeros coo csr),
    coo_csr coo csr.

Inductive partial_CSR (h: Z) (r: Z) (coo: coo_matrix Tdouble)
      (rowptr: list val) (colind: list val) (val: list val) : Prop :=
build_partial_CSR:
   
    (partial_CSR_coo: coo_matrix_wellformed coo)
    (partial_CSR_coo_sorted: sorted coord_le (coo_entries coo))
    (partial_CSR_i: 0 h Zlength (coo_entries coo))
    (partial_CSR_r: -1 r coo_rows coo)
    (partial_CSR_r': Forall (fun efst (fst e) r) (coo_entries (coo_upto h coo)))
    (partial_CSR_r'': Forall (fun efst (fst e) r) (sublist h (Zlength (coo_entries coo)) (coo_entries coo)))
    (csr: csr_matrix Tdouble)
    (partial_CSR_wf: csr_matrix_wellformed csr)
    (partial_CSR_coo_csr: coo_csr (coo_upto h coo) csr)
    (partial_CSR_val: sublist 0 (Zlength (csr_vals csr)) val = map Vfloat (csr_vals csr))
    (partial_CSR_colind: sublist 0 (Zlength (csr_col_ind csr)) colind = map (Vint oo Int.repr) (csr_col_ind csr))
    (partial_CSR_rowptr: sublist 0 (r+1) rowptr = map (Vint oo Int.repr) (sublist 0 (r+1) (csr_row_ptr csr)))
    (partial_CSR_val': Zlength val = count_distinct (coo_entries coo))
    (partial_CSR_colind': Zlength colind = count_distinct (coo_entries coo))
    (partial_CSR_rowptr': Zlength rowptr = coo_rows coo + 1)
    (partial_CSR_dbound: count_distinct (coo_entries coo) Int.max_unsigned),
    partial_CSR h r coo rowptr colind val.

Hint Unfold csr_rows : list_solve_unfold.
Lemma partial_CSR_rowptr': {t} r (coo: coo_matrix t) (csr: csr_matrix t),
   coo_matrix_wellformed coo
   csr_matrix_wellformed csr
   coo_csr coo csr
   -1 r coo_rows coo
   Forall (fun efst (fst e) r) (coo_entries coo)
   sublist (r+1) (coo_rows coo + 1) (csr_row_ptr csr) = Zrepeat (Zlength (csr_vals csr)) (coo_rows coo - r).

Definition matrix_upd {t} (i j: Z) (m: matrix t) (x: ftype t) : matrix t :=
  upd_Znth i m (upd_Znth j (Znth i m) x).

Lemma BPO_eqv_iff: {t} a b, @BPO.eqv _ _ (@CoordBPO t) a b fst a = fst b.

Lemma partial_CSR_duplicate:
     h r coo (f: ftype Tdouble) ROWPTR COLIND VAL,
    0 < h < Zlength (coo_entries coo)
    fst (Znth (h-1) (coo_entries coo)) = fst (Znth h (coo_entries coo))
    r = fst (fst (Znth (h-1) (coo_entries coo)))
    Znth (cd_upto h coo - 1) VAL = Vfloat f
    partial_CSR h r coo ROWPTR COLIND VAL
    partial_CSR (h+1) r coo ROWPTR COLIND
      (upd_Znth (cd_upto h coo - 1) VAL
         (Vfloat (Float.add f (snd (Znth h (coo_entries coo)))))).

Lemma coo_upto_wellformed: {t} i (coo: coo_matrix t),
  0 i Zlength (coo_entries coo)
  coo_matrix_wellformed coo coo_matrix_wellformed (coo_upto i coo).

Lemma coord_sorted_e: {t} (al: list (Z×Z×ftype t)) (H: sorted coord_le al)
   (i j: Z), 0 i j j < Zlength al coord_le (Znth i al) (Znth j al).

Lemma partial_CSR_newcol:
    i r c x coo ROWPTR COLIND VAL,
   0 < i < Zlength (coo_entries coo)
   Znth i (coo_entries coo) = (r, c, x)
   r = fst (fst (Znth (i-1) (coo_entries coo)))
   c snd (fst (Znth (i-1) (coo_entries coo)))
   partial_CSR i r coo ROWPTR COLIND VAL
   partial_CSR (i + 1) r coo ROWPTR
  (upd_Znth (count_distinct (sublist 0 i (coo_entries coo))) COLIND (Vint (Int.repr c)))
  (upd_Znth (count_distinct (sublist 0 i (coo_entries coo))) VAL (Vfloat x)).

Lemma partial_CSR_0: (coo: coo_matrix Tdouble),
  coo_matrix_wellformed coo
    sorted coord_le (coo_entries coo)
 let k := count_distinct (coo_entries coo)
 in k Int.max_unsigned
   partial_CSR 0 (-1) coo (Zrepeat Vundef (coo_rows coo + 1))
  (Zrepeat Vundef k) (Zrepeat Vundef k).

Lemma partial_CSR_skiprow:
     i r coo ROWPTR COLIND VAL,
    0 i < Zlength (coo_entries coo)
    r fst (fst (Znth i (coo_entries coo)))
    partial_CSR i (r-1) coo ROWPTR COLIND VAL
    partial_CSR i r coo
  (upd_Znth r ROWPTR
     (Vint
        (Int.repr (count_distinct (sublist 0 i (coo_entries coo))))))
  COLIND VAL.

Lemma partial_CSR_newrow:
     i r c x coo ROWPTR COLIND VAL,
    0 i < Zlength (coo_entries coo)
    Znth i (coo_entries coo) = (r,c,x)
    (i 0 fst (fst (Znth (i - 1) (coo_entries coo))) r)
    partial_CSR i r coo ROWPTR COLIND VAL
    partial_CSR (i + 1) r coo ROWPTR
     (upd_Znth (count_distinct (sublist 0 i (coo_entries coo))) COLIND
        (Vint (Int.repr c)))
     (upd_Znth (count_distinct (sublist 0 i (coo_entries coo))) VAL
        (Vfloat x)).

Lemma partial_CSR_lastrows:
    r coo ROWPTR COLIND VAL,
    r coo_rows coo
   partial_CSR (Zlength (coo_entries coo)) (r-1) coo ROWPTR COLIND VAL
   partial_CSR (Zlength (coo_entries coo)) r coo
     (upd_Znth r ROWPTR (Vint (Int.repr (count_distinct (coo_entries coo))))) COLIND VAL.

Lemma csr_row_rep_colsnonneg:
    {t} cols (vals: list (ftype t)) col_ind v,
       csr_row_rep cols vals col_ind v
       Zlength vals = Zlength col_ind
       Forall (Z.le 0) (col_ind ++ [cols]).

Lemma matrix_index_Z: {t} (m: matrix t) cols i j,
  matrix_cols m cols
  0 i < matrix_rows m
  0 j < cols
  matrix_index m (Z.to_nat i) (Z.to_nat j) = Znth j (Znth i m).

Lemma coo_to_matrix_build_csr_matrix:
   {t}
  (coo : coo_matrix t)
  (csr : csr_matrix t)
  (partial_CSR_coo : coo_matrix_wellformed coo)
  (partial_CSR_wf : csr_matrix_wellformed csr)
  (partial_CSR_coo_csr : coo_csr coo csr),
  csr_to_matrix csr (build_csr_matrix csr)
  coo_to_matrix coo (build_csr_matrix csr).

Lemma partial_CSR_properties:
   coo ROWPTR COLIND VAL,
    partial_CSR (Zlength (coo_entries coo)) (coo_rows coo) coo ROWPTR COLIND VAL
     (m: matrix Tdouble) (csr: csr_matrix Tdouble),
            csr_to_matrix csr m coo_to_matrix coo m
             coo_rows coo = matrix_rows m
             coo_cols coo = csr_cols csr
             map Vfloat (csr_vals csr) = VAL
             Zlength (csr_col_ind csr) = count_distinct (coo_entries coo)
             map Vint (map Int.repr (csr_col_ind csr)) = COLIND
             map Vint (map Int.repr (csr_row_ptr csr)) = ROWPTR
             Zlength (csr_vals csr) = count_distinct (coo_entries coo).

Lemma partial_CSR_VAL_defined:
   i r coo ROWPTR COLIND VAL h,
    0 i < Zlength (coo_entries coo)
    0 < h count_distinct (sublist 0 i (coo_entries coo))
    partial_CSR i r coo ROWPTR COLIND VAL
    is_float (Znth (h-1) VAL).