about summary refs log tree commit diff
path: root/src/nbe/Conversion.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/nbe/Conversion.ml')
-rw-r--r--src/nbe/Conversion.ml63
1 files changed, 47 insertions, 16 deletions
diff --git a/src/nbe/Conversion.ml b/src/nbe/Conversion.ml
index 00bc942..8c4478e 100644
--- a/src/nbe/Conversion.ml
+++ b/src/nbe/Conversion.ml
@@ -7,31 +7,61 @@ exception Unequal
 
 module Internal =
 struct
-  (** Context size *)
-  type env = int
+  type env = {
+    mode : [`Rigid | `Flex | `Full];
+    size : int
+  }
 
   module Eff = Algaeff.Reader.Make (struct type nonrec t = env end)
 
+  let with_mode mode f = Eff.scope (fun env -> { env with mode }) f
+
   let bind f =
-    let arg = D.var (Eff.read()) in
-    Eff.scope (fun size -> size + 1) (fun () -> f arg)
+    let arg = D.var (Eff.read()).size in
+    Eff.scope (fun env -> { env with size = env.size + 1 }) @@ fun () -> f arg
+
+  let equal_unfold_head hd1 hd2 = match hd1, hd2 with
+    | D.Def (p1, _), D.Def (p2, _) -> p1 = p2
 
-  let rec equate v1 v2 = match v1, v2 with
-    | D.Neutral ne, v | v, D.Neutral ne -> equate_ne ne v
-    | D.Pi (_, base1, fam1), D.Pi (_, base2, fam2) ->
+  let rec equate v1 v2 = match v1, v2, (Eff.read()).mode with
+    | D.Neutral ne, v, _ | v, D.Neutral ne, _ -> equate_ne ne v
+    | D.Pi (_, base1, fam1), D.Pi (_, base2, fam2), _  ->
       equate base1 base2;
       equate_clo fam1 fam2
-    | D.Lam (_, clo1), D.Lam (_, clo2) -> equate_clo clo1 clo2
-    | D.Sg (_, base1, fam1), D.Sg (_, base2, fam2) ->
+    | D.Lam (_, clo1), D.Lam (_, clo2), _  -> equate_clo clo1 clo2
+    | D.Sg (_, base1, fam1), D.Sg (_, base2, fam2), _  ->
       equate base1 base2;
       equate_clo fam1 fam2
-    | D.Pair (fst1, snd1), D.Pair (fst2, snd2) ->
+    | D.Pair (fst1, snd1), D.Pair (fst2, snd2), _  ->
       equate fst1 fst2;
       equate snd1 snd2
-    | Type, Type -> ()
-    | Bool, Bool -> ()
-    | True, True -> ()
-    | False, False -> ()
+    | Type, Type, _  -> ()
+    | Bool, Bool, _  -> ()
+    | True, True, _  -> ()
+    | False, False, _  -> ()
+
+    (* approximate conversion checking in the style of smalltt, see
+       https://github.com/AndrasKovacs/smalltt?tab=readme-ov-file#approximate-conversion-checking *)
+    (* in "full" mode, we immediately unfold any defined symbol *)
+    | D.Unfold (_, _, v1), v2, `Full -> equate (Lazy.force v1) v2
+    | v1, D.Unfold (_, _, v2), `Full -> equate v1 (Lazy.force v2)
+    (* in "flex" mode, we cannot unfold any top-level definition; we can only
+       recurse into spines if head symbols are equal *)
+    | D.Unfold (hd1, sp1, _), D.Unfold (hd2, sp2, _), `Flex ->
+      if equal_unfold_head hd1 hd2 then
+        equate_spine sp1 sp2
+      else raise Unequal
+    (* in "rigid" mode, we can initiate speculation if we have the same
+       top-level head symbol on both sides *)
+    | D.Unfold (hd1, sp1, v1), D.Unfold (hd2, sp2, v2), `Rigid ->
+      if equal_unfold_head hd1 hd2 then
+        try with_mode `Flex @@ fun () -> equate_spine sp1 sp2 with
+        | Unequal -> with_mode `Full @@ fun () -> equate (Lazy.force v1) (Lazy.force v2)
+      else
+        equate (Lazy.force v1) (Lazy.force v2)
+    | D.Unfold (_, _, v1), v2, `Rigid -> equate (Lazy.force v1) v2
+    | v1, D.Unfold (_, _, v2), `Rigid -> equate v1 (Lazy.force v2)
+
     | _ -> raise Unequal
 
   and equate_clo clo1 clo2 = bind @@ fun arg ->
@@ -71,5 +101,6 @@ struct
     | _ -> raise Unequal
 end
 
-let equate ~size v1 v2 = Internal.Eff.run ~env:size @@ fun () ->
-  Internal.equate v1 v2
+let equate ~size v1 v2 =
+  let env = { Internal.mode = `Rigid; size } in
+  Internal.Eff.run ~env @@ fun () -> Internal.equate v1 v2