Theory Cotree

theory Cotree
imports Applicative Adhoc_Overloading BNF_Corec
section {* A codatatype of infinite binary trees *}

theory Cotree imports 
  Main
  "../../AFP/Applicative_Lifting/Applicative"
  "~~/src/Tools/Adhoc_Overloading"
  "../../BNF_Corec"
begin


section {* A codatatype of infinite binary trees *}

context notes [[bnf_internals]] [[bnf_typedef_threshold = 0]] begin
codatatype 'a tree = Node (root: 'a) (left: "'a tree") (right: "'a tree")
end

lemma rel_treeD:
  assumes "rel_tree A x y"
  shows rel_tree_rootD: "A (root x) (root y)"
  and rel_tree_leftD: "rel_tree A (left x) (left y)"
  and rel_tree_rightD: "rel_tree A (right x) (right y)"
using assms
by(cases x y rule: tree.exhaust[case_product tree.exhaust], simp_all)+

lemmas [simp] = tree.map_sel

lemma set_tree_induct[consumes 1, case_names root left right]:
  assumes x: "x ∈ set_tree t"
  and root: "⋀t. P (root t) t"
  and left: "⋀x t. ⟦ x ∈ set_tree (left t); P x (left t) ⟧ ⟹ P x t"
  and right: "⋀x t. ⟦ x ∈ set_tree (right t); P x (right t) ⟧ ⟹ P x t"
  shows "P x t"
using x
proof(rule tree.set_induct)
  fix l x r
  from root[of "Node x l r"] show "P x (Node x l r)" by simp
qed(auto intro: left right)

lemma corec_tree_cong:
  assumes "⋀x. stopL x ⟹ STOPL x = STOPL' x"
  and "⋀x. ~ stopL x ⟹ LEFT x = LEFT' x"
  and "⋀x. stopR x ⟹ STOPR x = STOPR' x"
  and "⋀x. ¬ stopR x ⟹ RIGHT x = RIGHT' x"
  shows "corec_tree ROOT stopL STOPL LEFT stopR STOPR RIGHT = 
         corec_tree ROOT stopL STOPL' LEFT' stopR STOPR' RIGHT'"
  (is "?lhs = ?rhs")
proof
  fix x
  show "?lhs x = ?rhs x"
    by(coinduction arbitrary: x rule: tree.coinduct_strong)(auto simp add: assms)
qed

context 
  fixes g1 :: "'a ⇒ 'b"
  and g22 :: "'a ⇒ 'a"
  and g32 :: "'a ⇒ 'a"
begin

corec unfold_tree :: "'a ⇒ 'b tree"
where "unfold_tree a = Node (g1 a) (unfold_tree (g22 a)) (unfold_tree (g32 a))"

lemma unfold_tree_simps [simp]:
  "root (unfold_tree a) = g1 a"
  "left (unfold_tree a) = unfold_tree (g22 a)"
  "right (unfold_tree a) = unfold_tree (g32 a)"
by(subst unfold_tree.code; simp; fail)+

end

lemma unfold_tree_unique:
  assumes "⋀s. root (f s) = ROOT s"
  and "⋀s. left (f s) = f (LEFT s)"
  and "⋀s. right (f s) = f (RIGHT s)"
  shows "f s = unfold_tree ROOT LEFT RIGHT s"
by(rule unfold_tree.unique[THEN fun_cong])(auto simp add: fun_eq_iff assms intro: tree.expand)

(* lemma corec_unfold_tree:
  "corec_tree ROOT (λ_. False) l LEFT (λ_. False) r RIGHT = unfold_tree ROOT LEFT RIGHT"
by(auto simp add: unfold_tree_def intro: corec_tree_cong)
 *)

subsection {* Applicative functor for @{typ "'a tree"} *}

context fixes x :: "'a" begin
corec pure_tree :: "'a tree"
where "pure_tree = Node x pure_tree pure_tree"
end

lemmas pure_tree_unfold = pure_tree.code

lemma pure_tree_simps [simp]: 
  "root (pure_tree x) = x"
  "left (pure_tree x) = pure_tree x"
  "right (pure_tree x) = pure_tree x"
by(subst pure_tree_unfold; simp; fail)+

adhoc_overloading pure pure_tree

lemma map_pure_tree [simp]: "map_tree f (pure x) = pure (f x)"
by(coinduction arbitrary: x) auto

lemmas pure_tree_unique = pure_tree.unique

primcorec ap_tree :: "('a ⇒ 'b) tree ⇒ 'a tree ⇒ 'b tree"
where
  "root (ap_tree f x) = root f (root x)"
| "left (ap_tree f x) = ap_tree (left f) (left x)"
| "right (ap_tree f x) = ap_tree (right f) (right x)"

adhoc_overloading Applicative.ap ap_tree

unbundle applicative_syntax

lemma ap_tree_pure_Node [simp]:
  "pure f ⋄ Node x l r = Node (f x) (pure f ⋄ l) (pure f ⋄ r)"
by(rule tree.expand) auto

lemma ap_tree_Node_Node [simp]:
  "Node f fl fr ⋄ Node x l r = Node (f x) (fl ⋄ l) (fr ⋄ r)"
by(rule tree.expand) auto

text {* Applicative functor laws *}

lemma map_tree_ap_tree_pure_tree:
  "pure f ⋄ u = map_tree f u"
by(coinduction arbitrary: u) auto

lemma ap_tree_identity: "pure id ⋄ t = t"
by(simp add: map_tree_ap_tree_pure_tree tree.map_id)

lemma ap_tree_composition:
  "pure (op ∘) ⋄ r1 ⋄ r2 ⋄ r3 = r1 ⋄ (r2 ⋄ r3)"
by(coinduction arbitrary: r1 r2 r3) auto

lemma ap_tree_homomorphism:
  "pure f ⋄ pure x = pure (f x)"
by(simp add: map_tree_ap_tree_pure_tree)

lemma ap_tree_interchange:
  "t ⋄ pure x = pure (λf. f x) ⋄ t"
by(coinduction arbitrary: t) auto

lemma ap_tree_K_tree: "pure (λx y. x) ⋄ u ⋄ v = u"
by(coinduction arbitrary: u v)(auto)

lemma ap_tree_C_tree: "pure (λf x y. f y x) ⋄ u ⋄ v ⋄ w = u ⋄ w ⋄ v"
by(coinduction arbitrary: u v w)(auto)

lemma ap_tree_W_tree: "pure (λf x. f x x) ⋄ f ⋄ x = f ⋄ x ⋄ x"
by(coinduction arbitrary: f x)(auto)

applicative tree (C, K, W) for
  pure: pure_tree
  ap: ap_tree
by(rule ap_tree_identity[unfolded id_def] ap_tree_composition[unfolded o_def[abs_def]] ap_tree_homomorphism ap_tree_interchange ap_tree_K_tree ap_tree_C_tree ap_tree_W_tree)+

declare map_tree_ap_tree_pure_tree[symmetric, applicative_unfold]

lemma ap_tree_strong_extensional:
  "(⋀x. f ⋄ pure x = g ⋄ pure x) ⟹ f = g"
proof(coinduction arbitrary: f g)
  case [rule_format]: (Eq_tree f g)
  have "root f = root g"
  proof
    fix x
    show "root f x = root g x"
      using Eq_tree[of x] by(subst (asm) (1 2) ap_tree.ctr) simp
  qed
  moreover {
    fix x
    have "left f ⋄ pure x = left g ⋄ pure x"
      using Eq_tree[of x] by(subst (asm) (1 2) ap_tree.ctr) simp
  } moreover {
    fix x
    have "right f ⋄ pure x = right g ⋄ pure x"
      using Eq_tree[of x] by(subst (asm) (1 2) ap_tree.ctr) simp
  } ultimately show ?case by simp
qed

lemma ap_tree_extensional:
  "(⋀x. f ⋄ x = g ⋄ x) ⟹ f = g"
by(rule ap_tree_strong_extensional) simp

subsection {* Standard tree combinators *}

subsubsection {* Recurse combinator *}

text {*
  This will be the main combinator to define trees recursively

  Uniqueness for this gives us the unique fixed-point theorem for guarded recursive definitions.
*}
lemma map_unfold_tree [simp]: fixes l r x
 defines "unf ≡ unfold_tree (λf. f x) (λf. f ∘ l) (λf. f ∘ r)"
 shows "map_tree G (unf F) = unf (G ∘ F)"
by(coinduction arbitrary: F G)(auto 4 3 simp add: unf_def o_assoc)

friend_of_corec map_tree :: "('a ⇒ 'a) ⇒ 'a tree ⇒ 'a tree" where
  "map_tree f t = Node (f (root t)) (map_tree f (left t)) (map_tree f (right t))"
subgoal by (rule tree.expand; simp)
subgoal by (fold relator_eq; transfer_prover)
done

context fixes l :: "'a ⇒ 'a" and r :: "'a ⇒ 'a" and x :: "'a" begin
corec tree_recurse :: "'a tree"
where "tree_recurse = Node x (map_tree l tree_recurse) (map_tree r tree_recurse)"
end

lemma tree_recurse_simps [simp]:
  "root (tree_recurse l r x) = x"
  "left (tree_recurse l r x) = map_tree l (tree_recurse l r x)"
  "right (tree_recurse l r x) = map_tree r (tree_recurse l r x)"
by(subst tree_recurse.code; simp; fail)+

lemma tree_recurse_unfold:
  "tree_recurse l r x = Node x (map_tree l (tree_recurse l r x)) (map_tree r (tree_recurse l r x))"
by(fact tree_recurse.code)

lemma tree_recurse_fusion:
  assumes "h ∘ l = l' ∘ h" and "h ∘ r = r' ∘ h"
  shows "map_tree h (tree_recurse l r x) = tree_recurse l' r' (h x)"
by(rule tree_recurse.unique)(simp add: tree.expand tree.map_comp assms)

subsubsection {* Tree iteration *}

context fixes l :: "'a ⇒ 'a" and r :: "'a ⇒ 'a" begin
primcorec tree_iterate :: " 'a ⇒ 'a tree"
where "tree_iterate s = Node s (tree_iterate (l s)) (tree_iterate (r s))"
end

lemma unfold_tree_tree_iterate:
  "unfold_tree out l r = map_tree out ∘ tree_iterate l r"
by(rule ext)(rule unfold_tree_unique[symmetric]; simp)

lemma tree_iterate_fusion:
  assumes "h ∘ l = l' ∘ h"
  assumes "h ∘ r = r' ∘ h"
  shows "map_tree h (tree_iterate l r x) = tree_iterate l' r' (h x)"
apply(coinduction arbitrary: x)
using assms by(auto simp add: fun_eq_iff)

(* lemma tree_iterate_unfold:
  "tree_iterate l r s = Node s (tree_iterate l r (l s)) (tree_iterate l r (r s))"
by(fact tree_iterate.code) *)

subsubsection {* Tree traversal *}

datatype dir = L | R
type_synonym path = "dir list"

definition traverse_tree :: "path ⇒ 'a tree ⇒ 'a tree"
where "traverse_tree path ≡ foldr (λd f. f ∘ case_dir left right d) path id"

lemma traverse_tree_simps[simp]:
  "traverse_tree [] = id"
  "traverse_tree (d # path) = traverse_tree path ∘ (case d of L ⇒ left | R ⇒ right)"
by (simp_all add: traverse_tree_def)

lemma traverse_tree_map_tree [simp]:
  "traverse_tree path (map_tree f t) = map_tree f (traverse_tree path t)"
by (induct path arbitrary: t) (simp_all split: dir.splits)

lemma traverse_tree_append [simp]:
  "traverse_tree (path @ ext) t = traverse_tree ext (traverse_tree path t)"
by (induct path arbitrary: t) simp_all

text{* @{const "traverse_tree"} is an applicative-functor homomorphism. *}

lemma traverse_tree_pure_tree [simp]:
  "traverse_tree path (pure x) = pure x"
by (induct path arbitrary: x) (simp_all split: dir.splits)

lemma traverse_tree_ap [simp]:
  "traverse_tree path (f ⋄ x) = traverse_tree path f ⋄ traverse_tree path x"
by (induct path arbitrary: f x) (simp_all split: dir.splits)

context fixes l r :: "'a ⇒ 'a" begin

primrec traverse_dir :: "dir ⇒ 'a ⇒ 'a"
where
  "traverse_dir L = l"
| "traverse_dir R = r"

abbreviation traverse_path :: "path ⇒ 'a ⇒ 'a"
where "traverse_path ≡ fold traverse_dir"

end

lemma traverse_tree_tree_iterate:
  "traverse_tree path (tree_iterate l r s) =
   tree_iterate l r (traverse_path l r path s)"
by (induct path arbitrary: s) (simp_all split: dir.splits)

text{*

\citeauthor{DBLP:journals/jfp/Hinze09} shows that if the tree
construction function is suitably monoidal then recursion and
iteration define the same tree.

*}

lemma tree_recurse_iterate:
  assumes monoid:
    "⋀x y z. f (f x y) z = f x (f y z)"
    "⋀x. f x ε = x"
    "⋀x. f ε x = x"
  shows "tree_recurse (f l) (f r) ε = tree_iterate (λx. f x l) (λx. f x r) ε"
apply(rule tree_recurse.unique[symmetric])
apply(rule tree.expand)
apply(simp add: tree_iterate_fusion[where r'="λx. f x r" and l'="λx. f x l"] fun_eq_iff monoid)
done

subsubsection {* Mirroring *}

primcorec mirror :: "'a tree ⇒ 'a tree"
where
  "root (mirror t) = root t"
| "left (mirror t) = mirror (right t)"
| "right (mirror t) = mirror (left t)"

lemma mirror_unfold: "mirror (Node x l r) = Node x (mirror r) (mirror l)"
by(rule tree.expand) simp

lemma mirror_pure: "mirror (pure x) = pure x"
by(coinduction rule: tree.coinduct) simp

lemma mirror_ap_tree: "mirror (f ⋄ x) = mirror f ⋄ mirror x"
by(coinduction arbitrary: f x) auto

end