about summary refs log tree commit diff
path: root/src/elaborator/Elaborator.ml
blob: 6fa44ff1a6f15202713ac6737f5c676662518074 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
open Bwd
open Bwd.Infix

module A = Ast
module S = NbE.Syntax
module D = NbE.Domain

type env = {
  (* local context *)
  (* invariant: `tps`, `tms` and `names` all have length `size` *)
  tps : D.env;
  tms : D.env;
  names : Name.local bwd;
  size : int;

  (* top-level context *)
  toplvl : TopLevel.t;
}

module Eff = Algaeff.Reader.Make (struct type nonrec t = env end)

(* general helpers *)

let lookup (name : Name.t) : D.t * S.t =
  let env = Eff.read() in
  (* search through local context *)
  match Option.bind 
          (Name.to_local name) 
          (fun name -> Bwd.find_index ((=) name) env.names) with
  | Some ix ->
    let tp = Bwd.nth env.tps ix in
    (tp, S.Var ix)
  | None ->
    (* look up in top-level context *)
    match Yuujinchou.Trie.find_singleton name env.toplvl with
    | Some (Def { tp; tm }, _) ->
      (tp, S.Def (name, tm))
    | None -> Error.unbound_variable name

let bind ~(name : Name.local) ~(tp : D.t) f =
  let arg = D.var (Eff.read()).size in
  let update env = { 
    env with
    tps = env.tps <: tp;
    tms = env.tms <: arg;
    names = env.names <: name;
    size = env.size + 1;
  } in
  Eff.scope update (fun () -> f arg)

(* NbE helpers *)

let eval tm = NbE.eval ~env:(Eff.read()).tms tm

(* evaluate under the current environment augmented by `arg` *)
(* TODO: this is kind of inelegant, can we do better? *)
let eval_at arg = NbE.eval ~env:((Eff.read()).tms <: arg)

(* pretty-printing helpers *)

let pp_tm () =
  let names = (Eff.read()).names in
  fun fmt tm -> Pretty.pp ~names fmt tm

let pp_val () =
  let size = (Eff.read()).size in
  let pp_tm = pp_tm () in
  fun fmt v -> pp_tm fmt (NbE.quote ~size v)

(* main algorithm *)

type connective = [ `Pi | `Sigma ]

let rec check ~(tm : A.expr) ~(tp : D.t) : S.t =
  Error.tracef ?loc:tm.loc "when checking against the type @[%a@]" (pp_val()) tp @@ fun () ->
  match tm.value with
  | A.Pi ((name, base), fam) -> check_connective `Pi ~name:name.value ~base ~fam ~tp
  | A.Fun (base, fam) -> check_connective `Pi ~name:None ~base ~fam ~tp
  | A.Lam (name, body) -> begin match tp with
      | D.Pi (_, base, fam) -> 
        let body = bind ~name:name.value ~tp:base @@ fun arg ->
          let fib = NbE.inst_clo fam arg in
          check ~tm:body ~tp:fib in
        S.Lam (name.value, body)
      | _ -> Error.type_mismatch (pp_val()) tp Fmt.string "a Pi type"
    end
  | A.Sg ((name, base), fam) -> check_connective `Sigma ~name:name.value ~base ~fam ~tp
  | A.Prod (base, fam) -> check_connective `Sigma ~name:None ~base ~fam ~tp
  | A.Pair (fst, snd) -> begin match tp with
      | D.Sg (_, base, fam) -> 
        let fst = check ~tm:fst ~tp:base in
        let fib = NbE.inst_clo fam (eval fst) in
        let snd = check ~tm:snd ~tp:fib in
        S.Pair (fst, snd)
      | _ -> Error.type_mismatch (pp_val()) tp Fmt.string "a Sigma type"
    end
  | A.Type -> begin match tp with (* TODO type-in-type *)
      | D.Type -> S.Type
      | _ -> Error.type_mismatch (pp_val()) tp Fmt.string "type"
    end
  | A.Bool -> begin match tp with
      | D.Type -> S.Bool
      | _ -> Error.type_mismatch (pp_val()) tp Fmt.string "type"
    end
  | A.True -> begin match tp with
      | D.Bool -> S.True
      | _ -> Error.type_mismatch (pp_val()) tp Fmt.string "bool"
    end
  | A.False -> begin match tp with
      | D.Bool -> S.False
      | _ -> Error.type_mismatch (pp_val()) tp Fmt.string "bool"
    end
  | _ -> let (inferred_tp, tm) = infer tm in begin
      try NbE.equate ~size:((Eff.read()).size) inferred_tp tp with
      | NbE.Unequal ->
        Error.type_mismatch (pp_val()) tp (pp_val()) inferred_tp
    end;
    tm

and check_connective connective ~(name : Name.local) ~(base : A.expr) ~(fam : A.expr) ~(tp : D.t) =
  match tp with
  | D.Type ->
    let base = check ~tm:base ~tp in
    let fam = bind ~name:name ~tp:(eval base) @@ fun _ -> check ~tm:fam ~tp in
    begin match connective with
      | `Pi -> S.Pi (name, base, fam)
      | `Sigma -> S.Sg (name, base, fam)
    end
  | _ -> Error.type_mismatch (pp_val()) tp Fmt.string "type"

and check_tp (tp : A.expr) = check ~tp:D.Type ~tm:tp

and infer (tm : A.expr) : D.t * S.t =
  Error.tracef ?loc:tm.loc "when inferring the type" @@ fun () ->
  match tm.value with
  | A.Var name -> lookup name
  | A.Check (tm, tp) ->
    let tp = eval @@ check_tp tp in
    let tm = check ~tp ~tm in
    (tp, tm)
  | A.App (fn, arg) -> begin match infer fn with
      | (D.Pi (_, base, fam), fn) ->
        let arg = check ~tm:arg ~tp:base in
        let tp = NbE.inst_clo fam (eval arg) in
        let tm = S.App (fn, arg) in
        (tp, tm)
      | (tp, _) -> Error.type_mismatch ?loc:fn.loc Fmt.string "a Pi type" (pp_val()) tp
    end
  | A.Fst p -> begin match infer p with
      | (D.Sg (_, base, _), p) -> (base, S.Fst p)
      | (tp, _) -> Error.type_mismatch ?loc:p.loc Fmt.string "a Sigma type" (pp_val()) tp
    end
  | A.Snd p -> begin match infer p with
      | (D.Sg (_, _, fam), p) ->
        let tp = NbE.inst_clo fam (eval (S.Fst p)) in
        let tm = S.Snd p in
        (tp, tm)
      | (tp, _) -> Error.type_mismatch ?loc:p.loc Fmt.string "a Sigma type" (pp_val()) tp
    end
  | A.BoolElim { motive_var; motive_body; true_case; false_case; scrut } ->
    let scrut = check ~tm:scrut ~tp:D.Bool in
    let motive = bind ~name:motive_var.value ~tp:D.Bool @@ fun _ ->
      check ~tm:motive_body ~tp:D.Type in
    let motive_true = eval_at D.True motive in
    let motive_false = eval_at D.False motive in
    let motive_scrut = eval_at (eval scrut) motive in
    let true_case = check ~tm:true_case ~tp:motive_true in
    let false_case = check ~tm:false_case ~tp:motive_false in
    let tm = S.BoolElim { motive_var = motive_var.value; motive; true_case; false_case; scrut } in
    (motive_scrut, tm)
  | _ -> Error.not_inferable ()

(* elaborating definitions *)

let check_def ~(args : A.arg list) ~(tp : A.expr) ~(tm : A.expr) : S.t * S.t =
  let check_arg ((arg_name, arg_tp) : A.arg) cont () =
    let arg_tp = check_tp arg_tp in
    bind ~name:arg_name.value ~tp:(eval arg_tp) @@ fun _ ->
    let (tp, tm) = cont () in
    let tp = S.Pi (arg_name.value, arg_tp, tp) in
    let tm = S.Lam (arg_name.value, tm) in
    (tp, tm)
  in
  let check_rhs () =
    let tp = check_tp tp in
    let tm = check ~tp:(eval tp) ~tm in
    (tp, tm)
  in
  List.fold_right check_arg args check_rhs ()

(* interface *)

let initial_env toplvl : env = {
  tps = Emp;
  tms = Emp;
  names = Emp;
  size = 0;
  toplvl;
}

let check_toplevel ~(toplvl : TopLevel.t) ~(tm : A.expr) ~(tp : D.t) =
  Eff.run ~env:(initial_env toplvl) @@ fun () -> check ~tm ~tp

let check_tp_toplevel ~(toplvl : TopLevel.t) (tp : A.expr) =
  Eff.run ~env:(initial_env toplvl) @@ fun () -> check_tp tp

let infer_toplevel ~(toplvl : TopLevel.t) (tm : A.expr) =
  Eff.run ~env:(initial_env toplvl) @@ fun () -> infer tm

let check_def_toplevel ~(toplvl : TopLevel.t) ~(args : A.arg list) ~(tp : A.expr) ~(tm : A.expr) =
  Eff.run ~env:(initial_env toplvl) @@ fun () -> check_def ~args ~tp ~tm