Theory Misc

theory Misc
imports Complex_Main Eisbach Rewrite Simps_Case_Conv
theory Misc imports 
  Complex_Main 
  "~~/src/HOL/Eisbach/Eisbach"
  "~~/src/HOL/Library/Rewrite"
  "~~/src/HOL/Library/Simps_Case_Conv"
begin


lemma Least_le_Least:
  fixes x :: "'a :: wellorder"
  assumes "Q x"
  and Q: "⋀x. Q x ⟹ ∃y≤x. P y"
  shows "Least P ≤ Least Q"
proof -
  obtain f :: "'a ⇒ 'a" where "∀a. ¬ Q a ∨ f a ≤ a ∧ P (f a)" using Q by moura
  moreover have "Q (Least Q)" using ‹Q x› by(rule LeastI)
  ultimately show ?thesis by (metis (full_types) le_cases le_less less_le_trans not_less_Least)
qed

subsection ‹Transfer rules›

context includes lifting_syntax begin

lemma monotone_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total A"
  shows "((A ===> A ===> op =) ===> (B ===> B ===> op =) ===> (A ===> B) ===> op =) monotone monotone"
unfolding monotone_def[abs_def] by transfer_prover

lemma fun_ord_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total C"
  shows "((A ===> B ===> op =) ===> (C ===> A) ===> (C ===> B) ===> op =) fun_ord fun_ord"
unfolding fun_ord_def[abs_def] by transfer_prover

lemma fun_lub_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total A"  "bi_unique A"
  shows "((rel_set A ===> B) ===> rel_set (C ===> A) ===> C ===> B) fun_lub fun_lub"
unfolding fun_lub_def[abs_def] by transfer_prover

lemma Ex1_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_unique A" "bi_total A"
  shows "((A ===> op =) ===> op =) Ex1 Ex1"
unfolding Ex1_def[abs_def] by transfer_prover

end

subsection ‹Algebraic stuff›

lemma abs_diff_triangle_ineq2: "¦a - b :: _ :: ordered_ab_group_add_abs¦ ≤ ¦a - c¦ + ¦c - b¦"
by(rule order_trans[OF _ abs_diff_triangle_ineq]) simp

lemma (in ordered_ab_semigroup_add) add_left_mono_trans:
  "⟦ x ≤ a + b; b ≤ c ⟧ ⟹ x ≤ a + c"
by(erule order_trans)(rule add_left_mono)

lemma of_nat_le_one_cancel_iff [simp]:
  fixes n :: nat shows "real n ≤ 1 ⟷ n ≤ 1"
by linarith

lemma (in linordered_semidom) mult_right_le: "c ≤ 1 ⟹ 0 ≤ a ⟹ c * a ≤ a"
by(subst mult.commute)(rule mult_left_le)


subsection ‹Option›

declare is_none_bind [simp]

lemma None_in_map_option_image [simp]: "None ∈ map_option f ` A ⟷ None ∈ A"
by auto

lemma Some_in_map_option_image [simp]: "Some x ∈ map_option f ` A ⟷ (∃y. x = f y ∧ Some y ∈ A)"
by(auto intro: rev_image_eqI dest: sym)

lemma case_option_collapse: "case_option x (λ_. x) y = x"
by(simp split: option.split)

lemma case_option_id: "case_option None Some = id"
by(rule ext)(simp split: option.split)

subsubsection {* Orders on option *}

inductive ord_option :: "('a ⇒ 'b ⇒ bool) ⇒ 'a option ⇒ 'b option ⇒ bool"
  for ord :: "'a ⇒ 'b ⇒ bool"
where
  None: "ord_option ord None x"
| Some: "ord x y ⟹ ord_option ord (Some x) (Some y)"

inductive_simps ord_option_simps [simp]:
  "ord_option ord None x"
  "ord_option ord x None"
  "ord_option ord (Some x) (Some y)"
  "ord_option ord (Some x) None"

inductive_simps ord_option_eq_simps [simp]:
  "ord_option op = None y"
  "ord_option op = (Some x) y"

lemma ord_option_reflI: "(⋀y. y ∈ set_option x ⟹ ord y y) ⟹ ord_option ord x x"
by(cases x) simp_all

lemma reflp_ord_option: "reflp ord ⟹ reflp (ord_option ord)"
by(simp add: reflp_def ord_option_reflI)

lemma ord_option_trans:
  "⟦ ord_option ord x y; ord_option ord y z;
    ⋀a b c. ⟦ a ∈ set_option x; b ∈ set_option y; c ∈ set_option z; ord a b; ord b c ⟧ ⟹ ord a c ⟧
  ⟹ ord_option ord x z"
by(auto elim!: ord_option.cases)

lemma transp_ord_option: "transp ord ⟹ transp (ord_option ord)"
unfolding transp_def by(blast intro: ord_option_trans)

lemma antisymP_ord_option: "antisymP ord ⟹ antisymP (ord_option ord)"
by(auto intro!: antisymI elim!: ord_option.cases dest: antisymD)

lemma ord_option_chainD:
  "Complete_Partial_Order.chain (ord_option ord) Y
  ⟹ Complete_Partial_Order.chain ord {x. Some x ∈ Y}"
by(rule chainI)(auto dest: chainD)

definition lub_option :: "('a set ⇒ 'b) ⇒ 'a option set ⇒ 'b option"
where "lub_option lub Y = (if Y ⊆ {None} then None else Some (lub {x. Some x ∈ Y}))"

lemma map_lub_option: "map_option f (lub_option lub Y) = lub_option (f ∘ lub) Y"
by(simp add: lub_option_def)

lemma lub_option_upper:
  assumes "Complete_Partial_Order.chain (ord_option ord) Y" "x ∈ Y"
  and lub_upper: "⋀Y x. ⟦ Complete_Partial_Order.chain ord Y; x ∈ Y ⟧ ⟹ ord x (lub Y)"
  shows "ord_option ord x (lub_option lub Y)"
using assms(1-2)
by(cases x)(auto simp add: lub_option_def intro: lub_upper[OF ord_option_chainD])

lemma lub_option_least:
  assumes Y: "Complete_Partial_Order.chain (ord_option ord) Y"
  and upper: "⋀x. x ∈ Y ⟹ ord_option ord x y"
  assumes lub_least: "⋀Y y. ⟦ Complete_Partial_Order.chain ord Y; ⋀x. x ∈ Y ⟹ ord x y ⟧ ⟹ ord (lub Y) y"
  shows "ord_option ord (lub_option lub Y) y"
using Y
by(cases y)(auto 4 3 simp add: lub_option_def intro: lub_least[OF ord_option_chainD] dest: upper)

lemma lub_map_option: "lub_option lub (map_option f ` Y) = lub_option (lub ∘ op ` f) Y"
apply(auto simp add: lub_option_def)
apply(erule notE)
apply(rule arg_cong[where f=lub])
apply(auto intro: rev_image_eqI dest: sym)
done

lemma ord_option_mono: "⟦ ord_option A x y; ⋀x y. A x y ⟹ B x y ⟧ ⟹ ord_option B x y"
by(auto elim: ord_option.cases)

lemma ord_option_mono' [mono]:
  "(⋀x y. A x y ⟶ B x y) ⟹ ord_option A x y ⟶ ord_option B x y"
by(blast intro: ord_option_mono)

lemma ord_option_compp: "ord_option (A OO B) = ord_option A OO ord_option B"
by(auto simp add: fun_eq_iff elim!: ord_option.cases intro: ord_option.intros)

lemma ord_option_inf: "inf (ord_option A) (ord_option B) = ord_option (inf A B)" (is "?lhs = ?rhs")
proof(rule antisym)
  show "?lhs ≤ ?rhs" by(auto elim!: ord_option.cases)
qed(auto elim: ord_option_mono)

lemma ord_option_map2: "ord_option ord x (map_option f y) = ord_option (λx y. ord x (f y)) x y"
by(auto elim: ord_option.cases)

lemma ord_option_map1: "ord_option ord (map_option f x) y = ord_option (λx y. ord (f x) y) x y"
by(auto elim: ord_option.cases)

lemma option_ord_Some1_iff: "option_ord (Some x) y ⟷ y = Some x"
by(auto simp add: flat_ord_def)



text ‹
  We need a polymorphic constant for the empty map such that @{text "transfer_prover"}
  can use a custom transfer rule for @{const Map.empty}
›
definition Map_empty where [simp]: "Map_empty ≡ Map.empty"

subsubsection ‹filter for option›

fun filter_option :: "('a ⇒ bool) ⇒ 'a option ⇒ 'a option"
where
  "filter_option P None = None"
| "filter_option P (Some x) = (if P x then Some x else None)"

lemma set_filter_option [simp]: "set_option (filter_option P x) = {y ∈ set_option x. P y}"
by(cases x) auto

lemma filter_map_option: "filter_option P (map_option f x) = map_option f (filter_option (P ∘ f) x)"
by(cases x) simp_all

lemma is_none_filter_option [simp]: "Option.is_none (filter_option P x) ⟷ Option.is_none x ∨ ¬ P (the x)"
by(cases x) simp_all

lemma filter_option_eq_Some_iff [simp]: "filter_option P x = Some y ⟷ x = Some y ∧ P y"
by(cases x) auto

lemma Some_eq_filter_option_iff [simp]: "Some y = filter_option P x ⟷ x = Some y ∧ P y"
by(cases x) auto

lemma filter_conv_bind_option: "filter_option P x = Option.bind x (λy. if P y then Some y else None)"
by(cases x) simp_all


subsubsection ‹map for the combination of set and option›

fun map_option_set :: "('a ⇒ 'b option set) ⇒ 'a option ⇒ 'b option set"
where
  "map_option_set f None = {None}"
| "map_option_set f (Some x) = f x"

lemma None_in_map_option_set:
  "None ∈ map_option_set f x ⟷ None ∈ Set.bind (set_option x) f ∨ x = None"
by(cases x) simp_all

lemma None_in_map_option_set_None [intro!]: "None ∈ map_option_set f None"
by simp

lemma None_in_map_option_set_Some [intro!]: "None ∈ f x ⟹ None ∈ map_option_set f (Some x)"
by simp

lemma Some_in_map_option_set [intro!]: "Some y ∈ f x ⟹ Some y ∈ map_option_set f (Some x)"
by simp

lemma map_option_set_singleton [simp]: "map_option_set (λx. {f x}) y = {Option.bind y f}"
by(cases y) simp_all

lemma Some_eq_bind_conv: "Some y = Option.bind x f ⟷ (∃z. x = Some z ∧ f z = Some y)"
by(cases x) auto

lemma map_option_set_bind: "map_option_set f (Option.bind x g) = map_option_set (map_option_set f ∘ g) x"
by(cases x) simp_all

lemma Some_in_map_option_set_conv: "Some y ∈ map_option_set f x ⟷ (∃z. x = Some z ∧ Some y ∈ f z)"
by(cases x) auto

subsubsection {* join on options *}

definition join_option :: "'a option option ⇒ 'a option"
where "join_option x = (case x of Some y ⇒ y | None ⇒ None)"

simps_of_case join_simps [simp, code]: join_option_def

lemma set_join_option [simp]: "set_option (join_option x) = ⋃(set_option ` set_option x)"
by(cases x)(simp_all)

lemma in_set_join_option: "x ∈ set_option (join_option (Some (Some x)))"
by simp

lemma map_join_option: "map_option f (join_option x) = join_option (map_option (map_option f) x)"
by(cases x) simp_all

lemma bind_conv_join_option: "Option.bind x f = join_option (map_option f x)"
by(cases x) simp_all

lemma join_conv_bind_option: "join_option x = Option.bind x id"
by(cases x) simp_all

context includes lifting_syntax begin
lemma join_option_parametric [transfer_rule]:
  "(rel_option (rel_option R) ===> rel_option R) join_option join_option"
unfolding join_conv_bind_option[abs_def] by transfer_prover
end

lemma join_option_eq_Some [simp]: "join_option x = Some y ⟷ x = Some (Some y)"
by(cases x) simp_all

lemma Some_eq_join_option [simp]: "Some y = join_option x ⟷ x = Some (Some y)"
by(cases x) auto

lemma join_option_eq_None: "join_option x = None ⟷ x = None ∨ x = Some None"
by(cases x) simp_all

lemma None_eq_join_option: "None = join_option x ⟷ x = None ∨ x = Some None"
by(cases x) auto

subsubsection ‹Zip on options›

function zip_option :: "'a option ⇒ 'b option ⇒ ('a × 'b) option"
where
  "zip_option (Some x) (Some y) = Some (x, y)"
| "zip_option _ None = None"
| "zip_option None _ = None"
by pat_completeness auto
termination by lexicographic_order

lemma zip_option_eq_Some_iff [iff]:
  "zip_option x y = Some (a, b) ⟷ x = Some a ∧ y = Some b"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma set_zip_option [simp]:
  "set_option (zip_option x y) = set_option x × set_option y"
by auto

lemma zip_map_option1: "zip_option (map_option f x) y = map_option (apfst f) (zip_option x y)"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma zip_map_option2: "zip_option x (map_option g y) = map_option (apsnd g) (zip_option x y)"
by(cases "(x, y)" rule: zip_option.cases) simp_all

lemma map_zip_option:
  "map_option (map_prod f g) (zip_option x y) = zip_option (map_option f x) (map_option g y)"
by(simp add: zip_map_option1 zip_map_option2 option.map_comp apfst_def apsnd_def o_def prod.map_comp)

lemma zip_conv_bind_option:
  "zip_option x y = Option.bind x (λx. Option.bind y (λy. Some (x, y)))"
by(cases "(x, y)" rule: zip_option.cases) simp_all

context includes lifting_syntax begin
lemma zip_option_parametric [transfer_rule]:
  "(rel_option R ===> rel_option Q ===> rel_option (rel_prod R Q)) zip_option zip_option"
unfolding zip_conv_bind_option[abs_def] by transfer_prover
end





subsection ‹Predicator for function relator with relation @{term "op ="} on the domain›

definition fun_eq_pred :: "('b ⇒ bool) ⇒ ('a ⇒ 'b) ⇒ bool"
where "fun_eq_pred P f ⟷ (∀x. P (f x))"

context includes lifting_syntax begin

lemma fun_rel_eq_OO: "(op = ===> A) OO (op = ===> B) = (op = ===> A OO B)"
by(clarsimp simp add: rel_fun_def fun_eq_iff relcompp.simps) metis

lemma Domainp_fun_rel_eq:
  "Domainp (op = ===> A) = fun_eq_pred (Domainp A)"
by(simp add: rel_fun_def Domainp.simps fun_eq_pred_def fun_eq_iff) metis

lemma Domainp_eq: "Domainp op = = (λ_. True)"
by(simp add: Domainp.simps fun_eq_iff)

lemma invariant_True: "eq_onp (λ_. True) = op ="
by(simp add: eq_onp_def)

lemma fun_eq_invariant [relator_eq_onp]:
  "op = ===> eq_onp P = eq_onp (fun_eq_pred P)"
by(simp add: rel_fun_def eq_onp_def fun_eq_iff fun_eq_pred_def) metis

lemma Quotient_set_rel_eq:
  assumes "Quotient R Abs Rep T"
  shows "(rel_set T ===> rel_set T ===> op =) (rel_set R) op ="
proof(rule rel_funI iffI)+
  fix A B C D
  assume AB: "rel_set T A B" and CD: "rel_set T C D"
  have *: "⋀x y. R x y = (T x (Abs x) ∧ T y (Abs y) ∧ Abs x = Abs y)"
    "⋀a b. T a b ⟹ Abs a = b"
    using assms unfolding Quotient_alt_def by simp_all

  { assume [simp]: "B = D"
    thus "rel_set R A C"
      by(auto 4 4 intro!: rel_setI dest: rel_setD1[OF AB, simplified] rel_setD2[OF AB, simplified] rel_setD2[OF CD] rel_setD1[OF CD] simp add: * elim!: rev_bexI)
  next
    assume AC: "rel_set R A C"
    show "B = D"
      apply safe
       apply(drule rel_setD2[OF AB], erule bexE)
       apply(drule rel_setD1[OF AC], erule bexE)
       apply(drule rel_setD1[OF CD], erule bexE)
       apply(simp add: *)
      apply(drule rel_setD2[OF CD], erule bexE)
      apply(drule rel_setD2[OF AC], erule bexE)
      apply(drule rel_setD1[OF AB], erule bexE)
      apply(simp add: *)
      done
  }
qed

lemma fun_eq_pred_parametric [transfer_rule]:
  assumes [transfer_rule]: "bi_total A"
  shows "((B ===> op =) ===> (A ===> B) ===> op =) fun_eq_pred fun_eq_pred"
unfolding fun_eq_pred_def[abs_def] by transfer_prover

end


subsection ‹A relator for sets that treats sets like predicates›

context includes lifting_syntax begin
definition rel_pred :: "('a ⇒ 'b ⇒ bool) ⇒ 'a set ⇒ 'b set ⇒ bool"
where "rel_pred R A B = (R ===> op =) (λx. x ∈ A) (λy. y ∈ B)"

lemma rel_predI: "(R ===> op =) (λx. x ∈ A) (λy. y ∈ B) ⟹ rel_pred R A B"
by(simp add: rel_pred_def)

lemma rel_predD: "⟦ rel_pred R A B; R x y ⟧ ⟹ x ∈ A ⟷ y ∈ B"
by(simp add: rel_pred_def rel_fun_def)

lemma Collect_parametric [transfer_rule]: "((A ===> op =) ===> rel_pred A) Collect Collect"
by(simp add: rel_funI rel_predI)
end


subsection ‹Parametrisation of transfer rules›

ML ‹ (* from lifting_term.ML *)
exception QUOT_THM_INTERNAL of Pretty.T

fun get_rel_distr_rules ctxt s tm =
  let
    val thy = Proof_Context.theory_of ctxt
  in
    (case Lifting_Info.lookup_relator_distr_data ctxt s of
      SOME rel_distr_thm => (
        case tm of
          Const (@{const_name Lifting.POS}, _) => map (Thm.transfer thy) (#pos_distr_rules rel_distr_thm)
          | Const (@{const_name Lifting.NEG}, _) => map (Thm.transfer thy) (#neg_distr_rules rel_distr_thm)
      )
    | NONE => raise QUOT_THM_INTERNAL (Pretty.block 
      [Pretty.str ("No relator distr. data for the type " ^ quote s), 
       Pretty.brk 1,
       Pretty.str "found."]))
  end

fun get_quot_data (quotients: Lifting_Term.quotients) s =
  case Symtab.lookup quotients s of
    SOME qdata => qdata
  | NONE => raise QUOT_THM_INTERNAL (Pretty.block 
    [Pretty.str ("No quotient type " ^ quote s), 
     Pretty.brk 1, 
     Pretty.str "found."])

fun get_pcrel_info quotients s =
  case #pcr_info (get_quot_data quotients s) of
    SOME pcr_info => pcr_info
  | NONE => raise QUOT_THM_INTERNAL (Pretty.block 
    [Pretty.str ("No parametrized correspondce relation for " ^ quote s), 
     Pretty.brk 1, 
     Pretty.str "found."])

fun get_pcrel_def quotients ctxt s =
  let
    val thy = Proof_Context.theory_of ctxt
  in
    Thm.transfer thy (#pcrel_def (get_pcrel_info quotients s))
  end
›

(* TODO:
- check that only pcrel definitions of the right form are registered
- transform op = into op ==
*)
named_theorems pcrel_def "equation collection for parametrized correspondence relations"

ML ‹Lifting_Util.relation_types›

ML ‹
fun last [] = raise List.Empty
  | last [x] = x
  | last (_ :: xs) = last xs

fun pcrel_of_quotients quotients ctxt ((_, T) : string * typ) =
  T
  |> binder_types |> last
  |> dest_Type |> fst
  |> get_pcrel_def quotients ctxt
›

ML ‹
    fun dest_eq (Const (@{const_name Trueprop}, _) $ t) = dest_eq t
      | dest_eq (Const (@{const_name HOL.eq}, _) $ t $ u) = (t, u)
      | dest_eq (Const (@{const_name "Pure.eq"}, _) $ t $ u) = (t, u)
      | dest_eq t = raise TERM ("dest_eq", [t])
›

ML ‹
fun pcrel_of_pcrel_def ctxt ((s, _) : string * typ) =
  let
    val thms = Named_Theorems.get ctxt @{named_theorems pcrel_def}
    fun matches thm = thm |> Thm.concl_of |> Logic.dest_equals |> fst |> head_of |> dest_Const |> fst |> curry op = s
  in
    case find_first matches thms of NONE =>
      raise QUOT_THM_INTERNAL (Pretty.block [Pretty.str "No defining equation for", Pretty.brk 1, Pretty.str (quote s), Pretty.brk 1, Pretty.str "found"])
    | SOME thm => thm
  end
›

ML ‹
fun pcrel_of_all ctxt sT =
  pcrel_of_pcrel_def ctxt sT (* prefer pcrel_def as the pcrel constant is more specific as the type *)
  handle QUOT_THM_INTERNAL _ => pcrel_of_quotients (Lifting_Info.get_quotients ctxt) ctxt sT
›


ML ‹
(* Parametrization *)
local
  fun get_lhs rule = (Thm.dest_fun o Thm.dest_arg o strip_imp_concl o Thm.cprop_of) rule;
  
  fun no_imp _ = raise CTERM ("no implication", []);
  
  infix 0 else_imp

  fun (cv1 else_imp cv2) ct =
    (cv1 ct
      handle THM _ => cv2 ct
        | CTERM _ => cv2 ct
        | TERM _ => cv2 ct
        | TYPE _ => cv2 ct);
  
  fun first_imp cvs = fold_rev (curry op else_imp) cvs no_imp
  
  fun rewr_imp rule ct = 
    let
      val rule1 = Thm.incr_indexes (Thm.maxidx_of_cterm ct + 1) rule;
      val lhs_rule = get_lhs rule1;
      val rule2 = Thm.rename_boundvars (Thm.term_of lhs_rule) (Thm.term_of ct) rule1;
      val lhs_ct = Thm.dest_fun ct
    in
        Thm.instantiate (Thm.match (lhs_rule, lhs_ct)) rule2
          handle Pattern.MATCH => raise CTERM ("rewr_imp", [lhs_rule, lhs_ct])
   end
  
  fun rewrs_imp rules = first_imp (map rewr_imp rules)
in

  fun gen_merge_transfer_relations pcrs ctxt ctm =
    let
      val ctm = Thm.dest_arg ctm
      val tm = Thm.term_of ctm
      val rel = (hd o Lifting_Util.get_args 2) tm
  
      fun same_constants (Const (n1,_)) (Const (n2,_)) = n1 = n2
        | same_constants _ _  = false
      
      fun prove_extra_assms ctxt ctm distr_rule =
        let
          fun prove_assm assm = try (Goal.prove ctxt [] [] (Thm.term_of assm))
            (fn _ => SOLVED' (REPEAT_ALL_NEW (resolve_tac ctxt (Transfer.get_transfer_raw ctxt))) 1)
  
          fun is_POS_or_NEG ctm =
            case (head_of o Thm.term_of o Thm.dest_arg) ctm of
              Const (@{const_name Lifting.POS}, _) => true
              | Const (@{const_name Lifting.NEG}, _) => true
              | _ => false
  
          val inst_distr_rule = rewr_imp distr_rule ctm
          val extra_assms = filter_out is_POS_or_NEG (cprems_of inst_distr_rule)
          val proved_assms = Lifting_Util.map_interrupt prove_assm extra_assms
        in
          Option.map (curry op OF inst_distr_rule) proved_assms
        end
        handle CTERM _ => NONE
  
      fun cannot_merge_error_msg rty' qty tm () = Pretty.block
         [Pretty.str "Rewriting (merging) of this term has failed:",
          Pretty.brk 1,
          Syntax.pretty_term ctxt rel
, Pretty.brk 1, Syntax.pretty_typ ctxt rty', Pretty.brk 1, Syntax.pretty_typ ctxt qty, Pretty.brk 1, Syntax.pretty_term ctxt tm]
  
    in
      case Lifting_Util.get_args 2 rel of
          [Const (@{const_name "HOL.eq"}, _), _] => rewrs_imp @{thms neg_eq_OO pos_eq_OO} ctm
          | [_, Const (@{const_name "HOL.eq"}, _)] => rewrs_imp @{thms neg_OO_eq pos_OO_eq} ctm
          | [_, trans_rel] =>
            let
              val (rty', qty) = (Lifting_Util.relation_types o fastype_of) trans_rel
            in
              if Lifting_Util.same_type_constrs (rty', qty) then
                let
                  val distr_rules = get_rel_distr_rules ctxt ((fst o dest_Type) rty') (head_of tm)
                  val distr_rule = get_first (prove_extra_assms ctxt ctm) distr_rules
                in
                  case distr_rule of
                    NONE => raise Lifting_Term.MERGE_TRANSFER_REL (cannot_merge_error_msg rty' qty tm ())
                    | SOME distr_rule =>  Lifting_Util.MRSL (map (gen_merge_transfer_relations pcrs ctxt) 
                                            (cprems_of distr_rule),
                       distr_rule)
                end
              else
(* We know the constant, so we could use it just as well for lookup *)
                let
                  val pcr = head_of trans_rel |> dest_Const
                  val pcrel_def = pcrs ctxt pcr
                  val pcrel_const = (head_of o fst o Logic.dest_equals o Thm.prop_of) pcrel_def
                in
                  if same_constants pcrel_const (head_of trans_rel) then
                    let
                      val unfolded_ctm = Thm.rhs_of (Conv.arg1_conv (Conv.arg_conv (Conv.rewr_conv pcrel_def)) ctm)
                      val distr_rule = rewrs_imp @{thms POS_pcr_rule NEG_pcr_rule} unfolded_ctm
                      val result = Lifting_Util.MRSL (map (gen_merge_transfer_relations pcrs ctxt) 
                        (cprems_of distr_rule), distr_rule)
                      val fold_pcr_rel = Conv.rewr_conv (Thm.symmetric pcrel_def)
                    in  
                      Conv.fconv_rule (HOLogic.Trueprop_conv (Conv.combination_conv 
                        (Conv.arg_conv (Conv.arg_conv fold_pcr_rel)) fold_pcr_rel)) result
                    end
                  else
                    raise Lifting_Term.MERGE_TRANSFER_REL 
                      (Pretty.block [Pretty.str "Non-parametric correspondence relation used.",
                       Pretty.brk 1,
                       Syntax.pretty_term ctxt pcrel_const,
                       Pretty.brk 1,
                       Syntax.pretty_term ctxt (head_of trans_rel)
                      ]
                      )
                end
            end
    end
    handle QUOT_THM_INTERNAL pretty_msg => raise Lifting_Term.MERGE_TRANSFER_REL pretty_msg

  (*
    ctm - of the form "[POS|NEG] (par_R OO T) t f) ?X", where par_R is a parametricity transfer 
    relation for t and T is a transfer relation between t and f, which consists only from
    parametrized transfer relations (i.e., pcr_?) and equalities op=. POS or NEG encodes
    co-variance or contra-variance.
    
    The function merges par_R OO T using definitions of parametrized correspondence relations
    (e.g., (rel_S R) OO (pcr_T op=) --> pcr_T R using the definition pcr_T R = (rel_S R) OO cr_T).
  *)

  fun merge_transfer_relations ctxt ctm = gen_merge_transfer_relations 
    pcrel_of_all ctxt ctm
end
›


ML ‹ (* from lifting_def.ML *)
  fun generate_parametric_transfer_rule ctxt transfer_rule parametric_transfer_rule =
  let
    fun preprocess ctxt thm =
      let
        val tm = (Lifting_Util.strip_args 2 o HOLogic.dest_Trueprop o Thm.concl_of) thm;
        val param_rel = (snd o dest_comb o fst o dest_comb) tm;
        val free_vars = Term.add_vars param_rel [];
        
        fun make_subst (xi, typ) subst = 
          let
            val [rty, rty'] = binder_types typ
          in
            if Term.is_TVar rty andalso Lifting_Util.is_Type rty' then
              (xi, Thm.cterm_of ctxt (HOLogic.eq_const rty')) :: subst
            else
              subst
          end;

        val inst_thm = infer_instantiate ctxt (fold make_subst free_vars []) thm;
      in
        Conv.fconv_rule 
          ((Conv.concl_conv (Thm.nprems_of inst_thm) o
            HOLogic.Trueprop_conv o Conv.fun2_conv o Conv.arg1_conv)
            (Raw_Simplifier.rewrite ctxt false (Transfer.get_sym_relator_eq ctxt))) inst_thm
      end

    fun inst_relcomppI ctxt ant1 ant2 =
      let
        val t1 = (HOLogic.dest_Trueprop o Thm.concl_of) ant1
        val t2 = (HOLogic.dest_Trueprop o Thm.prop_of) ant2
        val fun1 = Thm.cterm_of ctxt (Lifting_Util.strip_args 2 t1)
        val args1 = map (Thm.cterm_of ctxt) (Lifting_Util.get_args 2 t1)
        val fun2 = Thm.cterm_of ctxt (Lifting_Util.strip_args 2 t2)
        val args2 = map (Thm.cterm_of ctxt) (Lifting_Util.get_args 1 t2)
        val relcomppI = Drule.incr_indexes2 ant1 ant2 @{thm relcomppI}
        val vars = map #1 (rev (Term.add_vars (Thm.prop_of relcomppI) []))
      in
        infer_instantiate ctxt (vars ~~ ([fun1] @ args1 @ [fun2] @ args2)) relcomppI
      end

    fun zip_transfer_rules ctxt thm =
      let
        fun mk_POS ty = Const (@{const_name Lifting.POS}, ty --> ty --> HOLogic.boolT)
        val rel = (Thm.dest_fun2 o Thm.dest_arg o Thm.cprop_of) thm
        val typ = Thm.typ_of_cterm rel
        val POS_const = Thm.cterm_of ctxt (mk_POS typ)
        val var = Thm.cterm_of ctxt (Var (("X", Thm.maxidx_of_cterm rel + 1), typ))
        val goal =
          Thm.apply (Thm.cterm_of ctxt HOLogic.Trueprop) (Thm.apply (Thm.apply POS_const rel) var)
      in
        Lifting_Util.MRSL ([merge_transfer_relations ctxt goal, thm], @{thm POS_apply})
      end
     
    val thm =
      inst_relcomppI ctxt parametric_transfer_rule transfer_rule
        OF [parametric_transfer_rule, transfer_rule]
    val preprocessed_thm = preprocess ctxt thm
    val orig_ctxt = ctxt
    val (fixed_thm, ctxt) = yield_singleton (apfst snd oo Variable.import true) preprocessed_thm ctxt
    val assms = cprems_of fixed_thm
    val add_transfer_rule = Thm.attribute_declaration Transfer.transfer_add
    val (prems, ctxt) = fold_map Thm.assume_hyps assms ctxt
    val ctxt = Context.proof_map (fold add_transfer_rule prems) ctxt
    val zipped_thm =
      fixed_thm
      |> Lifting_Util.undisch_all
      |> zip_transfer_rules ctxt
      |> implies_intr_list assms
      |> singleton (Variable.export ctxt orig_ctxt)
  in
    zipped_thm
  end
›

attribute_setup transfer_parametric = ‹ 
  Attrib.thm >> (fn parametricity =>
    Thm.rule_attribute [] (fn context => fn transfer_rule =>
      let
        val ctxt = Context.proof_of context;
        val thm' = (* Lifting_Term.parametrize_transfer_rule ctxt *) transfer_rule
      in generate_parametric_transfer_rule ctxt thm' parametricity
      end
      handle Lifting_Term.MERGE_TRANSFER_REL msg =>
        error (Pretty.string_of msg)
      ))
› "combine transfer rule with parametricity theorem"


subsection {* List of a given length *}

inductive_set nlists :: "'a set ⇒ nat ⇒ 'a list set" for A n
where nlists: "⟦ set xs ⊆ A; length xs = n ⟧ ⟹ xs ∈ nlists A n"
hide_fact (open) nlists

lemma nlists_alt_def: "nlists A n = {xs. set xs ⊆ A ∧ length xs = n}"
by(auto simp add: nlists.simps)

lemma nlists_empty: "nlists {} n = (if n = 0 then {[]} else {})"
by(auto simp add: nlists_alt_def)

lemma nlists_empty_gt0 [simp]: "n > 0 ⟹ nlists {} n = {}"
by(simp add: nlists_empty)

lemma nlists_0 [simp]: "nlists A 0 = {[]}"
by(auto simp add: nlists_alt_def)

lemma Cons_in_nlists_Suc [simp]: "x # xs ∈ nlists A (Suc n) ⟷ x ∈ A ∧ xs ∈ nlists A n"
by(simp add: nlists_alt_def)

lemma Nil_in_nlists [simp]: "[] ∈ nlists A n ⟷ n = 0"
by(auto simp add: nlists_alt_def)

lemma Cons_in_nlists_iff: "x # xs ∈ nlists A n ⟷ (∃n'. n = Suc n' ∧ x ∈ A ∧ xs ∈ nlists A n')"
by(cases n) simp_all

lemma in_nlists_Suc_iff: "xs ∈ nlists A (Suc n) ⟷ (∃x xs'. xs = x # xs' ∧ x ∈ A ∧ xs' ∈ nlists A n)"
by(cases xs) simp_all

lemma nlists_Suc: "nlists A (Suc n) = (⋃x∈A. op # x ` nlists A n)"
by(auto 4 3 simp add: in_nlists_Suc_iff intro: rev_image_eqI)

lemma replicate_in_nlists [simp, intro]: "x ∈ A ⟹ replicate n x ∈ nlists A n"
by(simp add: nlists_alt_def set_replicate_conv_if)

lemma nlists_eq_empty_iff [simp]: "nlists A n = {} ⟷ n > 0 ∧ A = {}"
using replicate_in_nlists by(cases n)(auto)

lemma finite_nlists [simp]: "finite A ⟹ finite (nlists A n)"
by(induction n)(simp_all add: nlists_Suc)

lemma finite_nlistsD: 
  assumes "finite (nlists A n)"
  shows "finite A ∨ n = 0"
proof(rule disjCI)
  assume "n ≠ 0"
  then obtain n' where n: "n = Suc n'" by(cases n)auto
  then have "A = hd ` nlists A n" by(auto 4 4 simp add: nlists_Suc intro: rev_image_eqI rev_bexI)
  also have "finite …" using assms ..
  finally show "finite A" .
qed

lemma finite_nlists_iff: "finite (nlists A n) ⟷ finite A ∨ n = 0"
by(auto dest: finite_nlistsD)

lemma card_nlists: "card (nlists A n) = card A ^ n"
proof(induction n)
  case (Suc n)
  have "card (⋃x∈A. op # x ` nlists A n) = card A * card (nlists A n)"
  proof(cases "finite A")
    case True
    then show ?thesis by(subst card_UN_disjoint)(auto simp add: card_image inj_on_def)
  next
    case False
    hence "¬ finite (⋃x∈A. op # x ` nlists A n)"
      unfolding nlists_Suc[symmetric] by(auto dest: finite_nlistsD)
    then show ?thesis using False by simp
  qed
  then show ?case using Suc.IH by(simp add: nlists_Suc)
qed simp

lemma in_nlists_UNIV: "xs ∈ nlists UNIV n ⟷ length xs = n"
by(simp add: nlists_alt_def)



subsection ‹ Corecursor for products ›

definition corec_prod :: "('s ⇒ 'a) ⇒ ('s ⇒ 'b) ⇒ 's ⇒ 'a × 'b"
where "corec_prod f g = (λs. (f s, g s))"

lemma corec_prod_apply: "corec_prod f g s = (f s, g s)"
by(simp add: corec_prod_def)

lemma corec_prod_sel [simp]:
  shows fst_corec_prod: "fst (corec_prod f g s) = f s"
  and snd_corec_prod: "snd (corec_prod f g s) = g s"
by(simp_all add: corec_prod_apply)

lemma apfst_corec_prod [simp]: "apfst h (corec_prod f g s) = corec_prod (h ∘ f) g s"
by(simp add: corec_prod_apply)

lemma apsnd_corec_prod [simp]: "apsnd h (corec_prod f g s) = corec_prod f (h ∘ g) s"
by(simp add: corec_prod_apply)

lemma map_corec_prod [simp]: "map_prod f g (corec_prod h k s) = corec_prod (f ∘ h) (g ∘ k) s"
by(simp add: corec_prod_apply)

lemma split_corec_prod [simp]: "case_prod h (corec_prod f g s) = h (f s) (g s)"
by(simp add: corec_prod_apply)

subsection ‹ @{const If}›

named_theorems if_distribs "Distributivity theorems for If"

lemma if_mono_cong: "⟦b ⟹ x ≤ x'; ¬ b ⟹ y ≤ y' ⟧ ⟹ If b x y ≤ If b x' y'"
by simp

end