Theory Stream_More_Corec_Upto3

theory Stream_More_Corec_Upto3
imports Stream_Corec_Upto3
header {* More on corecursion and coinduction up to *}

theory Stream_More_Corec_Upto3
imports Stream_Corec_Upto3
begin


subsection{* A natural-transformation-based version of the up-to corecursion principle *}

definition algρ3 :: "J K3 => J" where "algρ3 ≡ eval3 o K3_as_ΣΣ3"

lemma dd3_K3_as_ΣΣ3:
"dd3 o K3_as_ΣΣ3 = ρ3"
unfolding K3_as_ΣΣ3_def dd3_def
unfolding o_assoc apply(subst o_assoc[symmetric])
unfolding ddd3_\<oo>\<pp>3 unfolding o_assoc snd_convol
unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
unfolding Σ3.map_comp0[symmetric] ddd3_leaf3 Λ3_natural
unfolding o_assoc F_map_comp[symmetric] leaf3_flat3 F_map_id id_o Λ3_Inr ..

lemma algρ3: "dtor_J o algρ3 = F_map eval3 o ρ3 o K3_map <id, dtor_J>"
unfolding dd3_K3_as_ΣΣ3[symmetric] o_assoc
unfolding o_assoc[symmetric] K3_as_ΣΣ3_natural
unfolding o_assoc eval3 algρ3_def ..

lemma flat3_embL3: "flat3 o embL3 o ΣΣ2_map embL3 = embL3 o flat2" (is "?L = ?R")
proof-
  have "?L = ext2 (\<oo>\<pp>3 o Abs_Σ3 o Inl) embL3"
  proof(rule ext2_unique)
    show "flat3 o embL3 o ΣΣ2_map embL3 o leaf2 = embL3"
    unfolding o_assoc[symmetric] unfolding leaf2_natural
    unfolding o_assoc apply(subst o_assoc[symmetric])
    apply(subst embL3_def) unfolding ext2_comp_leaf2 flat3_leaf3 id_o ..
  next
    show "flat3 o embL3 o ΣΣ2_map embL3 o \<oo>\<pp>2 = \<oo>\<pp>3 o Abs_Σ3 o Inl o Σ2_map (flat3 o embL3 o ΣΣ2_map embL3)"
    apply(subst o_assoc[symmetric]) unfolding embL3_natural
    unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric],
                            subst o_assoc[symmetric])
    unfolding embL3_def unfolding ext2_commute unfolding embL3_def[symmetric]
    unfolding o_assoc apply(subst o_assoc[symmetric])
    unfolding \<oo>\<pp>3_natural[symmetric]
    unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
    unfolding map_sum_Inl Abs_Σ3_natural
    unfolding o_assoc[symmetric] map_sum_Inl Σ2.map_comp0[symmetric] embL3_natural[symmetric]
    apply(rule sym) apply(subst Σ2.map_comp0) unfolding o_assoc
    unfolding flat3_def unfolding ext3_commute
    apply(rule sym) apply(subst o_assoc[symmetric])
    unfolding Abs_Σ3_natural unfolding o_assoc[symmetric] map_sum_Inl \<oo>\<pp>3_natural[symmetric] ..
  qed
  also have "... = ?R"
  proof(rule sym, rule ext2_unique)
    show "embL3 o flat2 o leaf2 = embL3"
    unfolding o_assoc[symmetric] flat2_leaf2 o_id ..
  next
    show "embL3 o flat2 o \<oo>\<pp>2 = \<oo>\<pp>3 o Abs_Σ3 o Inl o Σ2_map (embL3 o flat2)"
    unfolding flat2_def o_assoc[symmetric] ext2_commute
    unfolding o_assoc
    apply(subst embL3_def) unfolding ext2_commute apply(subst embL3_def[symmetric])
    unfolding Σ2.map_comp0 o_assoc ..
  qed
  finally show ?thesis .
qed

lemma ddd3_embL3: "ddd3 o embL3 = (embL3 ** F_map embL3) o ddd2" (is "?L = ?R")
proof-
  have "?L = ext2 <\<oo>\<pp>3 o Abs_Σ3 o Inl o Σ2_map fst, F_map (flat3 o embL3) o Λ2> (leaf3 ** F_map leaf3)"
  proof(rule ext2_unique)
    show "ddd3 o embL3 o leaf2 = leaf3 ** F_map leaf3"
    apply(rule fst_snd_cong)
    unfolding fst_comp_map_prod snd_comp_map_prod
    unfolding  embL3_def
    apply (subst (3) o_assoc[symmetric]) defer apply (subst (3) o_assoc[symmetric])
    unfolding ext2_comp_leaf2
    unfolding ddd3_def ext3_comp_leaf3 fst_comp_map_prod snd_comp_map_prod by(rule refl, rule refl)
  next
    show "ddd3 o embL3 o \<oo>\<pp>2 =
          <\<oo>\<pp>3 o Abs_Σ3 o Inl o Σ2_map fst , F_map (flat3 o embL3) o Λ2> o Σ2_map (ddd3 o embL3)" (is "?A = ?B")
    proof(rule fst_snd_cong)
      show "fst o ?A = fst o ?B"
      unfolding o_assoc fst_convol unfolding o_assoc[symmetric] Σ2.map_comp0[symmetric]
      unfolding o_assoc
      apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
      apply(subst embL3_def) unfolding ext2_commute apply(subst embL3_def[symmetric])
      unfolding o_assoc apply(subst o_assoc[symmetric])
      apply(subst ddd3_def) unfolding ext3_commute apply(subst ddd3_def[symmetric])
      unfolding o_assoc fst_convol
      apply(subst o_assoc[symmetric]) unfolding Σ3.map_comp0[symmetric]
      apply(subst o_assoc[symmetric]) unfolding Abs_Σ3_natural map_sum_Inl o_assoc[symmetric]
      unfolding Σ2.map_comp0[symmetric] o_assoc ..
    next
      show "snd o ?A = snd o ?B"
      unfolding o_assoc snd_convol unfolding o_assoc[symmetric]
      apply(subst embL3_def) unfolding ext2_commute apply(subst embL3_def[symmetric])
      unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric])
      apply(subst ddd3_def) unfolding ext3_commute apply(subst ddd3_def[symmetric])
      unfolding o_assoc snd_convol
      unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
      unfolding Abs_Σ3_natural map_sum_Inl o_assoc[symmetric]
      unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric])
      unfolding Λ3_Inl unfolding Σ2.map_comp0 F_map_comp o_assoc ..
    qed
  qed
  also have "... = ?R"
  proof(rule sym, rule ext2_unique)
    show "(embL3 ** F_map embL3) o ddd2 o leaf2 = leaf3 ** F_map leaf3"
    unfolding o_assoc apply(subst o_assoc[symmetric])
    apply(subst ddd2_def) unfolding ext2_comp_leaf2
    unfolding map_prod.comp unfolding F_map_comp[symmetric]
    apply(subst embL3_def, subst embL3_def) unfolding ext2_comp_leaf2 ..
  next
    show "embL3 ** F_map embL3 o ddd2 o \<oo>\<pp>2 =
          <\<oo>\<pp>3 o Abs_Σ3 o Inl o Σ2_map fst , F_map (flat3 o embL3) o Λ2> o Σ2_map (embL3 ** F_map embL3 o ddd2)"
    (is "?A = ?B") proof(rule fst_snd_cong)
      show "fst o ?A = fst o ?B"
      unfolding o_assoc fst_convol fst_comp_map_prod
      unfolding o_assoc[symmetric] Σ2.map_comp0[symmetric] unfolding o_assoc unfolding fst_comp_map_prod
      unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
      apply(subst ddd2_def) unfolding ext2_commute apply(subst ddd2_def[symmetric])
      unfolding o_assoc fst_convol
      apply(subst embL3_def) unfolding ext2_commute apply(subst embL3_def[symmetric])
      unfolding Σ2.map_comp0 o_assoc ..
    next
      show "snd o ?A = snd o ?B"
      unfolding o_assoc snd_convol snd_comp_map_prod
      unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
      apply(subst ddd2_def) unfolding ext2_commute apply(subst ddd2_def[symmetric])
      unfolding o_assoc apply(subst o_assoc[symmetric]) unfolding snd_convol
      unfolding o_assoc F_map_comp[symmetric]
      unfolding flat3_embL3[symmetric]
      unfolding F_map_comp
      unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
      unfolding Λ2_natural[symmetric]
      unfolding o_assoc Σ2.map_comp0 ..
    qed
  qed
  finally show ?thesis .
qed

lemma dd3_embL3: "dd3 o embL3 = F_map embL3 o dd2"
unfolding dd3_def dd2_def o_assoc[symmetric] ddd3_embL3 by auto

lemma F_map_embL3_ΣΣ2_map:
"F_map embL3 o dd2 o ΣΣ2_map <id , dtor_J> =
 dd3 o ΣΣ3_map <id , dtor_J> o embL3"
unfolding o_assoc[symmetric] unfolding embL3_natural[symmetric]
unfolding o_assoc dd3_embL3 ..

lemma eval3_embL3: "eval3 o embL3 = eval2"
unfolding eval2_def apply(rule J.dtor_unfold_unique)
unfolding eval3_def unfolding o_assoc
unfolding dtor_unfold_J_pointfree
unfolding F_map_comp
apply(subst o_assoc[symmetric], subst o_assoc[symmetric])
unfolding F_map_embL3_ΣΣ2_map o_assoc ..

theorem algΛ3_algΛ2_algρ3:
"algΛ3 o Abs_Σ3 = case_sum algΛ2 algρ3" (is "?L = ?R")
proof(rule sum_comp_cases)
  show "?L o Inl = ?R o Inl"
  unfolding case_sum_o_inj apply(subst dtor_J_o_inj[symmetric])
  unfolding o_assoc dtor_J_algΛ3 unfolding Abs_Σ3_natural o_assoc[symmetric] map_sum_Inl
  unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric]) unfolding Λ3_Inl
  unfolding o_assoc F_map_comp[symmetric] eval3_embL3 dtor_J_algΛ2 ..
next
  show "?L o Inr = ?R o Inr"
  unfolding algρ3_def case_sum_o_inj algΛ3_def K3_as_ΣΣ3_def o_assoc ..
qed

theorem algΛ3_Inl: "algΛ3 (Abs_Σ3 (Inl x)) = algΛ2 x" (is "?L = ?R")
unfolding o_eq_dest_lhs[OF algΛ3_algΛ2_algρ3] by simp

lemma algΛ3_Inl_pointfree: "algΛ3 o Abs_Σ3 o Inl = algΛ2"
unfolding o_def fun_eq_iff algΛ3_Inl by simp

theorem algΛ3_Inr: "algΛ3 (Abs_Σ3 (Inr x)) = algρ3 x" (is "?L = ?R")
unfolding o_eq_dest_lhs[OF algΛ3_algΛ2_algρ3] by simp



subsection{* Up-to corecursor with guard not necessarily at the top *}

definition ff3 :: "'a F => 'a Σ3" where "ff3 ≡ Abs_Σ3 o Inl o ff2"

lemma algΛ3_ff3: "algΛ3 o ff3 = ctor_J"
unfolding ff3_def o_assoc algΛ3_Inl_pointfree algΛ2_ff2 ..

lemma ff3_transfer[transfer_rule]: "(F_rel R ===> Σ3_rel R) ff3 ff3"
unfolding ff3_def by transfer_prover

lemma ff3_natural: "Σ3_map f o ff3 = ff3 o F_map f"
using ff3_transfer[of "BNF_Def.Grp UNIV f"]
unfolding Σ3.rel_Grp F_rel_Grp
unfolding Grp_def rel_fun_def by auto

definition gg3 :: "'a ΣΣ3 F => 'a ΣΣ3" where
"gg3 ≡ \<oo>\<pp>3 o ff3"

lemma eval3_gg3: "eval3 o gg3 = ctor_J o F_map eval3"
unfolding gg3_def
unfolding o_assoc unfolding eval3_comp_\<oo>\<pp>3
unfolding o_assoc[symmetric] ff3_natural
unfolding o_assoc algΛ3_ff3 ..

lemma gg3_transfer[transfer_rule]: "(F_rel (ΣΣ3_rel R) ===> ΣΣ3_rel R) gg3 gg3"
unfolding gg3_def by transfer_prover

lemma gg3_natural: "ΣΣ3_map f o gg3 = gg3 o F_map (ΣΣ3_map f)"
using gg3_transfer[of "BNF_Def.Grp UNIV f"]
unfolding ΣΣ3.rel_Grp F_rel_Grp
unfolding Grp_def rel_fun_def by auto

definition unfoldUU3 :: "('a => 'a ΣΣ3 F ΣΣ3) => 'a => J" where
"unfoldUU3 s ≡ unfoldU3 (F_map flat3 o dd3 o ΣΣ3_map <gg3, id> o s)"

theorem unfoldUU3:
"unfoldUU3 s =
 eval3 o ΣΣ3_map (ctor_J o F_map eval3 o F_map (ΣΣ3_map (unfoldUU3 s))) o s"
unfolding unfoldUU3_def apply(subst unfoldU3_ctor_J_pointfree[symmetric]) unfolding unfoldUU3_def[symmetric]
unfolding extdd3_def F_map_comp[symmetric] o_assoc
apply(subst o_assoc[symmetric]) unfolding F_map_comp[symmetric]
apply(subst o_assoc[symmetric]) unfolding flat3_natural[symmetric]
apply(subst o_assoc) unfolding eval3_flat3
unfolding F_map_comp
apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
unfolding dd3_natural[symmetric]
unfolding o_assoc apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
unfolding dd3_natural[symmetric]
unfolding o_assoc[symmetric] ΣΣ3.map_comp0[symmetric]
unfolding o_assoc eval3_gg3 unfolding ΣΣ3.map_comp0 o_assoc
apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric],
      subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
unfolding ΣΣ3.map_comp0[symmetric]
apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
unfolding ΣΣ3.map_comp0[symmetric] map_prod.comp map_prod_o_convol o_id F_map_comp[symmetric]
apply(rule sym) unfolding o_assoc
apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
unfolding ΣΣ3.map_comp0[symmetric] F_map_comp[symmetric] o_assoc[symmetric] gg3_natural
unfolding o_assoc eval3_gg3
apply(rule sym)
apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
unfolding F_map_comp[symmetric] convol_comp_id2 convol_ctor_J_dtor_J
          ΣΣ3.map_comp0 o_assoc eval3 ctor_dtor_J_pointfree id_o ..

theorem unfoldUU3_unique:
assumes f: "f = eval3 o ΣΣ3_map (ctor_J o F_map eval3 o F_map (ΣΣ3_map f)) o s"
shows "f = unfoldUU3 s"
unfolding unfoldUU3_def apply(rule unfoldU3_unique)
apply(rule sym) apply(subst f) unfolding extdd3_def
unfolding o_assoc
apply(subst eval3_def) unfolding dtor_unfold_J_pointfree apply(subst eval3_def[symmetric])
unfolding o_assoc
apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
unfolding o_assoc ΣΣ3.map_comp0[symmetric]  convol_o id_o dtor_J_ctor_pointfree F_map_comp[symmetric]
unfolding o_assoc[symmetric] flat3_natural[symmetric]
unfolding o_assoc eval3_flat3 unfolding o_assoc[symmetric] unfolding F_map_comp
apply(rule sym) apply(subst F_map_comp[symmetric], subst ΣΣ3.map_comp0[symmetric])
unfolding o_assoc apply(subst o_assoc[symmetric])
unfolding dd3_natural[symmetric]
unfolding o_assoc[symmetric] ΣΣ3.map_comp0[symmetric] map_prod_o_convol o_id
unfolding o_assoc[symmetric] gg3_natural
unfolding o_assoc eval3_gg3 F_map_comp ..

(* corecursion: *)
definition corecUU3 :: "('a => (J + 'a) ΣΣ3 F ΣΣ3) => 'a => J" where
"corecUU3 s ≡
 unfoldUU3 (case_sum (leaf3 o dd3 o leaf3 o <Inl , F_map Inl o dtor_J>) s) o Inr"

lemma unfoldUU3_Inl:
"unfoldUU3 (case_sum (leaf3 o dd3 o leaf3 o <Inl , F_map Inl o dtor_J>) s) o Inl = id"
(is "?L = ?R")
proof-
  have "?L = unfoldUU3 (leaf3 o dd3 o leaf3 o <id, dtor_J>)"
  apply(rule unfoldUU3_unique)
  apply(subst unfoldUU3)
  unfolding o_assoc[symmetric] case_sum_o_inj snd_convol
  unfolding F_map_comp ΣΣ3.map_comp0
  unfolding o_assoc
  apply(rule sym)
  unfolding o_assoc
  apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric],
              subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
  apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric],
              subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric],
              subst o_assoc[symmetric])
  unfolding leaf3_natural apply(subst o_assoc[symmetric])
  apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric],
              subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
  unfolding dd3_natural[symmetric]
  apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric],
              subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
  apply(subst o_assoc[symmetric], subst o_assoc[symmetric])
  unfolding leaf3_natural
  unfolding o_assoc[symmetric] map_prod_o_convol o_id ..
  also have "... = ?R"
  apply(rule sym, rule unfoldUU3_unique)
  unfolding ΣΣ3.map_id0 F_map_id o_id
  unfolding o_assoc
  apply(subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric],
              subst o_assoc[symmetric], subst o_assoc[symmetric], subst o_assoc[symmetric])
  unfolding dd3_leaf3
  unfolding o_assoc[symmetric] snd_convol
  unfolding o_assoc
  apply(subst o_assoc[symmetric])
  unfolding leaf3_natural unfolding o_assoc eval3_leaf3 id_o
  apply(subst o_assoc[symmetric])
  unfolding F_map_comp[symmetric] eval3_leaf3 F_map_id o_id ctor_dtor_J_pointfree ..
  finally show ?thesis .
qed

theorem corecUU3_pointfree:
"corecUU3 s =
 eval3 o ΣΣ3_map (ctor_J o F_map eval3 o F_map (ΣΣ3_map (case_sum id (corecUU3 s)))) o s"
unfolding corecUU3_def
apply(subst unfoldUU3)
unfolding o_assoc[symmetric] unfolding case_sum_o_inj
apply(subst unfoldUU3_Inl[symmetric, of s])
unfolding o_assoc case_sum_Inl_Inr_L extdd3_def ..

theorem corecUU3_unique:
  assumes f: "f = eval3 o ΣΣ3_map (ctor_J o F_map eval3 o F_map (ΣΣ3_map (case_sum id f))) o s"
  shows "f = corecUU3 s"
  unfolding corecUU3_def
  apply(rule eq_o_InrI[OF unfoldUU3_Inl unfoldUU3_unique])
  apply (subst f)
  apply (auto simp: fun_eq_iff eval3_leaf3' pre_J.map_comp o_eq_dest[OF dd3_leaf3] convol_def
    leaf3_natural o_assoc case_sum_o_inj(1) eval3_leaf3 pre_J.map_id J.ctor_dtor split: sum.splits)
  done

theorem corecUU3:
"corecUU3 s a =
 eval3 (ΣΣ3_map (ctor_J o F_map eval3 o F_map (ΣΣ3_map (case_sum id (corecUU3 s)))) (s a))"
using corecUU3_pointfree unfolding o_def fun_eq_iff by(rule allE)

end