open Bwd open Bwd.Infix module D = Domain exception Unequal module Internal = struct type env = { mode : [`Rigid | `Flex | `Full]; size : int } module Eff = Algaeff.Reader.Make (struct type nonrec t = env end) let with_mode mode f = Eff.scope (fun env -> { env with mode }) f let bind f = let arg = D.var (Eff.read()).size in Eff.scope (fun env -> { env with size = env.size + 1 }) @@ fun () -> f arg let equal_unfold_head hd1 hd2 = match hd1, hd2 with | D.Def (p1, _), D.Def (p2, _) -> p1 = p2 let rec equate v1 v2 = match v1, v2, (Eff.read()).mode with | D.Neutral ne, v, _ | v, D.Neutral ne, _ -> equate_ne ne v | D.Pi (_, base1, fam1), D.Pi (_, base2, fam2), _ -> equate base1 base2; equate_clo fam1 fam2 | D.Lam (_, clo1), D.Lam (_, clo2), _ -> equate_clo clo1 clo2 | D.Sg (_, base1, fam1), D.Sg (_, base2, fam2), _ -> equate base1 base2; equate_clo fam1 fam2 | D.Pair (fst1, snd1), D.Pair (fst2, snd2), _ -> equate fst1 fst2; equate snd1 snd2 | Type, Type, _ -> () | Bool, Bool, _ -> () | True, True, _ -> () | False, False, _ -> () (* approximate conversion checking in the style of smalltt, see https://github.com/AndrasKovacs/smalltt?tab=readme-ov-file#approximate-conversion-checking *) (* in "full" mode, we immediately unfold any defined symbol *) | D.Unfold (_, _, v1), v2, `Full -> equate (Lazy.force v1) v2 | v1, D.Unfold (_, _, v2), `Full -> equate v1 (Lazy.force v2) (* in "flex" mode, we cannot unfold any top-level definition; we can only recurse into spines if head symbols are equal *) | D.Unfold (hd1, sp1, _), D.Unfold (hd2, sp2, _), `Flex -> if equal_unfold_head hd1 hd2 then equate_spine sp1 sp2 else raise Unequal (* in "rigid" mode, we can initiate speculation if we have the same top-level head symbol on both sides *) | D.Unfold (hd1, sp1, v1), D.Unfold (hd2, sp2, v2), `Rigid -> if equal_unfold_head hd1 hd2 then try with_mode `Flex @@ fun () -> equate_spine sp1 sp2 with | Unequal -> with_mode `Full @@ fun () -> equate (Lazy.force v1) (Lazy.force v2) else equate (Lazy.force v1) (Lazy.force v2) | D.Unfold (_, _, v1), v2, `Rigid -> equate (Lazy.force v1) v2 | v1, D.Unfold (_, _, v2), `Rigid -> equate v1 (Lazy.force v2) | _ -> raise Unequal and equate_clo clo1 clo2 = bind @@ fun arg -> equate (Eval.inst_clo clo1 arg) (Eval.inst_clo clo2 arg) and equate_ne_head (D.Var lvl1) (D.Var lvl2) = if lvl1 = lvl2 then () else raise Unequal and equate_frm frm1 frm2 = match frm1, frm2 with | D.App arg1, D.App arg2 -> equate arg1 arg2 | D.Fst, D.Fst -> () | D.Snd, D.Snd -> () | D.BoolElim { motive = mot1; true_case = t1; false_case = f1; _ }, D.BoolElim { motive = mot2; true_case = t2; false_case = f2; _ } -> equate_clo mot1 mot2; equate t1 t2; equate f1 f2; | _ -> raise Unequal and equate_spine sp1 sp2 = match sp1, sp2 with | Emp, Emp -> () | Snoc (sp1, frm1), Snoc (sp2, frm2) -> equate_frm frm1 frm2; equate_spine sp1 sp2 | _ -> raise Unequal and equate_ne (hd, sp) v = match v with | D.Neutral (hd2, sp2) -> equate_ne_head hd hd2; equate_spine sp sp2 (* eta expansion *) | D.Lam (_, clo) -> bind @@ fun arg -> equate_ne (hd, sp <: D.App arg) (Eval.inst_clo clo arg) | D.Pair (fst, snd) -> equate_ne (hd, sp <: D.Fst) fst; equate_ne (hd, sp <: D.Snd) snd | _ -> raise Unequal end let equate ~size v1 v2 = let env = { Internal.mode = `Rigid; size } in Internal.Eff.run ~env @@ fun () -> Internal.equate v1 v2