about summary refs log tree commit diff
path: root/src/nbe
diff options
context:
space:
mode:
Diffstat (limited to 'src/nbe')
-rw-r--r--src/nbe/Conversion.ml63
-rw-r--r--src/nbe/Data.ml16
-rw-r--r--src/nbe/Domain.ml12
-rw-r--r--src/nbe/Eval.ml27
-rw-r--r--src/nbe/NbE.ml3
-rw-r--r--src/nbe/Quote.ml8
-rw-r--r--src/nbe/Syntax.ml1
7 files changed, 100 insertions, 30 deletions
diff --git a/src/nbe/Conversion.ml b/src/nbe/Conversion.ml
index 00bc942..8c4478e 100644
--- a/src/nbe/Conversion.ml
+++ b/src/nbe/Conversion.ml
@@ -7,31 +7,61 @@ exception Unequal
 
 module Internal =
 struct
-  (** Context size *)
-  type env = int
+  type env = {
+    mode : [`Rigid | `Flex | `Full];
+    size : int
+  }
 
   module Eff = Algaeff.Reader.Make (struct type nonrec t = env end)
 
+  let with_mode mode f = Eff.scope (fun env -> { env with mode }) f
+
   let bind f =
-    let arg = D.var (Eff.read()) in
-    Eff.scope (fun size -> size + 1) (fun () -> f arg)
+    let arg = D.var (Eff.read()).size in
+    Eff.scope (fun env -> { env with size = env.size + 1 }) @@ fun () -> f arg
+
+  let equal_unfold_head hd1 hd2 = match hd1, hd2 with
+    | D.Def (p1, _), D.Def (p2, _) -> p1 = p2
 
-  let rec equate v1 v2 = match v1, v2 with
-    | D.Neutral ne, v | v, D.Neutral ne -> equate_ne ne v
-    | D.Pi (_, base1, fam1), D.Pi (_, base2, fam2) ->
+  let rec equate v1 v2 = match v1, v2, (Eff.read()).mode with
+    | D.Neutral ne, v, _ | v, D.Neutral ne, _ -> equate_ne ne v
+    | D.Pi (_, base1, fam1), D.Pi (_, base2, fam2), _  ->
       equate base1 base2;
       equate_clo fam1 fam2
-    | D.Lam (_, clo1), D.Lam (_, clo2) -> equate_clo clo1 clo2
-    | D.Sg (_, base1, fam1), D.Sg (_, base2, fam2) ->
+    | D.Lam (_, clo1), D.Lam (_, clo2), _  -> equate_clo clo1 clo2
+    | D.Sg (_, base1, fam1), D.Sg (_, base2, fam2), _  ->
       equate base1 base2;
       equate_clo fam1 fam2
-    | D.Pair (fst1, snd1), D.Pair (fst2, snd2) ->
+    | D.Pair (fst1, snd1), D.Pair (fst2, snd2), _  ->
       equate fst1 fst2;
       equate snd1 snd2
-    | Type, Type -> ()
-    | Bool, Bool -> ()
-    | True, True -> ()
-    | False, False -> ()
+    | Type, Type, _  -> ()
+    | Bool, Bool, _  -> ()
+    | True, True, _  -> ()
+    | False, False, _  -> ()
+
+    (* approximate conversion checking in the style of smalltt, see
+       https://github.com/AndrasKovacs/smalltt?tab=readme-ov-file#approximate-conversion-checking *)
+    (* in "full" mode, we immediately unfold any defined symbol *)
+    | D.Unfold (_, _, v1), v2, `Full -> equate (Lazy.force v1) v2
+    | v1, D.Unfold (_, _, v2), `Full -> equate v1 (Lazy.force v2)
+    (* in "flex" mode, we cannot unfold any top-level definition; we can only
+       recurse into spines if head symbols are equal *)
+    | D.Unfold (hd1, sp1, _), D.Unfold (hd2, sp2, _), `Flex ->
+      if equal_unfold_head hd1 hd2 then
+        equate_spine sp1 sp2
+      else raise Unequal
+    (* in "rigid" mode, we can initiate speculation if we have the same
+       top-level head symbol on both sides *)
+    | D.Unfold (hd1, sp1, v1), D.Unfold (hd2, sp2, v2), `Rigid ->
+      if equal_unfold_head hd1 hd2 then
+        try with_mode `Flex @@ fun () -> equate_spine sp1 sp2 with
+        | Unequal -> with_mode `Full @@ fun () -> equate (Lazy.force v1) (Lazy.force v2)
+      else
+        equate (Lazy.force v1) (Lazy.force v2)
+    | D.Unfold (_, _, v1), v2, `Rigid -> equate (Lazy.force v1) v2
+    | v1, D.Unfold (_, _, v2), `Rigid -> equate v1 (Lazy.force v2)
+
     | _ -> raise Unequal
 
   and equate_clo clo1 clo2 = bind @@ fun arg ->
@@ -71,5 +101,6 @@ struct
     | _ -> raise Unequal
 end
 
-let equate ~size v1 v2 = Internal.Eff.run ~env:size @@ fun () ->
-  Internal.equate v1 v2
+let equate ~size v1 v2 =
+  let env = { Internal.mode = `Rigid; size } in
+  Internal.Eff.run ~env @@ fun () -> Internal.equate v1 v2
diff --git a/src/nbe/Data.ml b/src/nbe/Data.ml
index d2f94d3..8589ea5 100644
--- a/src/nbe/Data.ml
+++ b/src/nbe/Data.ml
@@ -1,9 +1,8 @@
 open Bwd
 
-(** Syntactic terms *)
-
 type syn =
   | Var of int
+  | Def of Ident.t * value Lazy.t
   | Pi of Ident.local * syn * (* BINDS *) syn
   | Lam of Ident.local * (* BINDS *) syn
   | App of syn * syn
@@ -23,10 +22,9 @@ type syn =
       scrut : syn;
     }
 
-(** Semantic domain *)
-
-type value =
+and value =
   | Neutral of ne
+  | Unfold of unfold
   | Pi of Ident.local * value * clo
   | Lam of Ident.local * clo
   | Sg of Ident.local * value * clo
@@ -37,7 +35,13 @@ type value =
   | False
 
 and ne = ne_head * frm bwd
-and ne_head = Var of int (* De Bruijn levels *)
+and ne_head =
+  | Var of int (* De Bruijn levels *)
+
+and unfold = unfold_head * frm bwd * value Lazy.t
+and unfold_head =
+  | Def of Ident.t * value Lazy.t
+
 and frm = 
   | App of value
   | Fst
diff --git a/src/nbe/Domain.ml b/src/nbe/Domain.ml
index 6f5e11c..bf5ed90 100644
--- a/src/nbe/Domain.ml
+++ b/src/nbe/Domain.ml
@@ -2,6 +2,7 @@ open Bwd
 
 type t = Data.value =
   | Neutral of ne
+  | Unfold of unfold
   | Pi of Ident.local * t * clo
   | Lam of Ident.local * clo
   | Sg of Ident.local * t * clo
@@ -12,7 +13,13 @@ type t = Data.value =
   | False
 
 and ne = Data.ne
-and ne_head = Data.ne_head = Var of int (* De Bruijn levels *)
+and ne_head = Data.ne_head =
+  | Var of int (* De Bruijn levels *)
+
+and unfold = Data.unfold
+and unfold_head = Data.unfold_head =
+  | Def of Ident.t * t Lazy.t
+
 and frm = Data.frm =
   | App of t
   | Fst
@@ -27,4 +34,5 @@ and frm = Data.frm =
 and env = Data.env
 and clo = Data.clo = Clo of { body : Data.syn; env : env }
 
-let var i = Neutral (Var i, Bwd.Emp)
+let var i = Neutral (Var i, Emp)
+let def p v = Unfold (Def (p, v), Emp, v)
diff --git a/src/nbe/Eval.ml b/src/nbe/Eval.ml
index bd8326e..2730ac0 100644
--- a/src/nbe/Eval.ml
+++ b/src/nbe/Eval.ml
@@ -16,27 +16,41 @@ struct
 
   and app v w = match v with
     | D.Lam (_, clo) -> inst_clo clo w
-    | D.Neutral (hd, frms) -> D.Neutral (hd, frms <: D.App w)
+    | D.Neutral (hd, frms) ->
+      D.Neutral (hd, frms <: D.App w)
+    | D.Unfold (hd, frms, v) ->
+      D.Unfold (hd, frms <: D.App w, Lazy.map (fun v -> app v w) v)
     | _ -> invalid_arg "Eval.app"
 
   and fst = function
     | D.Pair (v, _) -> v
-    | D.Neutral (hd, frms) -> D.Neutral (hd, frms <: D.Fst)
+    | D.Neutral (hd, frms) ->
+      D.Neutral (hd, frms <: D.Fst)
+    | D.Unfold (hd, frms, v) ->
+      D.Unfold (hd, frms <: D.Fst, Lazy.map (fun v -> fst v) v)
     | _ -> invalid_arg "Eval.fst"
 
   and snd = function
     | D.Pair (_, v) -> v
-    | D.Neutral (hd, frms) -> D.Neutral (hd, frms <: D.Snd)
+    | D.Neutral (hd, frms) ->
+      D.Neutral (hd, frms <: D.Snd)
+    | D.Unfold (hd, frms, v) ->
+      D.Unfold (hd, frms <: D.Snd, Lazy.map (fun v -> snd v) v)
     | _ -> invalid_arg "Eval.snd"
 
   and bool_elim motive_var motive true_case false_case = function
     | D.True -> true_case
     | D.False -> false_case
-    | D.Neutral (hd, frms) -> D.Neutral (hd, frms <: D.BoolElim { motive_var; motive; true_case; false_case })
+    | D.Neutral (hd, frms) ->
+      D.Neutral (hd, frms <: D.BoolElim { motive_var; motive; true_case; false_case })
+    | D.Unfold (hd, frms, v) ->
+      D.Unfold (hd, frms <: D.BoolElim { motive_var; motive; true_case; false_case }, 
+                Lazy.map (fun v -> bool_elim motive_var motive true_case false_case v) v)
     | _ -> invalid_arg "Eval.bool_elim"
 
   and eval = function
     | S.Var i -> Bwd.nth (Eff.read()) i
+    | S.Def (p, v) -> D.def p v
     | S.Pi (name, base, fam) -> D.Pi (name, eval base, make_clo fam)
     | S.Lam (name, body) -> D.Lam (name, make_clo body)
     | S.App (a, b) -> app (eval a) (eval b)
@@ -53,4 +67,9 @@ struct
 end
 
 let eval ~env tm = Internal.Eff.run ~env @@ fun () -> Internal.eval tm
+let eval_toplevel tm = eval ~env:Emp tm
 let inst_clo = Internal.inst_clo
+
+let rec force_all = function
+  | D.Unfold (_, _, v) -> force_all (Lazy.force v)
+  | v -> v
diff --git a/src/nbe/NbE.ml b/src/nbe/NbE.ml
index 3ec1f6d..edd71a5 100644
--- a/src/nbe/NbE.ml
+++ b/src/nbe/NbE.ml
@@ -2,9 +2,12 @@ module Syntax = Syntax
 module Domain = Domain
 
 let eval = Eval.eval
+let eval_toplevel = Eval.eval_toplevel
 let inst_clo = Eval.inst_clo
+let force_all = Eval.force_all
 
 let quote = Quote.quote
+let quote_toplevel = Quote.quote_toplevel
 
 exception Unequal = Conversion.Unequal
 let equate = Conversion.equate
diff --git a/src/nbe/Quote.ml b/src/nbe/Quote.ml
index cc5a81e..94c8395 100644
--- a/src/nbe/Quote.ml
+++ b/src/nbe/Quote.ml
@@ -12,9 +12,10 @@ struct
     let arg = D.var (Eff.read()) in
     Eff.scope (fun size -> size + 1) @@ fun () ->
     f arg
-    
+
   let rec quote = function
     | D.Neutral ne -> quote_ne ne
+    | D.Unfold uf -> quote_unfold uf
     | D.Pi (name, base, fam) -> S.Pi (name, quote base, quote_clo fam)
     | D.Lam (name, clo) -> S.Lam (name, quote_clo clo)
     | D.Sg (name, base, fam) -> S.Sg (name, quote base, quote_clo fam)
@@ -27,9 +28,11 @@ struct
   and quote_clo clo = bind @@ fun arg -> quote (Eval.inst_clo clo arg)
 
   and quote_ne (hd, frms) = Bwd.fold_left quote_frm (quote_ne_head hd) frms
-
   and quote_ne_head (D.Var i) = S.Var (Eff.read() - i - 1) (* converting from levels to indices *)
 
+  and quote_unfold (hd, frms, _) = Bwd.fold_left quote_frm (quote_unfold_head hd) frms
+  and quote_unfold_head (D.Def (p, v)) = S.Def (p, v)
+
   and quote_frm hd = function
     | D.App v -> S.App (hd, quote v)
     | D.Fst -> S.Fst hd
@@ -45,3 +48,4 @@ struct
 end
 
 let quote ~size v = Internal.Eff.run ~env:size (fun () -> Internal.quote v)
+let quote_toplevel v = quote ~size:0 v
diff --git a/src/nbe/Syntax.ml b/src/nbe/Syntax.ml
index dd690fb..5de5281 100644
--- a/src/nbe/Syntax.ml
+++ b/src/nbe/Syntax.ml
@@ -1,5 +1,6 @@
 type t = Data.syn =
   | Var of int
+  | Def of Ident.t * Data.value Lazy.t
   | Pi of Ident.local * t * (* BINDS *) t
   | Lam of Ident.local * (* BINDS *) t
   | App of t * t