1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
|
open Big_int
open Ast
open Type_internal
type typ = Type_internal.t
type 'a exp = 'a Ast.exp
type 'a emap = 'a Envmap.t
type envs = Type_check.envs
type 'a rewriters = { rewrite_exp : 'a rewriters -> nexp_map option -> 'a exp -> 'a exp;
rewrite_lexp : 'a rewriters -> nexp_map option -> 'a lexp -> 'a lexp;
rewrite_pat : 'a rewriters -> nexp_map option -> 'a pat -> 'a pat;
rewrite_let : 'a rewriters -> nexp_map option -> 'a letbind -> 'a letbind;
rewrite_fun : 'a rewriters -> 'a fundef -> 'a fundef;
rewrite_def : 'a rewriters -> 'a def -> 'a def;
rewrite_defs : 'a rewriters -> 'a defs -> 'a defs;
}
let rec partial_assoc (eq: 'a -> 'a -> bool) (v: 'a) (ls : ('a *'b) list ) : 'b option = match ls with
| [] -> None
| (v1,v2)::ls -> if (eq v1 v) then Some v2 else partial_assoc eq v ls
let mk_atom_typ i = {t=Tapp("atom",[TA_nexp i])}
let rec rewrite_nexp_to_exp program_vars l nexp =
let rewrite n = rewrite_nexp_to_exp program_vars l n in
let typ = mk_atom_typ nexp in
let actual_rewrite_n nexp =
match nexp.nexp with
| Nconst i -> E_aux (E_lit (L_aux (L_num (int_of_big_int i),l)), (l,simple_annot typ))
| Nadd (n1,n2) -> E_aux (E_app_infix (rewrite n1,(Id_aux (Id "+",l)),rewrite n2),
(l, (tag_annot typ (External (Some "add")))))
| Nmult (n1,n2) -> E_aux (E_app_infix (rewrite n1,(Id_aux (Id "*",l)),rewrite n2),
(l, tag_annot typ (External (Some "multiply"))))
| Nsub (n1,n2) -> E_aux (E_app_infix (rewrite n1,(Id_aux (Id "-",l)),rewrite n2),
(l, tag_annot typ (External (Some "minus"))))
| N2n (n, _) -> E_aux (E_app_infix (E_aux (E_lit (L_aux (L_num 2,l)), (l, simple_annot (mk_atom_typ n_two))),
(Id_aux (Id "**",l)),
rewrite n), (l, tag_annot typ (External (Some "power"))))
| Npow(n,i) -> E_aux (E_app_infix
(rewrite n, (Id_aux (Id "**",l)),
E_aux (E_lit (L_aux (L_num i,l)),
(l, simple_annot (mk_atom_typ (mk_c_int i))))),
(l, tag_annot typ (External (Some "power"))))
| Nneg(n) -> E_aux (E_app_infix (E_aux (E_lit (L_aux (L_num 0,l)), (l, simple_annot (mk_atom_typ n_zero))),
(Id_aux (Id "-",l)),
rewrite n),
(l, tag_annot typ (External (Some "minus"))))
| Nvar v -> (*TODO these need to generate an error as it's a place where there's insufficient specification.
But, for now I need to permit this to make power.sail compile, and most errors are in trap
or vectors *)
(*let _ = Printf.eprintf "unbound variable here %s\n" v in*)
E_aux (E_id (Id_aux (Id v,l)),(l,simple_annot typ))
| _ -> raise (Reporting_basic.err_unreachable l ("rewrite_nexp given n that can't be rewritten: " ^ (n_to_string nexp))) in
match program_vars with
| None -> actual_rewrite_n nexp
| Some program_vars ->
(match partial_assoc nexp_eq_check nexp program_vars with
| None -> actual_rewrite_n nexp
| Some(None,ev) ->
(*let _ = Printf.eprintf "var case of rewrite, %s\n" ev in*)
E_aux (E_id (Id_aux (Id ev,l)), (l, simple_annot typ))
| Some(Some f,ev) ->
E_aux (E_app ((Id_aux (Id f,l)), [ (E_aux (E_id (Id_aux (Id ev,l)), (l,simple_annot typ)))]),
(l, tag_annot typ (External (Some f)))))
let rec match_to_program_vars ns bounds =
match ns with
| [] -> []
| n::ns -> match find_var_from_nexp n bounds with
| None -> match_to_program_vars ns bounds
| Some(augment,ev) ->
(*let _ = Printf.eprintf "adding n %s to program var %s\n" (n_to_string n) ev in*)
(n,(augment,ev))::(match_to_program_vars ns bounds)
let explode s =
let rec exp i l = if i < 0 then l else exp (i - 1) (s.[i] :: l) in
exp (String.length s - 1) []
let vector_string_to_bit_list lit =
let hexchar_to_binlist = function
| '0' -> ['0';'0';'0';'0']
| '1' -> ['0';'0';'0';'1']
| '2' -> ['0';'0';'1';'0']
| '3' -> ['0';'0';'1';'1']
| '4' -> ['0';'1';'0';'0']
| '5' -> ['0';'1';'0';'1']
| '6' -> ['0';'1';'1';'0']
| '7' -> ['0';'1';'1';'1']
| '8' -> ['1';'0';'0';'0']
| '9' -> ['1';'0';'0';'1']
| 'A' -> ['1';'0';'1';'0']
| 'B' -> ['1';'0';'1';'1']
| 'C' -> ['1';'1';'0';'0']
| 'D' -> ['1';'1';'0';'1']
| 'E' -> ['1';'1';'1';'0']
| 'F' -> ['1';'1';'1';'1']
| _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "hexchar_to_binlist given unrecognized character") in
let s_bin = match lit with
| L_hex s_hex -> List.flatten (List.map hexchar_to_binlist (explode (String.uppercase s_hex)))
| L_bin s_bin -> explode s_bin
| _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "s_bin given non vector literal") in
List.map (function '0' -> L_aux (L_zero, Parse_ast.Unknown)
| '1' -> L_aux (L_one,Parse_ast.Unknown)
| _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "binary had non-zero or one")) s_bin
let rewrite_pat rewriters nmap (P_aux (pat,(l,annot))) =
let rewrap p = P_aux (p,(l,annot)) in
let rewrite = rewriters.rewrite_pat rewriters nmap in
match pat with
| P_lit (L_aux ((L_hex _ | L_bin _) as lit,_)) ->
let ps = List.map (fun p -> P_aux (P_lit p,(Parse_ast.Unknown,simple_annot {t = Tid "bit"})))
(vector_string_to_bit_list lit) in
rewrap (P_vector ps)
| P_lit _ | P_wild | P_id _ -> rewrap pat
| P_as(pat,id) -> rewrap (P_as( rewrite pat, id))
| P_typ(typ,pat) -> rewrap (P_typ(typ,rewrite pat))
| P_app(id ,pats) -> rewrap (P_app(id, List.map rewrite pats))
| P_record(fpats,_) ->
rewrap (P_record(List.map (fun (FP_aux(FP_Fpat(id,pat),pannot)) -> FP_aux(FP_Fpat(id, rewrite pat), pannot)) fpats,
false))
| P_vector pats -> rewrap (P_vector(List.map rewrite pats))
| P_vector_indexed ipats -> rewrap (P_vector_indexed(List.map (fun (i,pat) -> (i, rewrite pat)) ipats))
| P_vector_concat pats -> rewrap (P_vector_concat (List.map rewrite pats))
| P_tup pats -> rewrap (P_tup (List.map rewrite pats))
| P_list pats -> rewrap (P_list (List.map rewrite pats))
let rewrite_exp rewriters nmap (E_aux (exp,(l,annot))) =
let rewrap e = E_aux (e,(l,annot)) in
let rewrite = rewriters.rewrite_exp rewriters nmap in
match exp with
| E_block exps -> rewrap (E_block (List.map rewrite exps))
| E_nondet exps -> rewrap (E_nondet (List.map rewrite exps))
| E_lit (L_aux ((L_hex _ | L_bin _) as lit,_)) ->
let es = List.map (fun p -> E_aux (E_lit p ,(Parse_ast.Unknown,simple_annot {t = Tid "bit"})))
(vector_string_to_bit_list lit) in
rewrap (E_vector es)
| E_id _ | E_lit _ -> rewrap exp
| E_cast (typ, exp) -> rewrap (E_cast (typ, rewrite exp))
| E_app (id,exps) -> rewrap (E_app (id,List.map rewrite exps))
| E_app_infix(el,id,er) -> rewrap (E_app_infix(rewrite el,id,rewrite er))
| E_tuple exps -> rewrap (E_tuple (List.map rewrite exps))
| E_if (c,t,e) -> rewrap (E_if (rewrite c,rewrite t, rewrite e))
| E_for (id, e1, e2, e3, o, body) ->
rewrap (E_for (id, rewrite e1, rewrite e2, rewrite e3, o, rewrite body))
| E_vector exps -> rewrap (E_vector (List.map rewrite exps))
| E_vector_indexed (exps,(Def_val_aux(default,dannot))) ->
let def = match default with
| Def_val_empty -> default
| Def_val_dec e -> Def_val_dec (rewrite e) in
rewrap (E_vector_indexed (List.map (fun (i,e) -> (i, rewrite e)) exps, Def_val_aux(def,dannot)))
| E_vector_access (vec,index) -> rewrap (E_vector_access (rewrite vec,rewrite index))
| E_vector_subrange (vec,i1,i2) ->
rewrap (E_vector_subrange (rewrite vec,rewrite i1,rewrite i2))
| E_vector_update (vec,index,new_v) ->
rewrap (E_vector_update (rewrite vec,rewrite index,rewrite new_v))
| E_vector_update_subrange (vec,i1,i2,new_v) ->
rewrap (E_vector_update_subrange (rewrite vec,rewrite i1,rewrite i2,rewrite new_v))
| E_vector_append (v1,v2) -> rewrap (E_vector_append (rewrite v1,rewrite v2))
| E_list exps -> rewrap (E_list (List.map rewrite exps))
| E_cons(h,t) -> rewrap (E_cons (rewrite h,rewrite t))
| E_record (FES_aux (FES_Fexps(fexps, bool),fannot)) ->
rewrap (E_record
(FES_aux (FES_Fexps
(List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) ->
FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot)))
| E_record_update (re,(FES_aux (FES_Fexps(fexps, bool),fannot))) ->
rewrap (E_record_update ((rewrite re),
(FES_aux (FES_Fexps
(List.map (fun (FE_aux(FE_Fexp(id,e),fannot)) ->
FE_aux(FE_Fexp(id,rewrite e),fannot)) fexps, bool), fannot))))
| E_field(exp,id) -> rewrap (E_field(rewrite exp,id))
| E_case (exp ,pexps) ->
rewrap (E_case (rewrite exp,
(List.map
(fun (Pat_aux (Pat_exp(p,e),pannot)) ->
Pat_aux (Pat_exp(p,rewrite e),pannot)) pexps)))
| E_let (letbind,body) -> rewrap (E_let(rewriters.rewrite_let rewriters nmap letbind,rewrite body))
| E_assign (lexp,exp) -> rewrap (E_assign(rewriters.rewrite_lexp rewriters nmap lexp,rewrite exp))
| E_exit e -> rewrap (E_exit (rewrite e))
| E_internal_cast ((_,casted_annot),exp) ->
let new_exp = rewrite exp in
(*let _ = Printf.eprintf "Removing an internal_cast with %s\n" (tannot_to_string casted_annot) in*)
(match casted_annot,exp with
| Base((_,t),_,_,_,_),E_aux(ec,(ecl,Base((_,exp_t),_,_,_,_))) ->
(*let _ = Printf.eprintf "Considering removing an internal cast where the two types are %s and %s\n" (t_to_string t) (t_to_string exp_t) in*)
(match t.t,exp_t.t with
(*TODO should pass d_env into here so that I can look at the abbreviations if there are any here*)
| Tapp("vector",[TA_nexp n1;TA_nexp nw1;TA_ord o1;_]),
Tapp("vector",[TA_nexp n2;TA_nexp nw2;TA_ord o2;_])
| Tapp("vector",[TA_nexp n1;TA_nexp nw1;TA_ord o1;_]),
Tapp("reg",[TA_typ {t=(Tapp("vector",[TA_nexp n2; TA_nexp nw2; TA_ord o2;_]))}]) ->
(match n1.nexp with
| Nconst i1 -> if nexp_eq n1 n2 then new_exp else rewrap (E_cast (t_to_typ t,new_exp))
| _ -> (match o1.order with
| Odec ->
(*let _ = Printf.eprintf "Considering removing a cast or not: %s %s, %b\n"
(n_to_string nw1) (n_to_string n1) (nexp_one_more_than nw1 n1) in*)
rewrap (E_cast (Typ_aux (Typ_var (Kid_aux((Var "length"),Parse_ast.Unknown)),
Parse_ast.Unknown),new_exp))
| _ -> new_exp))
| _ -> new_exp)
| Base((_,t),_,_,_,_),_ ->
(*let _ = Printf.eprintf "Considering removing an internal cast where the remaining type is %s\n%!"
(t_to_string t) in*)
(match t.t with
| Tapp("vector",[TA_nexp n1;TA_nexp nw1;TA_ord o1;_]) ->
(match o1.order with
| Odec ->
(*let _ = Printf.eprintf "Considering removing a cast or not: %s %s, %b\n"
(n_to_string nw1) (n_to_string n1) (nexp_one_more_than nw1 n1) in*)
rewrap (E_cast (Typ_aux (Typ_var (Kid_aux((Var "length"), Parse_ast.Unknown)),
Parse_ast.Unknown), new_exp))
| _ -> new_exp)
| _ -> new_exp)
| _ -> (*let _ = Printf.eprintf "Not a base match?\n" in*) new_exp)
| E_internal_exp (l,impl) ->
(match impl with
| Base((_,t),_,_,_,bounds) ->
let bounds = match nmap with | None -> bounds | Some nm -> add_map_to_bounds nm bounds in
(*let _ = Printf.eprintf "Rewriting internal expression, with type %s\n" (t_to_string t) in*)
(match t.t with
(*Old case; should possibly be removed*)
| Tapp("register",[TA_typ {t= Tapp("vector",[ _; TA_nexp r;_;_])}])
| Tapp("vector", [_;TA_nexp r;_;_]) ->
(*let _ = Printf.eprintf "vector case with %s, bounds are %s\n"
(n_to_string r) (bounds_to_string bounds) in*)
let nexps = expand_nexp r in
(match (match_to_program_vars nexps bounds) with
| [] -> rewrite_nexp_to_exp None l r
| map -> rewrite_nexp_to_exp (Some map) l r)
| Tapp("implicit", [TA_nexp i]) ->
(*let _ = Printf.eprintf "Implicit case with %s\n" (n_to_string i) in*)
let nexps = expand_nexp i in
(match (match_to_program_vars nexps bounds) with
| [] -> rewrite_nexp_to_exp None l i
| map -> rewrite_nexp_to_exp (Some map) l i)
| _ ->
raise (Reporting_basic.err_unreachable l
("Internal_exp given unexpected types " ^ (t_to_string t))))
| _ -> raise (Reporting_basic.err_unreachable l ("Internal_exp given none Base annot")))
| E_internal_exp_user ((l,user_spec),(_,impl)) ->
(match (user_spec,impl) with
| (Base((_,tu),_,_,_,_), Base((_,ti),_,_,_,bounds)) ->
(*let _ = Printf.eprintf "E_interal_user getting rewritten two types are %s and %s\n"
(t_to_string tu) (t_to_string ti) in*)
let bounds = match nmap with | None -> bounds | Some nm -> add_map_to_bounds nm bounds in
(match (tu.t,ti.t) with
| (Tapp("implicit", [TA_nexp u]),Tapp("implicit",[TA_nexp i])) ->
(*let _ = Printf.eprintf "Implicit case with %s\n" (n_to_string i) in*)
let nexps = expand_nexp i in
(match (match_to_program_vars nexps bounds) with
| [] -> rewrite_nexp_to_exp None l i
(*add u to program_vars env; for now it will work out properly by accident*)
| map -> rewrite_nexp_to_exp (Some map) l i)
| _ ->
raise (Reporting_basic.err_unreachable l
("Internal_exp_user given unexpected types " ^ (t_to_string tu) ^ ", " ^ (t_to_string ti))))
| _ -> raise (Reporting_basic.err_unreachable l ("Internal_exp_user given none Base annot")))
| E_internal_let _ -> raise (Reporting_basic.err_unreachable l "Internal let found before it should have been introduced")
let rewrite_let rewriters map (LB_aux(letbind,(l,annot))) =
let map = merge_option_maps map (get_map_tannot annot) in
match letbind with
| LB_val_explicit (typschm, pat,exp) ->
LB_aux(LB_val_explicit (typschm,pat, rewriters.rewrite_exp rewriters map exp),(l,annot))
| LB_val_implicit ( pat, exp) ->
LB_aux(LB_val_implicit (pat,rewriters.rewrite_exp rewriters map exp),(l,annot))
let rewrite_lexp rewriters map (LEXP_aux(lexp,(l,annot))) =
let rewrap le = LEXP_aux(le,(l,annot)) in
match lexp with
| LEXP_id _ | LEXP_cast _ -> rewrap lexp
| LEXP_memory (id,exps) -> rewrap (LEXP_memory(id,List.map (rewriters.rewrite_exp rewriters map) exps))
| LEXP_vector (lexp,exp) ->
rewrap (LEXP_vector (rewriters.rewrite_lexp rewriters map lexp,rewriters.rewrite_exp rewriters map exp))
| LEXP_vector_range (lexp,exp1,exp2) ->
rewrap (LEXP_vector_range (rewriters.rewrite_lexp rewriters map lexp,
rewriters.rewrite_exp rewriters map exp1,
rewriters.rewrite_exp rewriters map exp2))
| LEXP_field (lexp,id) -> rewrap (LEXP_field (rewriters.rewrite_lexp rewriters map lexp,id))
let rewrite_fun rewriters (FD_aux (FD_function(recopt,tannotopt,effectopt,funcls),(l,fdannot))) =
let rewrite_funcl (FCL_aux (FCL_Funcl(id,pat,exp),(l,annot))) =
(*let _ = Printf.eprintf "Rewriting function %s, pattern %s\n"
(match id with (Id_aux (Id i,_)) -> i) (Pretty_print.pat_to_string pat) in*)
(FCL_aux (FCL_Funcl (id,pat,rewriters.rewrite_exp rewriters (get_map_tannot fdannot) exp),(l,annot)))
in FD_aux (FD_function(recopt,tannotopt,effectopt,List.map rewrite_funcl funcls),(l,fdannot))
let rewrite_def rewriters d = match d with
| DEF_type _ | DEF_spec _ | DEF_default _ | DEF_reg_dec _ -> d
| DEF_fundef fdef -> DEF_fundef (rewriters.rewrite_fun rewriters fdef)
| DEF_val letbind -> DEF_val (rewriters.rewrite_let rewriters None letbind)
| DEF_scattered _ -> raise (Reporting_basic.err_unreachable Parse_ast.Unknown "DEF_scattered survived to rewritter")
let rewrite_defs_base rewriters (Defs defs) =
let rec rewrite ds = match ds with
| [] -> []
| d::ds -> (rewriters.rewrite_def rewriters d)::(rewrite ds) in
Defs (rewrite defs)
let rewrite_defs (Defs defs) = rewrite_defs_base
{rewrite_exp = rewrite_exp;
rewrite_pat = rewrite_pat;
rewrite_let = rewrite_let;
rewrite_lexp = rewrite_lexp;
rewrite_fun = rewrite_fun;
rewrite_def = rewrite_def;
rewrite_defs = rewrite_defs_base} (Defs defs)
(* signature of patterns *)
type ('pat,'pat_aux,'fpat,'fpat_aux,'annot) pat_alg =
{ p_lit : lit -> 'pat_aux
; p_wild : 'pat_aux
; p_as : 'pat * id -> 'pat_aux
; p_typ : Ast.typ * 'pat -> 'pat_aux
; p_id : id -> 'pat_aux
; p_app : id * 'pat list -> 'pat_aux
; p_record : 'fpat list * bool -> 'pat_aux
; p_vector : 'pat list -> 'pat_aux
; p_vector_indexed : (int * 'pat) list -> 'pat_aux
; p_vector_concat : 'pat list -> 'pat_aux
; p_tup : 'pat list -> 'pat_aux
; p_list : 'pat list -> 'pat_aux
; p_aux : 'pat_aux * 'annot -> 'pat
; fP_aux : 'fpat_aux * 'annot -> 'fpat
; fP_Fpat : id * 'pat -> 'fpat_aux
}
(* fold from term alg into alg *)
let rec fold_pat_aux alg = function
| P_lit lit -> alg.p_lit lit
| P_wild -> alg.p_wild
| P_id id -> alg.p_id id
| P_as (p,id) -> alg.p_as (fold_pat alg p,id)
| P_typ (typ,p) -> alg.p_typ (typ,fold_pat alg p)
| P_app (id,ps) -> alg.p_app (id,List.map (fold_pat alg) ps)
| P_record (ps,b) -> alg.p_record (List.map (fold_fpat alg) ps, b)
| P_vector ps -> alg.p_vector (List.map (fold_pat alg) ps)
| P_vector_indexed ps -> alg.p_vector_indexed (List.map (fun (i,p) -> (i, fold_pat alg p)) ps)
| P_vector_concat ps -> alg.p_vector_concat (List.map (fold_pat alg) ps)
| P_tup ps -> alg.p_tup (List.map (fold_pat alg) ps)
| P_list ps -> alg.p_list (List.map (fold_pat alg) ps)
and fold_pat alg = function
| P_aux (pat,annot) -> alg.p_aux (fold_pat_aux alg pat,annot)
and fold_fpat_aux alg = function
| FP_Fpat (id,pat) -> alg.fP_Fpat (id,fold_pat alg pat)
and fold_fpat alg = function
| FP_aux (fpat,annot) -> alg.fP_aux (fold_fpat_aux alg fpat,annot)
(* identity fold from term alg to term alg *)
let id_f : ('a pat, 'a pat_aux, 'a fpat, 'a fpat_aux, 'a annot) pat_alg =
{ p_lit = (fun lit -> P_lit lit)
; p_wild = P_wild
; p_as = (fun (pat,id) -> P_as (pat,id))
; p_typ = (fun (typ,pat) -> P_typ (typ,pat))
; p_id = (fun id -> P_id id)
; p_app = (fun (id,ps) -> P_app (id,ps))
; p_record = (fun (ps,b) -> P_record (ps,b))
; p_vector = (fun ps -> P_vector ps)
; p_vector_indexed = (fun ps -> P_vector_indexed ps)
; p_vector_concat = (fun ps -> P_vector_concat ps)
; p_tup = (fun ps -> P_tup ps)
; p_list = (fun ps -> P_list ps)
; p_aux = (fun (pat,annot) -> P_aux (pat,annot))
; fP_aux = (fun (fpat,annot) -> FP_aux (fpat,annot))
; fP_Fpat = (fun (id,pat) -> FP_Fpat (id,pat))
}
let remove_vector_concat_pat pat =
let counter = ref 0 in
let fresh_name () =
let current = !counter in
let () = counter := (current + 1) in
Id_aux (Id ("__v" ^ string_of_int current), Parse_ast.Unknown) in
(* expects that P_typ elements have been removed from AST,
that the length of all vectors involved is known,
that we don't have indexed vectors *)
(* introduce names for all patterns of form P_vector_concat *)
let name_vector_concat_roots =
let p_aux (pat,annot) = match pat with
| P_vector_concat pats -> P_aux (P_as (P_aux (pat,annot),fresh_name()),annot)
| _ -> P_aux (pat,annot) in
{id_f with p_aux = p_aux} in
let pat = fold_pat name_vector_concat_roots pat in
(* introduce names for all unnamed child nodes of P_vector_concat *)
let name_vector_concat_elements =
let p_vector_concat pats =
let aux ((P_aux (p,a)) as pat) = match p with
| P_vector _ -> P_aux (P_as (pat,fresh_name()),a)
(* | P_vector_concat. cannot happen after fold function name_vector_concat_roots *)
| _ -> pat in (* this can only be P_as and P_id *)
P_vector_concat (List.map aux pats) in
{id_f with p_vector_concat = p_vector_concat} in
let pat = fold_pat name_vector_concat_elements pat in
let zip l1 l2 = List.fold_right2 (fun x y acc -> (x,y) :: acc) l1 l2 [] in
let unzip l = List.fold_right (fun (a,b) (accA,accB) -> (a :: accA, b :: accB)) l ([],[]) in
(* remove names from vectors in vector_concat patterns and collect them as declarations for the
function body or expression *)
let unname_vector_concat_elements : ('a pat * (string list), 'a pat_aux * (string list), 'a fpat * (string list),
'a fpat_aux * (string list), 'a annot) pat_alg =
let p_aux ((pattern,decls),annot) = match pattern with
| P_as (P_aux (P_vector_concat pats,_),name) ->
let aux (pat_acc,decl_acc,pos) = function
| (P_aux (P_as (P_aux (p,annot),name2),
(l,Base(([],{t = Tapp ("vector",[_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_)))) ->
(pat_acc @ [P_aux (p,annot)],
decl_acc @ ["define name2 as vector <name> [pos;pos + length -1]"],
add_big_int pos length)
| (P_aux (P_id name2,
((l,Base(([],{t = Tapp ("vector", [_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_)) as annot))) ->
(pat_acc @ [P_aux (P_id name2,annot)],
decl_acc @ ["define name2 as vector <name> [pos;pos + length -1]"],
add_big_int pos length)
| (P_aux (_,(l,Base(([],{t = Tapp ("vector", [_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_))) as p)
-> (pat_acc @ [p],decl_acc,add_big_int pos length)
| (P_aux (_,(l,_))) -> raise (Reporting_basic.err_unreachable l "Non-vector in vector-concat pattern") in
let (pats',decls',_) = List.fold_left aux ([],[],zero_big_int) pats in
(P_aux (P_vector_concat pats',annot),decls @ decls')
| _ -> (P_aux (pattern,annot),decls) in
{ p_lit = (fun lit -> (P_lit lit,[]))
; p_wild = (P_wild,[])
; p_as = (fun ((pat,decls),id) -> (P_as (pat,id),decls))
; p_typ = (fun (typ,(pat,decls)) -> (P_typ (typ,pat),decls))
; p_id = (fun id -> (P_id id,[]))
; p_app = (fun (id,ps) -> let (ps,decls) = unzip ps in (P_app (id,ps),List.flatten decls))
; p_record = (fun (ps,b) -> let (ps,decls) = unzip ps in (P_record (ps,b),List.flatten decls))
; p_vector = (fun ps -> let (ps,decls) = unzip ps in (P_vector ps,List.flatten decls))
; p_vector_indexed = (fun ps -> let (is,ps) = unzip ps in let (ps,decls) = unzip ps in let ps = zip is ps in
(P_vector_indexed ps,List.flatten decls))
; p_vector_concat = (fun ps -> let (ps,decls) = unzip ps in (P_vector_concat ps,List.flatten decls))
; p_tup = (fun ps -> let (ps,decls) = unzip ps in (P_tup ps,List.flatten decls))
; p_list = (fun ps -> let (ps,decls) = unzip ps in (P_list ps,List.flatten decls))
; p_aux = (fun ((pat,decls),annot) -> p_aux ((pat,decls),annot))
; fP_aux = (fun ((fpat,decls),annot) -> (FP_aux (fpat,annot),decls))
; fP_Fpat = (fun (id,(pat,decls)) -> (FP_Fpat (id,pat),decls))
} in
let (pat,decls) = fold_pat unname_vector_concat_elements pat in
(* at this point shouldn't have P_as patterns in P_vector_concat patterns any more,
all P_as and P_id vectors should have their declarations in decls.
Now flatten all vector_concat patterns*)
let flatten =
let p_vector_concat ps =
let aux p acc = match p with
| (P_aux (P_vector_concat pats,_)) -> pats @ acc
| pat -> pat :: acc in
P_vector_concat (List.fold_right aux ps []) in
{id_f with p_vector_concat = p_vector_concat} in
let pat = fold_pat flatten pat in
(* at this point pat should be a flat pattern: no vector_concat patterns
with vector_concats patterns as direct child-nodes anymore *)
let range a b =
let rec aux a b = if a > b then [] else a :: aux (a+1) b in
if a > b then List.rev (aux b a) else aux a b in
let remove_vector_concats =
let p_vector_concat ps =
let aux acc = function
| P_aux (P_vector ps,annot) -> acc @ ps
| P_aux (P_id name2, (_,Base(([],{t = Tapp ("vector", [_;TA_nexp {nexp = Nconst length};_;_])}),_,_,_,_))) ->
let wild _ = P_aux (P_wild,(Parse_ast.Unknown,simple_annot {t = Tid "bit"})) in
acc @ (List.map wild (range 0 ((int_of_big_int length) - 1)))
| (P_aux (_,(l,_))) -> raise (Reporting_basic.err_unreachable l "Non-vector in vector-concat pattern") in
P_vector_concat (List.fold_left aux [] ps) in
{id_f with p_vector_concat = p_vector_concat} in
let pat = fold_pat remove_vector_concats pat in
(pat,decls)
|