Theory PMF_Code

theory PMF_Code
imports PMF_Impl
theory PMF_Code imports
  PMF_Impl
begin

lemma type_definition_parametric [transfer_rule]:
  includes lifting_syntax
  assumes [transfer_rule]: "bi_unique A" "bi_total A" "bi_unique B" 
  shows "((A ===> B) ===> (B ===> A) ===> rel_set B ===> op =) type_definition type_definition"
  unfolding type_definition_def
  by(fold Ball_def) transfer_prover

lemma apsnd_parametric [transfer_rule]: includes lifting_syntax shows
  "((A ===> B) ===> rel_prod C A ===> rel_prod C B) apsnd apsnd"
  unfolding apsnd_def by transfer_prover

lemma (in comm_monoid_add) sum_mset_mset: "sum_mset (mset xs) = sum_list xs"
  by(induction xs)(simp_all add: add.commute)

lemma (in comm_monoid_add) sum_list_cong: "mset xs = mset ys ⟹ sum_list xs = sum_list ys"
  by(simp add: sum_mset_mset[symmetric])

lemma (in monoid_add) sum_list_concat: "sum_list (concat xs) = sum_list (map sum_list xs)"
  by(induction xs) simp_all


definition mUnion :: "'a multiset multiset ⇒ 'a multiset"
where "mUnion = fold_mset (op +) {#}"

context begin

interpretation mUnion: comp_fun_commute "op + :: 'a multiset ⇒ _"
by unfold_locales(simp add: add_ac fun_eq_iff)

lemma mUnion_empty [simp]: "mUnion {#} = {#}"
by(simp add: mUnion_def)

lemma mUnion_single [simp]: "mUnion {# A #} = A"
by(simp add: mUnion_def)

lemma mUnion_union [simp]: "mUnion (A + B) = mUnion A + mUnion B"
unfolding mUnion_def by(induction B)(simp_all add: add_ac)

lemma mUnion_add_mset [simp]: "mUnion (add_mset x A) = x + (mUnion A)"
by(simp add: mUnion_def)

end

lemma image_mset_mUnion: "image_mset f (mUnion A) = mUnion (image_mset (image_mset f) A)"
  by(induction A) auto

lemma count_mUnion [simp]: "count (mUnion A) x = sum_mset (image_mset (λB. count B x) A)"
  by(induction A) simp_all

lemma mset_concat [simp]: "mset (concat xss) = mUnion (mset (map mset xss))"
by(induction xss)(simp_all add: add_ac)

lemma mset_distinct: "distinct xs ⟹ mset xs = mset_set (set xs)"
by(induction xs)(auto simp add: add_ac)

lemma image_mset_mset_set: 
  assumes inj: "inj_on f A"
  and "finite A"
  shows "image_mset f (mset_set A) = mset_set (f ` A)"
proof -
  def B  A
  then have "finite B" "B ⊆ A" using ‹finite A› by simp_all
  then show "image_mset f (mset_set B) = mset_set (f ` B)"
  proof(induction)
    case (insert x F)
    have "f x ∉ f ` F" using ‹x ∉ F› insert.prems by(auto dest: inj_onD[OF inj])
    thus ?case using insert by simp
  qed simp
qed

lemma filter_mset_False [simp]: "filter_mset (λ_. False) A = {#}"
  by(induction A) simp_all

lemma mset_remdups:
  "mset (remdups xs) = mset_set (set xs)"
  by(subst mset_distinct; simp)

text ‹Non-negative real numbers›

typedef nnreal = "{x :: real. x ≥ 0}" 
  morphisms real_of_nnreal nnreal' by auto

setup_lifting type_definition_nnreal

lift_definition nnreal :: "real ⇒ nnreal" is "sup 0" by(rule sup_ge1)

lemma nnreal_cases [cases type]:
  obtains y where "x = nnreal y" "y ≥ 0"
  by transfer (simp add: sup.absorb_iff2)

lift_definition pmf_nnr :: "'a pmf ⇒ 'a ⇒ nnreal" is pmf by(rule pmf_nonneg)
lift_definition embed_pmf_nnr :: "('a ⇒ nnreal) ⇒ 'a pmf" is embed_pmf .
lift_definition ennreal_of_nnreal :: "nnreal ⇒ ennreal" is ennreal .

lemma ennreal_of_nnreal_nnreal [simp]: "ennreal_of_nnreal (nnreal x) = ennreal x"
  by transfer (metis ennreal_eq_0_iff le_iff_sup linear sup.orderE)

lemma ennreal_of_nnreal_inject [simp]: "ennreal_of_nnreal x = ennreal_of_nnreal y ⟷ x = y"
  by transfer simp

instantiation nnreal :: "{zero, one, plus, times, comm_monoid_add, comm_monoid_mult, comm_semiring_0}" begin
lift_definition zero_nnreal :: nnreal is 0 by simp
lift_definition one_nnreal :: nnreal is 1 by simp
lift_definition plus_nnreal :: "nnreal ⇒ nnreal ⇒ nnreal" is "op +" by simp
lift_definition times_nnreal :: "nnreal ⇒ nnreal ⇒ nnreal" is "op *" by simp
instance by(intro_classes; transfer; simp add: algebra_simps)+
end

lemma ennreal_of_nnreal_0 [simp]: "ennreal_of_nnreal 0 = 0"
  by transfer simp

lemma ennreal_of_nnreal_eq_0 [simp]: "ennreal_of_nnreal x = 0 ⟷ x = 0"
  by transfer simp

context includes ennreal.lifting begin

lift_definition nnreal_of_ennreal :: "ennreal ⇒ nnreal" is real_of_ereal by(rule real_of_ereal_pos)

lemma ennreal_of_nnreal_inverse [simp]: "nnreal_of_ennreal (ennreal_of_nnreal x) = x"
  apply(cases x)
  apply clarsimp
  apply transfer
  apply(clarsimp simp add: ereal_max_0 ereal_max[symmetric] sup_real_def simp del: ereal_max)
  done

lemma nnreal_of_ennreal_0 [simp]: "nnreal_of_ennreal 0 = 0"
  by transfer simp

end

lemma ennreal_of_nnreal_plus [simp]: "ennreal_of_nnreal x + ennreal_of_nnreal y = ennreal_of_nnreal (x + y)"
  by transfer simp

lemma ennreal_of_nnreal_times [simp]: "ennreal_of_nnreal x * ennreal_of_nnreal y = ennreal_of_nnreal (x * y)"
  by transfer(simp add: ennreal_mult)

lemma setsum_ennreal_of_nnreal [simp]:
  "sum (%x. ennreal_of_nnreal (f x)) A = ennreal_of_nnreal (sum f A)"
  apply(cases "finite A")
  subgoal by(induction rule: finite_induct; auto)
  subgoal by simp
  done

instantiation nnreal :: equal begin
lift_definition equal_nnreal :: "nnreal ⇒ nnreal ⇒ bool" is "op =" .
instance by(intro_classes;transfer;simp)
end

lifting_update nnreal.lifting
lifting_forget nnreal.lifting


typedef 'a pmf_impl = "UNIV :: ('a ⇒ nnreal) set" ..

setup_lifting type_definition_pmf_impl


lift_definition rep_pmf :: "'a pmf ⇒ 'a pmf_impl" is pmf_nnr .
lift_definition abs_pmf :: "'a pmf_impl ⇒ 'a pmf" is embed_pmf_nnr .
lift_definition wf_pmf :: "'a pmf_impl ⇒ bool" is "λf. (∫+ x. ennreal_of_nnreal (f x) ∂count_space UNIV) = 1" .

lemma td_pmf_code: "type_definition rep_pmf abs_pmf {f. wf_pmf f}"
  apply transfer 
  including nnreal.lifting
  apply transfer
  apply(simp add: conj_commute td_pmf_embed_pmf)
  done

lemma pmf_code_cert [code abstype]: "abs_pmf (rep_pmf p) = p"
  by(rule type_definition.Rep_inverse[OF td_pmf_code])

lift_definition return_pmf_impl :: "'a ⇒ 'a pmf_impl" is "λx y. if x = y then 1 else 0" .

lemma return_pmf_code [code abstract]: "rep_pmf (return_pmf x) = return_pmf_impl x"
  apply transfer
  including nnreal.lifting
  apply transfer
  apply(simp add: fun_eq_iff)
  done

 

lift_definition bind_pmf_impl :: "'a pmf_impl ⇒ ('a ⇒ 'b pmf_impl) ⇒ 'b pmf_impl"
is "λf g x. nnreal_of_ennreal (∫+ y. ennreal_of_nnreal (f y) * ennreal_of_nnreal (g y x) ∂count_space UNIV)" .

lemma bind_pmf_code [code abstract]:
  "rep_pmf (bind_pmf p f) = bind_pmf_impl (rep_pmf p) (λx. rep_pmf (f x))"
proof(transfer; rule ext)
  fix p and f :: "'b ⇒ 'a pmf" and x
  have "ennreal_of_nnreal (pmf_nnr (p ⤜ f) x) = (∫+ y. ennreal_of_nnreal (pmf_nnr p y) * ennreal_of_nnreal (pmf_nnr (f y) x) ∂count_space UNIV)"
    (is "ennreal_of_nnreal ?lhs = ?rhs") including nnreal.lifting
    by transfer(simp add: ennreal_pmf_bind nn_integral_measure_pmf)
  from this[symmetric] show "?lhs = nnreal_of_ennreal ?rhs" by simp
qed


text ‹Implementation of @{typ "'a pmf_impl"} as associative lists›

lift_definition pmf_impl_of_pfp :: "('a × nnreal) list ⇒ 'a pmf_impl"
is "λxs x. sum_list (map snd [(x', y :: nnreal) ← xs. x' = x])" .

code_datatype pmf_impl_of_pfp

lift_definition zero_pmf_impl :: "'a pmf_impl" is "λ_. 0" .

lemma zero_pmf_impl_code [code]: "zero_pmf_impl = pmf_impl_of_pfp []"
  by transfer simp

lemma pmf_impl_of_pfp_Nil [simp]: "pmf_impl_of_pfp [] = zero_pmf_impl"
  by transfer simp

lift_definition scale_pmf_impl :: "nnreal ⇒ 'a pmf_impl ⇒ 'a pmf_impl"
  is "λ(r :: nnreal) f x. f x * r" .

lemma scale_pmf_impl_code [code]: 
  "scale_pmf_impl p (pmf_impl_of_pfp xs) = pmf_impl_of_pfp (if p = 0 then [] else map (apsnd (op * p)) xs)"
  apply transfer
  including nnreal.lifting
  apply transfer
  apply(auto simp add: filter_map o_def split_def mult_ac sum_list_const_mult)
  done

lemma scale_pmf_impl_alt: "scale_pmf_impl p (pmf_impl_of_pfp xs) = pmf_impl_of_pfp (map (apsnd (op * p)) xs)"
  apply transfer
  including nnreal.lifting
  apply transfer
  apply(auto simp add: filter_map o_def split_def mult_ac sum_list_const_mult)
  done


lift_definition plus_pmf_impl :: "'a pmf_impl ⇒ 'a pmf_impl ⇒ 'a pmf_impl" is "λf g x. f x + g x" .

lemma plus_pmf_impl_code [code]:
  "plus_pmf_impl (pmf_impl_of_pfp xs) (pmf_impl_of_pfp ys) = pmf_impl_of_pfp (xs @ ys)"
  by transfer auto

lemma plus_pmf_impl_zero [simp]: "plus_pmf_impl zero_pmf_impl p = p"
  by transfer simp

lemma plus_pmf_impl_assoc [simp]:
  "plus_pmf_impl (plus_pmf_impl p q) r = plus_pmf_impl p (plus_pmf_impl q r)"
  by transfer(simp add: add_ac)

lemma pmf_impl_of_pfp_Cons [simp]:
  "pmf_impl_of_pfp ((x, p) # xs) = plus_pmf_impl (scale_pmf_impl p (return_pmf_impl x)) (pmf_impl_of_pfp xs)"
  by transfer(simp add: fun_eq_iff)

lemma pmf_impl_of_pfp_append [simp]:
  "pmf_impl_of_pfp (xs @ ys) = plus_pmf_impl (pmf_impl_of_pfp xs) (pmf_impl_of_pfp ys)"
  by(induction xs) auto

lemma return_pmf_impl_code [code]: "return_pmf_impl x = pmf_impl_of_pfp [(x, 1)]"
  by transfer(simp add: fun_eq_iff)

lemma bind_pmf_impl_code [code]: 
  "bind_pmf_impl (pmf_impl_of_pfp xs) f = foldr (λ(x, p). plus_pmf_impl (scale_pmf_impl p (f x))) xs zero_pmf_impl"
proof(transfer; rule ext)
  fix xs :: "('b × nnreal) list" and f :: "'b ⇒ 'a ⇒ nnreal" and x
  have inj: "inj_on (mset ∘ (λa. [(x', y)←xs . x' = a])) (fst ` set xs)"
    by(rule inj_onI)(auto 4 3 simp add: mset_filter dest!: arg_cong[where f=set_mset])

  have *: "mUnion (mset_set ((λx. {# (x', y) ∈# B. x' = x#}) ` A)) = {# (x, y) ∈# B. x ∈ A #}"
    if "finite A" for A :: "'c set" and B :: "('c × 'd) multiset"
    using that
  proof(induction A)
    case (insert x A)
    show ?case
    proof(cases "x ∈ fst ` set_mset B")
      case True
      then have "{# (x', y) ∈# B. x' = x#} ∉ (λx. {# (x', y) ∈# B. x' = x#}) ` A" using ‹x ∉ A›
        by(auto 4 3 dest!: arg_cong[where f="set_mset"] elim: equalityE)
      then show ?thesis using insert by(auto intro: multiset_eqI)
    next
      case False
      hence "{# (x', y) ∈# B. x' = x#} = {#}"
        by(auto 4 3 intro!: multiset_eqI simp add: count_eq_zero_iff intro: rev_image_eqI)
      then have "mUnion (mset_set ((λx. {# a ∈# B. case a of (x', y) ⇒ x' = x#}) ` insert x A)) = mUnion (mset_set ((λx. {# a ∈# B. case a of (x', y) ⇒ x' = x#}) ` A))"
        using ‹finite A›
        by(simp add: mset_set.insert_remove)(metis (no_types, lifting) Diff_empty Diff_insert0 add.left_neutral finite_imageI mUnion_add_mset mset_set.remove)
      also have "… = {# (x, y) ∈# B. x ∈ A#}" by(simp add: insert.IH)
      also have "… = {# (x', y) ∈# B. x' ∈ insert x A#}" using False
        by(auto 4 3 intro!: multiset_eqI simp add: count_eq_zero_iff intro: rev_image_eqI)
      finally show ?thesis .
    qed
  qed simp

  have "(∫+ y. ennreal_of_nnreal (sum_list (map snd [(x', ya)←xs . x' = y])) * ennreal_of_nnreal (f y x) ∂count_space UNIV) = 
        (∫+ y. ennreal_of_nnreal (sum_list (map snd [(x', ya)←xs . x' = y])) * ennreal_of_nnreal (f y x) ∂count_space (fst ` set xs))"
    (is "?lhs = _")
    apply(clarsimp simp add: nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator)
    apply(subst filter_False; auto intro: rev_image_eqI)
    done
  also have "… = ennreal_of_nnreal (∑a∈set (map fst xs). sum_list (map snd [(x', ya)←xs . x' = a]) * f a x)"
    by(simp add: nn_integral_count_space_finite)
  also have "(∑a∈set (map fst xs). sum_list (map snd [(x', ya)←xs . x' = a]) * f a x) = 
    (sum_list (map sum_list (map (λa. map (λ(a, y). f a x * y) [(x', y)←xs. x' = a]) (remdups (map fst xs)))))"
    unfolding sum_code 
    by(simp add: o_def sum_list_const_mult[symmetric] split_def mult.commute cong: list.map_cong_simp)
  also have "… = sum_list (map (λ(a, y). f a x * y) (concat (map (λa. [(x', y)←xs. x' = a]) (remdups (map fst xs)))))"
    unfolding map_concat by(simp add: sum_list_concat o_def)
  also have "… = sum_list (map (λ(a, y). f a x * y) xs)" using inj
    apply(intro sum_list_cong)
    apply(simp add: mset_remdups image_mset_mset_set o_def)
    apply(rule arg_cong2[where f="image_mset"])
    apply(auto simp add: mset_filter * filter_mset_eq_conv intro: rev_image_eqI)
    done
  also have "… = foldr (λ(x, p) g xa. f x xa * p + g xa) xs (λ_. 0) x"
    unfolding sum_list.eq_foldr by(induction xs) auto
  finally show "nnreal_of_ennreal ?lhs = …" by simp
qed

definition map_pmf_impl :: "('a ⇒ 'b) ⇒ 'a pmf_impl ⇒ 'b pmf_impl"
  where "map_pmf_impl f p = bind_pmf_impl p (λx. return_pmf_impl (f x))"
  
lemma map_pmf_code [code]: "rep_pmf (map_pmf f p) = map_pmf_impl f (rep_pmf p)"
by(simp add: map_pmf_def bind_pmf_code map_pmf_impl_def return_pmf_code)

lemma map_pmf_impl_code [code]:
  "map_pmf_impl f (pmf_impl_of_pfp xs) = pmf_impl_of_pfp (map (apfst f) xs)"
  unfolding map_pmf_impl_def
  apply(induction xs)
   apply(simp add: bind_pmf_impl_code return_pmf_impl_code scale_pmf_impl_alt zero_pmf_impl_code del: pmf_impl_of_pfp_Cons pmf_impl_of_pfp_Nil)
  apply clarsimp
  apply(simp add: bind_pmf_impl_code return_pmf_impl_code scale_pmf_impl_alt plus_pmf_impl_code del: pmf_impl_of_pfp_Cons pmf_impl_of_pfp_Nil)
  done
  
end