Theory IO_Monad

theory IO_Monad
imports Monad_Syntax Misc
section {* The IO monad *}

theory IO_Monad imports
  "~~/src/HOL/Library/Monad_Syntax"
  Misc
begin

subsection {* The real world *}

text {*
  We model the real world as an uninterpreted type @{text "real_world"} 
  and build the IO monad like in Haskell. The additional result
  @{term None} denotes a non-terminating computation.
*}

typedecl real_world

typedef 'a IO = "UNIV :: (real_world ⇒ ('a × real_world) option) set" ..

setup_lifting type_definition_IO

context includes lifting_syntax begin

lemma return_IO_raw_parametric: "(A ===> op = ===> rel_option (rel_prod A op =)) (λx state. Some (x, state)) (λx state. Some (x, state))"
by transfer_prover

lift_definition return_IO :: "'a ⇒ 'a IO" is "λx state. Some (x, state)" parametric return_IO_raw_parametric .

lemma bind_IO_parametric_raw:
  "((A ===> rel_option (rel_prod B op =)) ===> (B ===> op = ===> rel_option (rel_prod D op =)) ===> A ===> rel_option (rel_prod D op =))
   (λf (g :: 'a ⇒ real_world ⇒ ('b × real_world) option) state. f state ⤜ case_prod g)
   (λf (g :: 'c ⇒ real_world ⇒ ('d × real_world) option) state. f state ⤜ case_prod g)"
by transfer_prover

lift_definition bind_IO :: "('a IO) ⇒ ('a ⇒ 'b IO) ⇒ 'b IO"
is "λf (g :: 'a ⇒ real_world ⇒ ('b × real_world) option) state. f state ⤜ case_prod g"
parametric bind_IO_parametric_raw .

adhoc_overloading bind bind_IO

lemma return_IO_bind [simp]:
  "return_IO x ⤜ g = g x"
by transfer simp

lemma bind_IO_return [simp]:
  "f ⤜ return_IO = f"
by transfer simp

lemma bind_IO_bind [simp]:
  fixes f :: "'a IO" shows
  "f ⤜ (λx. g x ⤜ h) = (f ⤜ g) ⤜ h"
by transfer(simp add: split_def)

lemmas IO_monad = return_IO_bind bind_IO_return bind_IO_bind

lift_definition run_IO :: "'a IO ⇒ real_world ⇒ 'a option"
is "λf state. map_option fst (f state)" .

lemma run_IO_return_IO [simp]: "run_IO (return_IO x) state = Some x"
by(transfer) simp

lemma run_IO_bind_IO: "run_IO (x ⤜ f) state = (Rep_IO x state ⤜ (λ(x, state'). run_IO (f x) state'))"
by transfer(simp add: split_def map_option_bind o_def)

lift_definition fail_IO :: "'a IO"
is "λstate. None" .

lemma run_IO_fail_IO [simp]: "run_IO fail_IO state = None"
by transfer simp

definition fail_IO' :: "unit ⇒ 'a IO"
where [simp, code del]: "fail_IO' _ = fail_IO"

lemma fail_IO_code [code, code_unfold]: "fail_IO = fail_IO' ()"
by simp

declare [[code abort: fail_IO']]

subsection {* Setup for lifting and transfer *}

lemma IO_rel_parametric_raw:
  "((A ===> B ===> op =) ===> (op = ===> rel_option (rel_prod A op =)) ===> (op = ===> rel_option (rel_prod B op =)) ===> op =)
   (λA. op = ===> rel_option (rel_prod A op =)) (λB. op = ===> rel_option (rel_prod B op =))"
by(clarsimp simp add: rel_fun_def rel_prod_conv split_beta rel_option_unfold)(safe, blast+)

lift_definition IO_rel :: "('a ⇒ 'b ⇒ bool) ⇒ 'a IO ⇒ 'b IO ⇒ bool"
is "λA. rel_fun op = (rel_option (rel_prod A op =))" parametric IO_rel_parametric_raw .

declare [[code drop: IO_rel]]

lemma IO_pred_parametric_raw:
  "((A ===> op =) ===> (op = ===> rel_option (rel_prod A op =)) ===> op =)
   (λP. fun_eq_pred (pred_option (pred_prod P (λ_. True))))
   (λP. fun_eq_pred (pred_option (pred_prod P (λ_. True))))"
by transfer_prover

lift_definition IO_pred :: "('a ⇒ bool) ⇒ 'a IO ⇒ bool"
is "λP. fun_eq_pred (pred_option (pred_prod P (λ_. True)))"
parametric IO_pred_parametric_raw .

declare [[code drop: IO_pred]]

lemma IO_rel_eq [relator_eq]: "IO_rel op = = op ="
unfolding fun_eq_iff by transfer(simp add: relator_eq)

lemma IO_rel_mono [relator_mono]: "A ≤ B ⟹ IO_rel A ≤ IO_rel B"
unfolding le_fun_def by transfer(auto simp add: rel_fun_def rel_prod_conv split_beta rel_option_unfold)

lemma IO_rel_OO [relator_distr]: "IO_rel A OO IO_rel B = IO_rel (A OO B)"
unfolding fun_eq_iff by transfer(simp add: fun_rel_eq_OO prod.rel_compp[symmetric] eq_OO option.rel_compp[symmetric])

lemma Domainp_IO [relator_domain]:
  "Domainp A = P ⟹ Domainp (IO_rel A) = IO_pred P"
by transfer(simp add: Domainp_fun_rel_eq prod.Domainp_rel Domainp_eq option.Domainp_rel)

lemma reflp_fun1: "⟦ is_equality A; reflp B ⟧ ⟹ reflp (A ===> B)"
by(simp add: reflp_def rel_fun_def is_equality_def)

(*
lemma reflp_IO_rel [reflexivity_rule]:
  "reflp A ⟹ reflp (IO_rel A)"
by transfer(simp add: reflp_fun1 is_equality_eq)
*)

lemma left_total_IO_rel [transfer_rule]:
  "left_total A ⟹ left_total (IO_rel A)"
by(transfer)(intro left_total_eq left_total_fun left_unique_eq prod.left_total_rel option.left_total_rel)+

lemma left_unique_IO_rel [transfer_rule]:
  "left_unique A ⟹ left_unique (IO_rel A)"
by transfer(intro left_unique_fun left_total_eq prod.left_unique_rel left_unique_eq option.left_unique_rel)

lemma right_total_IO_rel [transfer_rule]:
  "right_total A ⟹ right_total (IO_rel A)"
by transfer(intro right_total_fun right_unique_eq prod.right_total_rel right_total_eq option.right_total_rel)

lemma right_unique_IO_rel [transfer_rule]:
  "right_unique A ⟹ right_unique (IO_rel A)"
by transfer (intro right_unique_fun right_total_eq prod.right_unique_rel right_unique_eq option.right_unique_rel)

lemma bi_total_IO_rel [transfer_rule]:
  "bi_total A ⟹ bi_total (IO_rel A)"
unfolding bi_total_alt_def by(simp add: left_total_IO_rel right_total_IO_rel)

lemma bi_unique_IO_rel [transfer_rule]:
  "bi_unique A ⟹ bi_unique (IO_rel A)"
unfolding bi_unique_alt_def by(simp add: left_unique_IO_rel right_unique_IO_rel)

lemma IO_invariant_commute [relator_eq_onp]:
  "IO_rel (eq_onp P) = eq_onp (IO_pred P)"
by transfer(simp add: fun_eq_invariant[symmetric] prod.rel_eq_onp[symmetric] option.rel_eq_onp[symmetric] invariant_True)

lemma IO_map_parametric_raw:
  "((A ===> B) ===> (op = ===> rel_option (rel_prod A op =)) ===> (op = ===> rel_option (rel_prod B op =)))
    (λf. op ∘ (map_option (apfst f)))
    (λf. op ∘ (map_option (apfst f)))"
by transfer_prover

lift_definition IO_map :: "('a ⇒ 'b) ⇒ 'a IO ⇒ 'b IO"
is "λf. op ∘ (map_option (apfst f))" parametric IO_map_parametric_raw .

lemma IO_map_code [code]:
  "IO_map f x = x ⤜ return_IO ∘ f"
apply transfer
apply(clarsimp simp add: fun_eq_iff apfst_def split_beta map_prod_def)
apply(case_tac "x xa")
apply(simp_all add: split_beta)
done

lemma Quotient_prod [quot_map]:
  assumes "Quotient R Abs Rep T"
  shows "Quotient (IO_rel R) (IO_map Abs) (IO_map Rep) (IO_rel T)"
using fun_quotient[OF identity_quotient[where ?'a=real_world] option.Quotient[OF prod.Quotient[OF assms identity_quotient[where ?'a=real_world]]]]
unfolding Quotient_alt_def
by transfer(simp add: apfst_def map_fun_def)

lemma return_IO_parametric [transfer_rule]:
  "(A ===> IO_rel A) return_IO return_IO"
proof
  fix x y
  assume [transfer_rule]: "A x y"
  show "IO_rel A (return_IO x) (return_IO y)"
    by transfer transfer_prover
qed

lemma bind_IO_parametric [transfer_rule]:
  "(IO_rel A ===> (A ===> IO_rel B) ===> IO_rel B) bind_IO bind_IO"
proof(rule rel_funI)+
  fix x y f g
  have [transfer_rule]: 
    "(op = ===> (pcr_IO op = ===> pcr_IO op = ===> op =) ===> (op = ===> pcr_IO op =) ===> (op = ===> pcr_IO op =) ===> op =) op ===> op ===>" 
    by(auto simp add: rel_fun_def IO.pcr_cr_eq cr_IO_def)

  assume "IO_rel A x y" "(A ===> IO_rel B) f g"
  thus "IO_rel B (x ⤜ f) (y ⤜ g)"
    by(transfer fixing: A B)(rule bind_IO_parametric_raw[THEN rel_funD, THEN rel_funD])
qed

lemma IO_map_parametric [transfer_rule]:
  "((A ===> B) ===> IO_rel A ===> IO_rel B) IO_map IO_map"
by(rule rel_funI)+(transfer, rule IO_map_parametric_raw[THEN rel_funD, THEN rel_funD])

lemma IO_pred_parametric [transfer_rule]:
  "((A ===> op =) ===> IO_rel A ===> op =) IO_pred IO_pred"
by(rule rel_funI)+(transfer, rule IO_pred_parametric_raw[THEN rel_funD, THEN rel_funD])

lemma IO_rel_parametric [transfer_rule]:
  "((A ===> B ===> op =) ===> IO_rel A ===> IO_rel B ===> op =) IO_rel IO_rel"
by(rule rel_funI)+(transfer, rule IO_rel_parametric_raw[THEN rel_funD, THEN rel_funD, THEN rel_funD])

lifting_update IO.lifting
lifting_forget IO.lifting

end

subsection {* Code generator setup *}

lemma [code, code del]: "return_IO = return_IO" ..

lemma [code, code del]: "bind_IO = bind_IO" ..

declare [[code abort: bind_IO return_IO]]

code_printing code_module Reactive_Concurrency  (Scala)
{*object Reactive_Concurrency {
  def bind[A, B](x: A, g: A => B) : B = { g(x) }
}*}
code_reserved Scala Reactive_Concurrency

code_printing type_constructor IO 
  (SML) "_" and
  (OCaml) "_" and
  (Haskell) "Prelude.IO _" and
  (Scala) "_"
| constant return_IO 
  (SML) "_" and
  (OCaml) "_" and
  (Haskell) "return" and
  (Scala) "_"
| constant bind_IO 
  (SML) "!(fn f'_ => f'_ _)" and
  (OCaml) "!(fun f'_ -> f'_ _)" and
  (Scala) "Reactive'_Concurrency.bind"
code_monad bind_IO Haskell

lemmas [transfer_rule] =
  identity_quotient
  fun_quotient
  Quotient_integer[folded integer.pcr_cr_eq]
  Quotient_IO[folded IO.pcr_cr_eq]

lemma undefined_transfer:
  assumes "Quotient R Abs Rep T"
  shows "T (Rep undefined) undefined"
using assms unfolding Quotient_alt_def by blast

bundle undefined_transfer = undefined_transfer[transfer_rule]

end