theory PMF_Code imports PMF_Impl begin lemma type_definition_parametric [transfer_rule]: includes lifting_syntax assumes [transfer_rule]: "bi_unique A" "bi_total A" "bi_unique B" shows "((A ===> B) ===> (B ===> A) ===> rel_set B ===> op =) type_definition type_definition" unfolding type_definition_def by(fold Ball_def) transfer_prover lemma apsnd_parametric [transfer_rule]: includes lifting_syntax shows "((A ===> B) ===> rel_prod C A ===> rel_prod C B) apsnd apsnd" unfolding apsnd_def by transfer_prover lemma (in comm_monoid_add) sum_mset_mset: "sum_mset (mset xs) = sum_list xs" by(induction xs)(simp_all add: add.commute) lemma (in comm_monoid_add) sum_list_cong: "mset xs = mset ys ⟹ sum_list xs = sum_list ys" by(simp add: sum_mset_mset[symmetric]) lemma (in monoid_add) sum_list_concat: "sum_list (concat xs) = sum_list (map sum_list xs)" by(induction xs) simp_all definition mUnion :: "'a multiset multiset ⇒ 'a multiset" where "mUnion = fold_mset (op +) {#}" context begin interpretation mUnion: comp_fun_commute "op + :: 'a multiset ⇒ _" by unfold_locales(simp add: add_ac fun_eq_iff) lemma mUnion_empty [simp]: "mUnion {#} = {#}" by(simp add: mUnion_def) lemma mUnion_single [simp]: "mUnion {# A #} = A" by(simp add: mUnion_def) lemma mUnion_union [simp]: "mUnion (A + B) = mUnion A + mUnion B" unfolding mUnion_def by(induction B)(simp_all add: add_ac) lemma mUnion_add_mset [simp]: "mUnion (add_mset x A) = x + (mUnion A)" by(simp add: mUnion_def) end lemma image_mset_mUnion: "image_mset f (mUnion A) = mUnion (image_mset (image_mset f) A)" by(induction A) auto lemma count_mUnion [simp]: "count (mUnion A) x = sum_mset (image_mset (λB. count B x) A)" by(induction A) simp_all lemma mset_concat [simp]: "mset (concat xss) = mUnion (mset (map mset xss))" by(induction xss)(simp_all add: add_ac) lemma mset_distinct: "distinct xs ⟹ mset xs = mset_set (set xs)" by(induction xs)(auto simp add: add_ac) lemma image_mset_mset_set: assumes inj: "inj_on f A" and "finite A" shows "image_mset f (mset_set A) = mset_set (f ` A)" proof - def B ≡ A then have "finite B" "B ⊆ A" using ‹finite A› by simp_all then show "image_mset f (mset_set B) = mset_set (f ` B)" proof(induction) case (insert x F) have "f x ∉ f ` F" using ‹x ∉ F› insert.prems by(auto dest: inj_onD[OF inj]) thus ?case using insert by simp qed simp qed lemma filter_mset_False [simp]: "filter_mset (λ_. False) A = {#}" by(induction A) simp_all lemma mset_remdups: "mset (remdups xs) = mset_set (set xs)" by(subst mset_distinct; simp) text ‹Non-negative real numbers› typedef nnreal = "{x :: real. x ≥ 0}" morphisms real_of_nnreal nnreal' by auto setup_lifting type_definition_nnreal lift_definition nnreal :: "real ⇒ nnreal" is "sup 0" by(rule sup_ge1) lemma nnreal_cases [cases type]: obtains y where "x = nnreal y" "y ≥ 0" by transfer (simp add: sup.absorb_iff2) lift_definition pmf_nnr :: "'a pmf ⇒ 'a ⇒ nnreal" is pmf by(rule pmf_nonneg) lift_definition embed_pmf_nnr :: "('a ⇒ nnreal) ⇒ 'a pmf" is embed_pmf . lift_definition ennreal_of_nnreal :: "nnreal ⇒ ennreal" is ennreal . lemma ennreal_of_nnreal_nnreal [simp]: "ennreal_of_nnreal (nnreal x) = ennreal x" by transfer (metis ennreal_eq_0_iff le_iff_sup linear sup.orderE) lemma ennreal_of_nnreal_inject [simp]: "ennreal_of_nnreal x = ennreal_of_nnreal y ⟷ x = y" by transfer simp instantiation nnreal :: "{zero, one, plus, times, comm_monoid_add, comm_monoid_mult, comm_semiring_0}" begin lift_definition zero_nnreal :: nnreal is 0 by simp lift_definition one_nnreal :: nnreal is 1 by simp lift_definition plus_nnreal :: "nnreal ⇒ nnreal ⇒ nnreal" is "op +" by simp lift_definition times_nnreal :: "nnreal ⇒ nnreal ⇒ nnreal" is "op *" by simp instance by(intro_classes; transfer; simp add: algebra_simps)+ end lemma ennreal_of_nnreal_0 [simp]: "ennreal_of_nnreal 0 = 0" by transfer simp lemma ennreal_of_nnreal_eq_0 [simp]: "ennreal_of_nnreal x = 0 ⟷ x = 0" by transfer simp context includes ennreal.lifting begin lift_definition nnreal_of_ennreal :: "ennreal ⇒ nnreal" is real_of_ereal by(rule real_of_ereal_pos) lemma ennreal_of_nnreal_inverse [simp]: "nnreal_of_ennreal (ennreal_of_nnreal x) = x" apply(cases x) apply clarsimp apply transfer apply(clarsimp simp add: ereal_max_0 ereal_max[symmetric] sup_real_def simp del: ereal_max) done lemma nnreal_of_ennreal_0 [simp]: "nnreal_of_ennreal 0 = 0" by transfer simp end lemma ennreal_of_nnreal_plus [simp]: "ennreal_of_nnreal x + ennreal_of_nnreal y = ennreal_of_nnreal (x + y)" by transfer simp lemma ennreal_of_nnreal_times [simp]: "ennreal_of_nnreal x * ennreal_of_nnreal y = ennreal_of_nnreal (x * y)" by transfer(simp add: ennreal_mult) lemma setsum_ennreal_of_nnreal [simp]: "sum (%x. ennreal_of_nnreal (f x)) A = ennreal_of_nnreal (sum f A)" apply(cases "finite A") subgoal by(induction rule: finite_induct; auto) subgoal by simp done instantiation nnreal :: equal begin lift_definition equal_nnreal :: "nnreal ⇒ nnreal ⇒ bool" is "op =" . instance by(intro_classes;transfer;simp) end lifting_update nnreal.lifting lifting_forget nnreal.lifting typedef 'a pmf_impl = "UNIV :: ('a ⇒ nnreal) set" .. setup_lifting type_definition_pmf_impl lift_definition rep_pmf :: "'a pmf ⇒ 'a pmf_impl" is pmf_nnr . lift_definition abs_pmf :: "'a pmf_impl ⇒ 'a pmf" is embed_pmf_nnr . lift_definition wf_pmf :: "'a pmf_impl ⇒ bool" is "λf. (∫⇧+ x. ennreal_of_nnreal (f x) ∂count_space UNIV) = 1" . lemma td_pmf_code: "type_definition rep_pmf abs_pmf {f. wf_pmf f}" apply transfer including nnreal.lifting apply transfer apply(simp add: conj_commute td_pmf_embed_pmf) done lemma pmf_code_cert [code abstype]: "abs_pmf (rep_pmf p) = p" by(rule type_definition.Rep_inverse[OF td_pmf_code]) lift_definition return_pmf_impl :: "'a ⇒ 'a pmf_impl" is "λx y. if x = y then 1 else 0" . lemma return_pmf_code [code abstract]: "rep_pmf (return_pmf x) = return_pmf_impl x" apply transfer including nnreal.lifting apply transfer apply(simp add: fun_eq_iff) done lift_definition bind_pmf_impl :: "'a pmf_impl ⇒ ('a ⇒ 'b pmf_impl) ⇒ 'b pmf_impl" is "λf g x. nnreal_of_ennreal (∫⇧+ y. ennreal_of_nnreal (f y) * ennreal_of_nnreal (g y x) ∂count_space UNIV)" . lemma bind_pmf_code [code abstract]: "rep_pmf (bind_pmf p f) = bind_pmf_impl (rep_pmf p) (λx. rep_pmf (f x))" proof(transfer; rule ext) fix p and f :: "'b ⇒ 'a pmf" and x have "ennreal_of_nnreal (pmf_nnr (p ⤜ f) x) = (∫⇧+ y. ennreal_of_nnreal (pmf_nnr p y) * ennreal_of_nnreal (pmf_nnr (f y) x) ∂count_space UNIV)" (is "ennreal_of_nnreal ?lhs = ?rhs") including nnreal.lifting by transfer(simp add: ennreal_pmf_bind nn_integral_measure_pmf) from this[symmetric] show "?lhs = nnreal_of_ennreal ?rhs" by simp qed text ‹Implementation of @{typ "'a pmf_impl"} as associative lists› lift_definition pmf_impl_of_pfp :: "('a × nnreal) list ⇒ 'a pmf_impl" is "λxs x. sum_list (map snd [(x', y :: nnreal) ← xs. x' = x])" . code_datatype pmf_impl_of_pfp lift_definition zero_pmf_impl :: "'a pmf_impl" is "λ_. 0" . lemma zero_pmf_impl_code [code]: "zero_pmf_impl = pmf_impl_of_pfp []" by transfer simp lemma pmf_impl_of_pfp_Nil [simp]: "pmf_impl_of_pfp [] = zero_pmf_impl" by transfer simp lift_definition scale_pmf_impl :: "nnreal ⇒ 'a pmf_impl ⇒ 'a pmf_impl" is "λ(r :: nnreal) f x. f x * r" . lemma scale_pmf_impl_code [code]: "scale_pmf_impl p (pmf_impl_of_pfp xs) = pmf_impl_of_pfp (if p = 0 then [] else map (apsnd (op * p)) xs)" apply transfer including nnreal.lifting apply transfer apply(auto simp add: filter_map o_def split_def mult_ac sum_list_const_mult) done lemma scale_pmf_impl_alt: "scale_pmf_impl p (pmf_impl_of_pfp xs) = pmf_impl_of_pfp (map (apsnd (op * p)) xs)" apply transfer including nnreal.lifting apply transfer apply(auto simp add: filter_map o_def split_def mult_ac sum_list_const_mult) done lift_definition plus_pmf_impl :: "'a pmf_impl ⇒ 'a pmf_impl ⇒ 'a pmf_impl" is "λf g x. f x + g x" . lemma plus_pmf_impl_code [code]: "plus_pmf_impl (pmf_impl_of_pfp xs) (pmf_impl_of_pfp ys) = pmf_impl_of_pfp (xs @ ys)" by transfer auto lemma plus_pmf_impl_zero [simp]: "plus_pmf_impl zero_pmf_impl p = p" by transfer simp lemma plus_pmf_impl_assoc [simp]: "plus_pmf_impl (plus_pmf_impl p q) r = plus_pmf_impl p (plus_pmf_impl q r)" by transfer(simp add: add_ac) lemma pmf_impl_of_pfp_Cons [simp]: "pmf_impl_of_pfp ((x, p) # xs) = plus_pmf_impl (scale_pmf_impl p (return_pmf_impl x)) (pmf_impl_of_pfp xs)" by transfer(simp add: fun_eq_iff) lemma pmf_impl_of_pfp_append [simp]: "pmf_impl_of_pfp (xs @ ys) = plus_pmf_impl (pmf_impl_of_pfp xs) (pmf_impl_of_pfp ys)" by(induction xs) auto lemma return_pmf_impl_code [code]: "return_pmf_impl x = pmf_impl_of_pfp [(x, 1)]" by transfer(simp add: fun_eq_iff) lemma bind_pmf_impl_code [code]: "bind_pmf_impl (pmf_impl_of_pfp xs) f = foldr (λ(x, p). plus_pmf_impl (scale_pmf_impl p (f x))) xs zero_pmf_impl" proof(transfer; rule ext) fix xs :: "('b × nnreal) list" and f :: "'b ⇒ 'a ⇒ nnreal" and x have inj: "inj_on (mset ∘ (λa. [(x', y)←xs . x' = a])) (fst ` set xs)" by(rule inj_onI)(auto 4 3 simp add: mset_filter dest!: arg_cong[where f=set_mset]) have *: "mUnion (mset_set ((λx. {# (x', y) ∈# B. x' = x#}) ` A)) = {# (x, y) ∈# B. x ∈ A #}" if "finite A" for A :: "'c set" and B :: "('c × 'd) multiset" using that proof(induction A) case (insert x A) show ?case proof(cases "x ∈ fst ` set_mset B") case True then have "{# (x', y) ∈# B. x' = x#} ∉ (λx. {# (x', y) ∈# B. x' = x#}) ` A" using ‹x ∉ A› by(auto 4 3 dest!: arg_cong[where f="set_mset"] elim: equalityE) then show ?thesis using insert by(auto intro: multiset_eqI) next case False hence "{# (x', y) ∈# B. x' = x#} = {#}" by(auto 4 3 intro!: multiset_eqI simp add: count_eq_zero_iff intro: rev_image_eqI) then have "mUnion (mset_set ((λx. {# a ∈# B. case a of (x', y) ⇒ x' = x#}) ` insert x A)) = mUnion (mset_set ((λx. {# a ∈# B. case a of (x', y) ⇒ x' = x#}) ` A))" using ‹finite A› by(simp add: mset_set.insert_remove)(metis (no_types, lifting) Diff_empty Diff_insert0 add.left_neutral finite_imageI mUnion_add_mset mset_set.remove) also have "… = {# (x, y) ∈# B. x ∈ A#}" by(simp add: insert.IH) also have "… = {# (x', y) ∈# B. x' ∈ insert x A#}" using False by(auto 4 3 intro!: multiset_eqI simp add: count_eq_zero_iff intro: rev_image_eqI) finally show ?thesis . qed qed simp have "(∫⇧+ y. ennreal_of_nnreal (sum_list (map snd [(x', ya)←xs . x' = y])) * ennreal_of_nnreal (f y x) ∂count_space UNIV) = (∫⇧+ y. ennreal_of_nnreal (sum_list (map snd [(x', ya)←xs . x' = y])) * ennreal_of_nnreal (f y x) ∂count_space (fst ` set xs))" (is "?lhs = _") apply(clarsimp simp add: nn_integral_count_space_indicator intro!: nn_integral_cong split: split_indicator) apply(subst filter_False; auto intro: rev_image_eqI) done also have "… = ennreal_of_nnreal (∑a∈set (map fst xs). sum_list (map snd [(x', ya)←xs . x' = a]) * f a x)" by(simp add: nn_integral_count_space_finite) also have "(∑a∈set (map fst xs). sum_list (map snd [(x', ya)←xs . x' = a]) * f a x) = (sum_list (map sum_list (map (λa. map (λ(a, y). f a x * y) [(x', y)←xs. x' = a]) (remdups (map fst xs)))))" unfolding sum_code by(simp add: o_def sum_list_const_mult[symmetric] split_def mult.commute cong: list.map_cong_simp) also have "… = sum_list (map (λ(a, y). f a x * y) (concat (map (λa. [(x', y)←xs. x' = a]) (remdups (map fst xs)))))" unfolding map_concat by(simp add: sum_list_concat o_def) also have "… = sum_list (map (λ(a, y). f a x * y) xs)" using inj apply(intro sum_list_cong) apply(simp add: mset_remdups image_mset_mset_set o_def) apply(rule arg_cong2[where f="image_mset"]) apply(auto simp add: mset_filter * filter_mset_eq_conv intro: rev_image_eqI) done also have "… = foldr (λ(x, p) g xa. f x xa * p + g xa) xs (λ_. 0) x" unfolding sum_list.eq_foldr by(induction xs) auto finally show "nnreal_of_ennreal ?lhs = …" by simp qed definition map_pmf_impl :: "('a ⇒ 'b) ⇒ 'a pmf_impl ⇒ 'b pmf_impl" where "map_pmf_impl f p = bind_pmf_impl p (λx. return_pmf_impl (f x))" lemma map_pmf_code [code]: "rep_pmf (map_pmf f p) = map_pmf_impl f (rep_pmf p)" by(simp add: map_pmf_def bind_pmf_code map_pmf_impl_def return_pmf_code) lemma map_pmf_impl_code [code]: "map_pmf_impl f (pmf_impl_of_pfp xs) = pmf_impl_of_pfp (map (apfst f) xs)" unfolding map_pmf_impl_def apply(induction xs) apply(simp add: bind_pmf_impl_code return_pmf_impl_code scale_pmf_impl_alt zero_pmf_impl_code del: pmf_impl_of_pfp_Cons pmf_impl_of_pfp_Nil) apply clarsimp apply(simp add: bind_pmf_impl_code return_pmf_impl_code scale_pmf_impl_alt plus_pmf_impl_code del: pmf_impl_of_pfp_Cons pmf_impl_of_pfp_Nil) done end