about summary refs log tree commit diff
path: root/src/nbe/Conversion.ml
blob: 8c4478ea32e60eac2c56451146ddbb904dd0fe8c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
open Bwd
open Bwd.Infix

module D = Domain

exception Unequal

module Internal =
struct
  type env = {
    mode : [`Rigid | `Flex | `Full];
    size : int
  }

  module Eff = Algaeff.Reader.Make (struct type nonrec t = env end)

  let with_mode mode f = Eff.scope (fun env -> { env with mode }) f

  let bind f =
    let arg = D.var (Eff.read()).size in
    Eff.scope (fun env -> { env with size = env.size + 1 }) @@ fun () -> f arg

  let equal_unfold_head hd1 hd2 = match hd1, hd2 with
    | D.Def (p1, _), D.Def (p2, _) -> p1 = p2

  let rec equate v1 v2 = match v1, v2, (Eff.read()).mode with
    | D.Neutral ne, v, _ | v, D.Neutral ne, _ -> equate_ne ne v
    | D.Pi (_, base1, fam1), D.Pi (_, base2, fam2), _  ->
      equate base1 base2;
      equate_clo fam1 fam2
    | D.Lam (_, clo1), D.Lam (_, clo2), _  -> equate_clo clo1 clo2
    | D.Sg (_, base1, fam1), D.Sg (_, base2, fam2), _  ->
      equate base1 base2;
      equate_clo fam1 fam2
    | D.Pair (fst1, snd1), D.Pair (fst2, snd2), _  ->
      equate fst1 fst2;
      equate snd1 snd2
    | Type, Type, _  -> ()
    | Bool, Bool, _  -> ()
    | True, True, _  -> ()
    | False, False, _  -> ()

    (* approximate conversion checking in the style of smalltt, see
       https://github.com/AndrasKovacs/smalltt?tab=readme-ov-file#approximate-conversion-checking *)
    (* in "full" mode, we immediately unfold any defined symbol *)
    | D.Unfold (_, _, v1), v2, `Full -> equate (Lazy.force v1) v2
    | v1, D.Unfold (_, _, v2), `Full -> equate v1 (Lazy.force v2)
    (* in "flex" mode, we cannot unfold any top-level definition; we can only
       recurse into spines if head symbols are equal *)
    | D.Unfold (hd1, sp1, _), D.Unfold (hd2, sp2, _), `Flex ->
      if equal_unfold_head hd1 hd2 then
        equate_spine sp1 sp2
      else raise Unequal
    (* in "rigid" mode, we can initiate speculation if we have the same
       top-level head symbol on both sides *)
    | D.Unfold (hd1, sp1, v1), D.Unfold (hd2, sp2, v2), `Rigid ->
      if equal_unfold_head hd1 hd2 then
        try with_mode `Flex @@ fun () -> equate_spine sp1 sp2 with
        | Unequal -> with_mode `Full @@ fun () -> equate (Lazy.force v1) (Lazy.force v2)
      else
        equate (Lazy.force v1) (Lazy.force v2)
    | D.Unfold (_, _, v1), v2, `Rigid -> equate (Lazy.force v1) v2
    | v1, D.Unfold (_, _, v2), `Rigid -> equate v1 (Lazy.force v2)

    | _ -> raise Unequal

  and equate_clo clo1 clo2 = bind @@ fun arg ->
    equate (Eval.inst_clo clo1 arg) (Eval.inst_clo clo2 arg)

  and equate_ne_head (D.Var lvl1) (D.Var lvl2) =
    if lvl1 = lvl2 then () else raise Unequal

  and equate_frm frm1 frm2 = match frm1, frm2 with
    | D.App arg1, D.App arg2 -> equate arg1 arg2
    | D.Fst, D.Fst -> ()
    | D.Snd, D.Snd -> ()
    | D.BoolElim { motive = mot1; true_case = t1; false_case = f1; _ },
      D.BoolElim { motive = mot2; true_case = t2; false_case = f2; _ } ->
      equate_clo mot1 mot2;
      equate t1 t2;
      equate f1 f2;
    | _ -> raise Unequal

  and equate_spine sp1 sp2 = match sp1, sp2 with
    | Emp, Emp -> ()
    | Snoc (sp1, frm1), Snoc (sp2, frm2) ->
      equate_frm frm1 frm2;
      equate_spine sp1 sp2
    | _ -> raise Unequal

  and equate_ne (hd, sp) v = match v with
    | D.Neutral (hd2, sp2) ->
      equate_ne_head hd hd2;
      equate_spine sp sp2
    (* eta expansion *)
    | D.Lam (_, clo) -> bind @@ fun arg ->
      equate_ne (hd, sp <: D.App arg) (Eval.inst_clo clo arg)
    | D.Pair (fst, snd) ->
      equate_ne (hd, sp <: D.Fst) fst;
      equate_ne (hd, sp <: D.Snd) snd
    | _ -> raise Unequal
end

let equate ~size v1 v2 =
  let env = { Internal.mode = `Rigid; size } in
  Internal.Eff.run ~env @@ fun () -> Internal.equate v1 v2