summaryrefslogtreecommitdiff
path: root/src/monomorphise.ml
blob: 17add78c1091a89791a70fcda577e16096d14785 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
open Parse_ast
open Ast
open Type_internal

(* TODO: put this somewhere common *)

let id_to_string (Id_aux(id,l)) =
  match id with
    | Id(s) -> s
    | DeIid(s) -> s

(* TODO: check for temporary failwiths *)

let optmap v f =
  match v with
  | None -> None
  | Some v -> Some (f v)

let disable_const_propagation = ref false

(* Based on current type checker's behaviour *)
let pat_id_is_variable t_env id =
  match Envmap.apply t_env id with
  | Some (Base(_,Constructor _,_,_,_,_))
  | Some (Base(_,Enum _,_,_,_,_))
    -> false
  | _ -> true
  

let bindings_from_pat t_env p =
  let rec aux_pat (P_aux (p,annot)) =
  match p with
  | P_lit _
  | P_wild
    -> []
  | P_as (p,id) -> id_to_string id::(aux_pat p)
  | P_typ (_,p) -> aux_pat p
  | P_id id ->
     let i = id_to_string id in
     if pat_id_is_variable t_env i then [i] else []
  | P_vector ps
  | P_vector_concat ps
  | P_app (_,ps)
  | P_tup ps
  | P_list ps
    -> List.concat (List.map aux_pat ps)
  | P_record (fps,_) -> List.concat (List.map aux_fpat fps)
  | P_vector_indexed ips -> List.concat (List.map (fun (_,p) -> aux_pat p) ips)
  and aux_fpat (FP_aux (FP_Fpat (_,p), _)) = aux_pat p
  in aux_pat p

let remove_bound t_env env pat =
  let bound = bindings_from_pat t_env pat in
  List.fold_left (fun sub v -> Envmap.remove env v) env bound

let split_defs splits (Type_check.Env (d_env,t_env,b_env,tp_env)) defs =

  (* Constant propogation *)
  let rec const_prop_exp substs ((E_aux (e,(l,annot))) as exp) =
    let re e = E_aux (e,(l,annot)) in
    match e with
      (* TODO: are there circumstances in which we should get rid of these? *)
    | E_block es -> re (E_block (List.map (const_prop_exp substs) es))
    | E_nondet es -> re (E_nondet (List.map (const_prop_exp substs) es))
       
    | E_id id ->
       (match Envmap.apply substs (id_to_string id) with
       | None -> exp
       | Some exp' -> exp')
    | E_lit _
    | E_sizeof _
    | E_internal_exp _
    | E_sizeof_internal _
    | E_internal_exp_user _
    | E_comment _
      -> exp
    | E_cast (t,e') -> re (E_cast (t, const_prop_exp substs e'))
    | E_app (id,es) -> re (E_app (id,List.map (const_prop_exp substs) es))
    | E_app_infix (e1,id,e2) -> re (E_app_infix (const_prop_exp substs e1,id,const_prop_exp substs e2))
    | E_tuple es -> re (E_tuple (List.map (const_prop_exp substs) es))
    | E_if (e1,e2,e3) -> re (E_if (const_prop_exp substs e1, const_prop_exp substs e2, const_prop_exp substs e3))
    | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3,ord,const_prop_exp (Envmap.remove substs (id_to_string id)) e4))
    | E_vector es -> re (E_vector (List.map (const_prop_exp substs) es))
    | E_vector_indexed (ies,ed) -> re (E_vector_indexed (List.map (fun (i,e) -> (i,const_prop_exp substs e)) ies,
                                                         const_prop_opt_default substs ed))
    | E_vector_access (e1,e2) -> re (E_vector_access (const_prop_exp substs e1,const_prop_exp substs e2))
    | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3))
    | E_vector_update (e1,e2,e3) -> re (E_vector_update (const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3))
    | E_vector_update_subrange (e1,e2,e3,e4) -> re (E_vector_update_subrange (const_prop_exp substs e1,const_prop_exp substs e2,const_prop_exp substs e3,const_prop_exp substs e4))
    | E_vector_append (e1,e2) -> re (E_vector_append (const_prop_exp substs e1,const_prop_exp substs e2))
    | E_list es -> re (E_list (List.map (const_prop_exp substs) es))
    | E_cons (e1,e2) -> re (E_cons (const_prop_exp substs e1,const_prop_exp substs e2))
    | E_record fes -> re (E_record (const_prop_fexps substs fes))
    | E_record_update (e,fes) -> re (E_record_update (const_prop_exp substs e, const_prop_fexps substs fes))
    | E_field (e,id) -> re (E_field (const_prop_exp substs e,id))
    | E_case (e,cases) -> re (E_case (const_prop_exp substs e, List.map (const_prop_pexp substs) cases))
    | E_let (lb,e) ->
       let (lb',substs') = const_prop_letbind substs lb in
       re (E_let (lb', const_prop_exp substs' e))
    | E_assign (le,e) -> re (E_assign (const_prop_lexp substs le, const_prop_exp substs e))
    | E_exit e -> re (E_exit (const_prop_exp substs e))
    | E_return e -> re (E_return (const_prop_exp substs e))
    | E_assert (e1,e2) -> re (E_assert (const_prop_exp substs e1,const_prop_exp substs e2))
    | E_internal_cast (ann,e) -> re (E_internal_cast (ann,const_prop_exp substs e))
    | E_comment_struc e -> re (E_comment_struc e)
    | E_internal_let _
    | E_internal_plet _
    | E_internal_return _
      -> raise (Reporting_basic.err_unreachable l
                  "Unexpected internal expression encountered in monomorphisation")
    and const_prop_opt_default substs ((Def_val_aux (ed,annot)) as eda) =
    match ed with
    | Def_val_empty -> eda
    | Def_val_dec e -> Def_val_aux (Def_val_dec (const_prop_exp substs e),annot)
  and const_prop_fexps substs (FES_aux (FES_Fexps (fes,flag), annot)) =
    FES_aux (FES_Fexps (List.map (const_prop_fexp substs) fes, flag), annot)
  and const_prop_fexp substs (FE_aux (FE_Fexp (id,e), annot)) =
    FE_aux (FE_Fexp (id,const_prop_exp substs e),annot)
  and const_prop_pexp substs (Pat_aux (Pat_exp (p,e),l)) =
    Pat_aux (Pat_exp (p,const_prop_exp (remove_bound t_env substs p) e),l)
  and const_prop_letbind substs (LB_aux (lb,annot)) =
    match lb with
    | LB_val_explicit (tysch,p,e) ->
       (LB_aux (LB_val_explicit (tysch,p,const_prop_exp substs e), annot),
        remove_bound t_env substs p)
    | LB_val_implicit (p,e) ->
       (LB_aux (LB_val_implicit (p,const_prop_exp substs e), annot),
        remove_bound t_env substs p)
  and const_prop_lexp substs ((LEXP_aux (e,annot)) as le) =
    let re e = LEXP_aux (e,annot) in
    match e with
    | LEXP_id _ (* shouldn't end up substituting here *)
    | LEXP_cast _
      -> le
    | LEXP_memory (id,es) -> re (LEXP_memory (id,List.map (const_prop_exp substs) es)) (* or here *)
    | LEXP_tup les -> re (LEXP_tup (List.map (const_prop_lexp substs) les))
    | LEXP_vector (le,e) -> re (LEXP_vector (const_prop_lexp substs le, const_prop_exp substs e))
    | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (const_prop_lexp substs le, const_prop_exp substs e1, const_prop_exp substs e2))
    | LEXP_field (le,id) -> re (LEXP_field (const_prop_lexp substs le, id))
  in

  let subst_exp subst exp =
    if !disable_const_propagation then
    (* TODO: This just sticks a let in - we really need propogation *)
      let (subi,(E_aux (_,subannot) as sube)) = subst in
      let E_aux (e,(l,annot)) = exp in
      let lg = Generated l in
      let p = P_aux (P_id (Id_aux (Id subi, lg)), subannot) in
      E_aux (E_let (LB_aux (LB_val_implicit (p,sube),(lg,annot)), exp),(lg,annot))
    else 
      let substs = Envmap.from_list [subst] in
      const_prop_exp substs exp
  in
    

  (* Split a variable pattern into every possible value *)

  let split id l tannot =
    let new_l = Generated l in
    let new_id i = Id_aux (Id i, new_l) in
    match tannot with
    | Type_internal.NoTyp ->
       raise (Reporting_basic.err_general l ("No type information for variable " ^ id ^ " to split on"))
    | Type_internal.Overload _ ->
       raise (Reporting_basic.err_general l ("Type for variable " ^ id ^ " to split on is overloaded"))
    | Type_internal.Base ((tparams,ty0),_,cs,_,_,_) ->
       let () = match tparams with
         | [] -> ()
         | _ -> raise (Reporting_basic.err_general l ("Type for variable " ^ id ^ " to split on has parameters"))
       in
       let ty = match ty0.t with Tabbrev(_,ty) -> ty | _ -> ty0 in
       let cannot () =
         raise (Reporting_basic.err_general l
                  ("Cannot split type " ^ Type_internal.t_to_string ty ^ " for variable " ^ id))
       in
       (match ty.t with
       | Tid i ->
          (match Envmap.apply d_env.enum_env i with
          (* enumerations *)
          | Some ns -> List.map (fun n -> (P_aux (P_id (new_id n),(l,tannot)),
                                           (id,E_aux (E_id (new_id n),(new_l,tannot))))) ns
          | None -> cannot ())
     (*|  vectors TODO *)
     (*|  numbers TODO *)
       | _ -> cannot ())
  in
  
  (* Split variable patterns at the given locations *)

  let map_locs ls (Defs defs) =
    let rec match_l = function
      | Unknown
      | Int _ -> []
      | Generated l -> [] (* Could do match_l l, but only want to split user-written patterns *)
      | Range (p,q) ->
         List.filter (fun ((filename,line),_) ->
           Filename.basename p.Lexing.pos_fname = filename &&
             p.Lexing.pos_lnum <= line && line <= q.Lexing.pos_lnum) ls
    in 
    
    let split_pat var p =
      let rec list f = function
        | [] -> None
        | h::t ->
           match f h with
           | None -> (match list f t with None -> None | Some (l,ps,r) -> Some (h::l,ps,r))
           | Some ps -> Some ([],ps,t)
      in
      let rec spl (P_aux (p,(l,annot))) =
        let relist f ctx ps =
          optmap (list f ps) 
            (fun (left,ps,right) ->
              List.map (fun (p,sub) -> P_aux (ctx (left@p::right),(l,annot)),sub) ps)
        in
        let re f p =
          optmap (spl p)
            (fun ps -> List.map (fun (p,sub) -> (P_aux (f p,(l,annot)), sub)) ps)
        in
        let fpat (FP_aux ((FP_Fpat (id,p),annot))) =
          optmap (spl p)
            (fun ps -> List.map (fun (p,sub) -> FP_aux (FP_Fpat (id,p), annot), sub) ps)
        in
        let ipat (i,p) = optmap (spl p) (List.map (fun (p,sub) -> (i,p),sub))
        in
        match p with
        | P_lit _
        | P_wild
          -> None
        | P_as (p',id) ->
           let i = id_to_string id in
           if i = var
           then raise (Reporting_basic.err_general l
                         ("Cannot split " ^ var ^ " on 'as' pattern"))
           else re (fun p -> P_as (p,id)) p'
        | P_typ (t,p') -> re (fun p -> P_typ (t,p)) p'
        | P_id id ->
           let i = id_to_string id in
           if i = var
           then Some (split i l annot)
           else None
        | P_app (id,ps) ->
           relist spl (fun ps -> P_app (id,ps)) ps
        | P_record (fps,flag) ->
           relist fpat (fun fps -> P_record (fps,flag)) fps
        | P_vector ps ->
           relist spl (fun ps -> P_vector ps) ps
        | P_vector_indexed ips ->
           relist ipat (fun ips -> P_vector_indexed ips) ips
        | P_vector_concat ps ->
           relist spl (fun ps -> P_vector_concat ps) ps
        | P_tup ps ->
           relist spl (fun ps -> P_tup ps) ps
        | P_list ps ->
           relist spl (fun ps -> P_list ps) ps
      in spl p
    in
    
    let map_pat (P_aux (_,(l,_)) as p) =
      match match_l l with
      | [] -> None
      | [(_,var)] -> split_pat var p
      | lvs -> raise (Reporting_basic.err_general l
                        ("Multiple variables to split on: " ^ String.concat ", " (List.map snd lvs)))
    in

    let check_single_pat (P_aux (_,(l,_)) as p) =
      match match_l l with
      | [] -> p
      | lvs ->
         let pvs = bindings_from_pat t_env p in
         let overlap = List.exists (fun (_,v) -> List.mem v pvs) lvs in
         let () =
           if overlap then
             Reporting_basic.print_err false true l "Monomorphisation"
               "Splitting a singleton pattern is not possible"
         in p
    in

    let rec map_exp ((E_aux (e,annot)) as ea) =
      let re e = E_aux (e,annot) in
      match e with
      | E_block es -> re (E_block (List.map map_exp es))
      | E_nondet es -> re (E_nondet (List.map map_exp es))
      | E_id _
      | E_lit _
      | E_sizeof _
      | E_internal_exp _
      | E_sizeof_internal _
      | E_internal_exp_user _
      | E_comment _
        -> ea
      | E_cast (t,e') -> re (E_cast (t, map_exp e'))
      | E_app (id,es) -> re (E_app (id,List.map map_exp es))
      | E_app_infix (e1,id,e2) -> re (E_app_infix (map_exp e1,id,map_exp e2))
      | E_tuple es -> re (E_tuple (List.map map_exp es))
      | E_if (e1,e2,e3) -> re (E_if (map_exp e1, map_exp e2, map_exp e3))
      | E_for (id,e1,e2,e3,ord,e4) -> re (E_for (id,map_exp e1,map_exp e2,map_exp e3,ord,map_exp e4))
      | E_vector es -> re (E_vector (List.map map_exp es))
      | E_vector_indexed (ies,ed) -> re (E_vector_indexed (List.map (fun (i,e) -> (i,map_exp e)) ies,
                                                           map_opt_default ed))
      | E_vector_access (e1,e2) -> re (E_vector_access (map_exp e1,map_exp e2))
      | E_vector_subrange (e1,e2,e3) -> re (E_vector_subrange (map_exp e1,map_exp e2,map_exp e3))
      | E_vector_update (e1,e2,e3) -> re (E_vector_update (map_exp e1,map_exp e2,map_exp e3))
      | E_vector_update_subrange (e1,e2,e3,e4) -> re (E_vector_update_subrange (map_exp e1,map_exp e2,map_exp e3,map_exp e4))
      | E_vector_append (e1,e2) -> re (E_vector_append (map_exp e1,map_exp e2))
      | E_list es -> re (E_list (List.map map_exp es))
      | E_cons (e1,e2) -> re (E_cons (map_exp e1,map_exp e2))
      | E_record fes -> re (E_record (map_fexps fes))
      | E_record_update (e,fes) -> re (E_record_update (map_exp e, map_fexps fes))
      | E_field (e,id) -> re (E_field (map_exp e,id))
      | E_case (e,cases) -> re (E_case (map_exp e, List.concat (List.map map_pexp cases)))
      | E_let (lb,e) -> re (E_let (map_letbind lb, map_exp e))
      | E_assign (le,e) -> re (E_assign (map_lexp le, map_exp e))
      | E_exit e -> re (E_exit (map_exp e))
      | E_return e -> re (E_return (map_exp e))
      | E_assert (e1,e2) -> re (E_assert (map_exp e1,map_exp e2))
      | E_internal_cast (ann,e) -> re (E_internal_cast (ann,map_exp e))
      | E_comment_struc e -> re (E_comment_struc e)
      | E_internal_let (le,e1,e2) -> re (E_internal_let (map_lexp le, map_exp e1, map_exp e2))
      | E_internal_plet (p,e1,e2) -> re (E_internal_plet (check_single_pat p, map_exp e1, map_exp e2))
      | E_internal_return e -> re (E_internal_return (map_exp e))
    and map_opt_default ((Def_val_aux (ed,annot)) as eda) =
      match ed with
      | Def_val_empty -> eda
      | Def_val_dec e -> Def_val_aux (Def_val_dec (map_exp e),annot)
    and map_fexps (FES_aux (FES_Fexps (fes,flag), annot)) =
      FES_aux (FES_Fexps (List.map map_fexp fes, flag), annot)
    and map_fexp (FE_aux (FE_Fexp (id,e), annot)) =
      FE_aux (FE_Fexp (id,map_exp e),annot)
    and map_pexp (Pat_aux (Pat_exp (p,e),l)) =
      match map_pat p with
      | None -> [Pat_aux (Pat_exp (p,map_exp e),l)]
      | Some patsubsts ->
         List.map (fun (pat',subst) ->
           let exp' = subst_exp subst e in
           Pat_aux (Pat_exp (pat', map_exp exp'),l))
           patsubsts
    and map_letbind (LB_aux (lb,annot)) =
      match lb with
      | LB_val_explicit (tysch,p,e) -> LB_aux (LB_val_explicit (tysch,check_single_pat p,map_exp e), annot)
      | LB_val_implicit (p,e) -> LB_aux (LB_val_implicit (check_single_pat p,map_exp e), annot)
    and map_lexp ((LEXP_aux (e,annot)) as le) =
      let re e = LEXP_aux (e,annot) in
      match e with
      | LEXP_id _
      | LEXP_cast _
        -> le
      | LEXP_memory (id,es) -> re (LEXP_memory (id,List.map map_exp es))
      | LEXP_tup les -> re (LEXP_tup (List.map map_lexp les))
      | LEXP_vector (le,e) -> re (LEXP_vector (map_lexp le, map_exp e))
      | LEXP_vector_range (le,e1,e2) -> re (LEXP_vector_range (map_lexp le, map_exp e1, map_exp e2))
      | LEXP_field (le,id) -> re (LEXP_field (map_lexp le, id))
    in

    let map_funcl (FCL_aux (FCL_Funcl (id,pat,exp),annot)) =
      match map_pat pat with
      | None -> [FCL_aux (FCL_Funcl (id, pat, map_exp exp), annot)]
      | Some patsubsts ->
         List.map (fun (pat',subst) ->
           let exp' = subst_exp subst exp in
           FCL_aux (FCL_Funcl (id, pat', map_exp exp'), annot))
           patsubsts
    in

    let map_fundef (FD_aux (FD_function (r,t,e,fcls),annot)) =
      FD_aux (FD_function (r,t,e,List.concat (List.map map_funcl fcls)),annot)
    in
    let map_scattered_def sd =
      match sd with
      | SD_aux (SD_scattered_funcl fcl, annot) ->
         List.map (fun fcl' -> SD_aux (SD_scattered_funcl fcl', annot)) (map_funcl fcl)
      | _ -> [sd]
    in
    let map_def d =
      match d with
      | DEF_kind _
      | DEF_type _
      | DEF_spec _
      | DEF_default _
      | DEF_reg_dec _
      | DEF_comm _
        -> [d]
      | DEF_fundef fd -> [DEF_fundef (map_fundef fd)]
      | DEF_val lb -> [DEF_val (map_letbind lb)]
      | DEF_scattered sd -> List.map (fun x -> DEF_scattered x) (map_scattered_def sd)

    in
    Defs (List.concat (List.map map_def defs))

  in map_locs splits defs