]> matita.cs.unibo.it Git - helm.git/blob - helm/ocaml/cic_notation/cicNotationRew.ml
snapshot (ported to new "typed" ids_to_inner_sort table)
[helm.git] / helm / ocaml / cic_notation / cicNotationRew.ml
1 (* Copyright (C) 2004-2005, HELM Team.
2  * 
3  * This file is part of HELM, an Hypertextual, Electronic
4  * Library of Mathematics, developed at the Computer Science
5  * Department, University of Bologna, Italy.
6  * 
7  * HELM is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License
9  * as published by the Free Software Foundation; either version 2
10  * of the License, or (at your option) any later version.
11  * 
12  * HELM is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with HELM; if not, write to the Free Software
19  * Foundation, Inc., 59 Temple Place - Suite 330, Boston,
20  * MA  02111-1307, USA.
21  * 
22  * For details, see the HELM World-Wide-Web page,
23  * http://helm.cs.unibo.it/
24  *)
25
26 open Printf
27
28 type pattern_id = int
29 type interpretation_id = pattern_id
30
31 type term_info =
32   { sort: (Cic.id, CicNotationPt.sort_kind) Hashtbl.t;
33     uri: (Cic.id, string) Hashtbl.t;
34   }
35
36 exception No_match
37
38 module OrderedInt =
39   struct
40   type t = int
41   let compare (x1:t) (x2:t) = Pervasives.compare x2 x1  (* reverse order *)
42   end
43
44 module IntSet = Set.Make (OrderedInt)
45
46 let int_set_of_int_list l =
47   List.fold_left (fun acc i -> IntSet.add i acc) IntSet.empty l
48
49 let warning s = prerr_endline ("CicNotation WARNING: " ^ s)
50
51 module Patterns =
52   struct
53   type row_t = CicNotationPt.cic_appl_pattern list * pattern_id
54   type t = row_t list
55
56   let empty = []
57
58   let first_column t = List.map (fun (patterns, _) -> List.hd patterns) t
59   let pattern_ids t = List.map snd t
60
61   let prepend_column t column =
62     try
63       List.map2 (fun elt (pl, pid) -> (elt :: pl), pid) column t
64
65     with Invalid_argument _ -> assert false
66
67   let prepend_columns t columns =
68     List.fold_right
69       (fun column rows -> prepend_column rows column)
70       columns t
71
72   let partition t pidl =
73     let partitions = Hashtbl.create 11 in
74     let add pid row = Hashtbl.add partitions pid row in
75     (try
76       List.iter2 add pidl t
77     with Invalid_argument _ -> assert false);
78     let pidset = int_set_of_int_list pidl in
79     IntSet.fold
80       (fun pid acc ->
81         match Hashtbl.find_all partitions pid with
82         | [] -> acc
83         | patterns -> (pid, List.rev patterns) :: acc)
84       pidset []
85
86   let are_empty t = fst (List.hd t) = []
87     (* if first row has an empty list of patterns, then others will as well *)
88
89     (* return 2 lists of rows, first one containing homogeneous rows according
90      * to "compatible" below *)
91   let horizontal_split t =
92     let compatible ap1 ap2 =
93       match ap1, ap2 with
94       | CicNotationPt.UriPattern _, CicNotationPt.UriPattern _
95       | CicNotationPt.ArgPattern _, CicNotationPt.ArgPattern _
96       | CicNotationPt.ApplPattern _, CicNotationPt.ApplPattern _ -> true
97       | _ -> false
98     in
99     let ap =
100       match t with
101       | [] -> assert false
102       | ([], _) :: _ ->
103           assert false  (* are_empty should have been invoked in advance *)
104       | (hd :: _ , _) :: _ -> hd
105     in
106     let rec aux prev_t = function
107       | [] -> List.rev prev_t, []
108       | ([], _) :: _ -> assert false
109       | (((hd :: _), _) as row) :: tl when compatible ap hd ->
110           aux (row :: prev_t) tl
111       | t -> List.rev prev_t, t
112     in
113     aux [] t
114
115     (* return 2 lists, first one representing first column, second one
116      * representing rows stripped of the first element *)
117   let vertical_split t =
118     let l =
119       List.map
120         (function
121           | (hd :: tl, pid) -> hd, (tl, pid)
122           | _ -> assert false)
123         t
124     in
125     List.split l
126   end
127
128   (* acic -> ast auxiliary function s *)
129
130 let get_types uri =
131   let o,_ = CicEnvironment.get_obj CicUniv.empty_ugraph uri in
132     match o with
133       | Cic.InductiveDefinition (l,_,_,_) -> l 
134       | _ -> assert false
135
136 let name_of_inductive_type uri i = 
137   let types = get_types uri in
138   let (name, _, _, _) = try List.nth types i with Not_found -> assert false in
139   name
140
141   (* returns <name, type> pairs *)
142 let constructors_of_inductive_type uri i =
143   let types = get_types uri in
144   let (_, _, _, constructors) = 
145     try List.nth types i with Not_found -> assert false
146   in
147   constructors
148
149   (* returns name only *)
150 let constructor_of_inductive_type uri i j =
151   (try
152     fst (List.nth (constructors_of_inductive_type uri i) (j-1))
153   with Not_found -> assert false)
154
155 module Ast = CicNotationPt
156
157 let string_of_name = function
158   | Cic.Name s -> s
159   | Cic.Anonymous -> "_"
160
161 let ident_of_name n = Ast.Ident (string_of_name n, None)
162
163 let idref id t = Ast.AttributedTerm (`IdRef id, t)
164
165 let ast_of_acic0 term_info acic k =
166 (*   prerr_endline "ast_of_acic0"; *)
167   let k = k term_info in
168   let register_uri id uri = Hashtbl.add term_info.uri id uri in
169   let sort_of_id id =
170     try
171       Hashtbl.find term_info.sort id
172     with Not_found -> assert false
173   in
174   let aux_substs substs =
175     Some
176       (List.map
177         (fun (uri, annterm) -> (UriManager.name_of_uri uri, k annterm))
178         substs)
179   in
180   let aux_context context =
181     List.map
182       (function
183         | None -> None
184         | Some annterm -> Some (k annterm))
185       context
186   in
187   let aux = function
188     | Cic.ARel (id,_,_,b) -> idref id (Ast.Ident (b, None))
189     | Cic.AVar (id,uri,substs) ->
190         register_uri id (UriManager.string_of_uri uri);
191         idref id (Ast.Ident (UriManager.name_of_uri uri, aux_substs substs))
192     | Cic.AMeta (id,n,l) -> idref id (Ast.Meta (n, aux_context l))
193     | Cic.ASort (id,Cic.Prop) -> idref id (Ast.Sort `Prop)
194     | Cic.ASort (id,Cic.Set) -> idref id (Ast.Sort `Set)
195     | Cic.ASort (id,Cic.Type _) -> idref id (Ast.Sort `Type)
196     | Cic.ASort (id,Cic.CProp) -> idref id (Ast.Sort `CProp)
197     | Cic.AImplicit _ -> assert false
198     | Cic.AProd (id,n,s,t) ->
199         let binder_kind =
200           match sort_of_id id with
201           | `Set | `Type -> `Pi
202           | `Prop | `CProp -> `Forall
203         in
204         idref id (Ast.Binder (binder_kind, (ident_of_name n, Some (k s)), k t))
205     | Cic.ACast (id,v,t) ->
206         idref id (Ast.Appl [idref id (Ast.Symbol ("cast", 0)); k v; k t])
207     | Cic.ALambda (id,n,s,t) ->
208         idref id (Ast.Binder (`Lambda, (ident_of_name n, Some (k s)), k t))
209     | Cic.ALetIn (id,n,s,t) ->
210         idref id (Ast.LetIn ((ident_of_name n, None), k s, k t))
211     | Cic.AAppl (aid,args) -> idref aid (Ast.Appl (List.map k args))
212     | Cic.AConst (id,uri,substs) ->
213         register_uri id (UriManager.string_of_uri uri);
214         idref id (Ast.Ident (UriManager.name_of_uri uri, aux_substs substs))
215     | Cic.AMutInd (id,uri,i,substs) as t ->
216         let name = name_of_inductive_type uri i in
217         let uri_str = UriManager.string_of_uri uri in
218         let puri_str =
219           uri_str ^ "#xpointer(1/" ^ (string_of_int (i + 1)) ^ ")"
220         in
221         register_uri id puri_str;
222         idref id (Ast.Ident (name, aux_substs substs))
223     | Cic.AMutConstruct (id,uri,i,j,substs) ->
224         let name = constructor_of_inductive_type uri i j in
225         let uri_str = UriManager.string_of_uri uri in
226         let puri_str = sprintf "%s#xpointer(1/%d/%d)" uri_str (i + 1) j in
227         register_uri id puri_str;
228         idref id (Ast.Ident (name, aux_substs substs))
229     | Cic.AMutCase (id,uri,typeno,ty,te,patterns) ->
230         let name = name_of_inductive_type uri typeno in
231         let constructors = constructors_of_inductive_type uri typeno in
232         let rec eat_branch ty pat =
233           match (ty, pat) with
234           | Cic.Prod (_, _, t), Cic.ALambda (_, name, s, t') ->
235               let (cv, rhs) = eat_branch t t' in
236               (ident_of_name name, Some (k s)) :: cv, rhs
237           | _, _ -> [], k pat
238         in
239         let patterns =
240           List.map2
241             (fun (name, ty) pat ->
242               let (capture_variables, rhs) = eat_branch ty pat in
243               ((name, capture_variables), rhs))
244             constructors patterns
245         in
246         idref id (Ast.Case (k te, Some name, Some (k ty), patterns))
247     | Cic.AFix (id, no, funs) -> 
248         let defs = 
249           List.map
250             (fun (_, n, decr_idx, ty, bo) ->
251               ((Ast.Ident (n, None), Some (k ty)), k bo, decr_idx))
252             funs
253         in
254         let name =
255           try
256             (match List.nth defs no with
257             | (Ast.Ident (n, _), _), _, _ when n <> "_" -> n
258             | _ -> assert false)
259           with Not_found -> assert false
260         in
261         idref id (Ast.LetRec (`Inductive, defs, Ast.Ident (name, None)))
262     | Cic.ACoFix (id, no, funs) -> 
263         let defs = 
264           List.map
265             (fun (_, n, ty, bo) -> ((Ast.Ident (n, None), Some (k ty)), k bo, 0))
266             funs
267         in
268         let name =
269           try
270             (match List.nth defs no with
271             | (Ast.Ident (n, _), _), _, _ when n <> "_" -> n
272             | _ -> assert false)
273           with Not_found -> assert false
274         in
275         idref id (Ast.LetRec (`CoInductive, defs, Ast.Ident (name, None)))
276   in
277   aux acic
278
279   (* persistent state *)
280
281 let level2_patterns = Hashtbl.create 211
282
283 let (compiled32: (term_info -> Cic.annterm -> CicNotationPt.term) option ref) =
284   ref None
285
286 let pattern_matrix = ref Patterns.empty
287
288 let get_compiled32 () =
289   match !compiled32 with
290   | None -> assert false
291   | Some f -> f
292
293 let set_compiled32 f = compiled32 := Some f
294
295   (* "envl" is a list of triples:
296    *   <name environment, term environment, pattern id>, where
297    *   name environment: (string * string) list
298    *   term environment: (string * Cic.annterm) list *)
299 let return_closure success_k =
300   (fun term_info terms envl ->
301 (*     prerr_endline "return_closure"; *)
302     match terms with
303     | [] ->
304         (try
305           success_k term_info (List.hd envl)
306         with Failure _ -> assert false)
307     | _ -> assert false)
308
309 let variable_closure names k =
310   (fun term_info terms envl ->
311 (*     prerr_endline "variable_closure"; *)
312     match terms with
313     | hd :: tl ->
314         let envl' =
315           List.map2
316             (fun arg (name_env, term_env, pid) ->
317               let rec aux name_env term_env pid arg term =
318                 match arg, term with
319                   Ast.IdentArg name, _ ->
320                     (name_env, (name, term) :: term_env, pid)
321                 | Ast.EtaArg (Some name, arg'),
322                   Cic.ALambda (id, name', ty, body) ->
323                     aux
324                       ((name, (string_of_name name', Some (ty, id))) :: name_env)
325                       term_env pid arg' body
326                 | Ast.EtaArg (Some name, arg'), _ ->
327                     let name' = CicNotationUtil.fresh_name () in
328                     aux ((name, (name', None)) :: name_env)
329                       term_env pid arg' term
330                 | Ast.EtaArg (None, arg'), Cic.ALambda (id, name, ty, body) ->
331                     assert false
332                 | Ast.EtaArg (None, arg'), _ ->
333                     assert false
334               in
335                 aux name_env term_env pid arg hd)
336             names envl
337         in
338         k term_info tl envl'
339     | _ -> assert false)
340
341 let appl_closure ks k =
342   (fun term_info terms envl ->
343 (*     prerr_endline "appl_closure"; *)
344     (match terms with
345     | Cic.AAppl (_, args) :: tl ->
346         (try
347           let k' = List.assoc (List.length args) ks in
348           k' term_info (args @ tl) envl
349         with Not_found -> k term_info terms envl)
350     | [] -> assert false
351     | _ -> k term_info terms envl))
352
353 let uri_of_term t = CicUtil.uri_of_term (Deannotate.deannotate_term t)
354
355 let uri_closure ks k =
356   (fun term_info terms envl ->
357 (*     prerr_endline "uri_closure"; *)
358     (match terms with
359     | [] -> assert false
360     | hd :: tl ->
361 (*         prerr_endline (sprintf "uri_of_term = %s" (uri_of_term hd)); *)
362         begin
363           try
364             let k' = List.assoc (uri_of_term hd) ks in
365             k' term_info tl envl
366           with
367           | Invalid_argument _  (* raised by uri_of_term *)
368           | Not_found -> k term_info terms envl
369         end))
370
371   (* compiler from level 3 to level 2 *)
372 let compiler32 (t: Patterns.t) success_k fail_k =
373   let rec aux t k = (* k is a continuation *)
374     if t = [] then
375       k
376     else if Patterns.are_empty t then begin
377       (match t with
378       | _::_::_ ->
379           (* optimization possible here: throw away all except one of the rules
380            * which lead to ambiguity *)
381           warning "Ambiguous patterns"
382       | _ -> ());
383       return_closure success_k
384     end else
385       match Patterns.horizontal_split t with
386       | t', [] ->
387           (match t' with
388           | []
389           | ([], _) :: _ -> assert false
390           | (Ast.ArgPattern (Ast.IdentArg _) :: _, _) :: _
391           | (Ast.ArgPattern (Ast.EtaArg _) :: _, _) :: _ ->
392               let first_column, t'' = Patterns.vertical_split t' in
393               let names =
394                 List.map
395                   (function
396                     | Ast.ArgPattern arg -> arg
397                     | _ -> assert false)
398                   first_column
399               in
400               variable_closure names (aux t'' k)
401           | (Ast.ApplPattern _ :: _, _) :: _ ->
402               let pidl =
403                 List.map
404                   (function
405                     | (Ast.ApplPattern args) :: _, _ -> List.length args
406                     | _ -> assert false)
407                   t'
408               in
409                 (* arity partitioning *)
410               let clusters = Patterns.partition t' pidl in
411               let ks =  (* k continuation list *)
412                 List.map
413                   (fun (len, cluster) ->
414                     let cluster' =
415                       List.map  (* add args as patterns heads *)
416                         (function
417                           | (Ast.ApplPattern args) :: tl, pid ->
418                               (* let's throw away "teste di cluster" *)
419                               args @ tl, pid
420                           | _ -> assert false)
421                         cluster
422                     in
423                     len, aux cluster' k)
424                   clusters
425               in
426               appl_closure ks k
427           | (Ast.UriPattern _ :: _, _) :: _ ->
428               let uidmap, pidl =
429                 let urimap = ref [] in
430                 let uidmap = ref [] in
431                 let get_uri_id uri =
432                   try
433                     List.assoc uri !urimap
434                   with
435                     Not_found ->
436                       let uid = List.length !urimap in
437                       urimap := (uri, uid) :: !urimap ;
438                       uidmap := (uid, uri) :: !uidmap ;
439                       uid
440                 in
441                 let uidl =
442                   List.map
443                     (function
444                       | (Ast.UriPattern uri) :: _, _ -> get_uri_id uri
445                       | _ -> assert false)
446                     t'
447                 in
448                   !uidmap, uidl
449               in
450               let clusters = Patterns.partition t' pidl in
451               let ks =
452                 List.map
453                   (fun (uid, cluster) ->
454                     let cluster' =
455                       List.map
456                         (function
457                         | (Ast.UriPattern uri) :: tl, pid -> tl, pid
458                         | _ -> assert false)
459                       cluster
460                     in
461                     List.assoc uid uidmap, aux cluster' k)
462                   clusters
463               in
464               uri_closure ks k)
465       | t', tl -> aux t' (aux tl k)
466   in
467   let matcher = aux t (fun _ _ -> raise No_match) in
468   (fun term_info annterm ->
469     try
470       matcher term_info [annterm] (List.map (fun (_, pid) -> [], [], pid) t)
471     with No_match -> fail_k term_info annterm)
472
473 let ast_of_acic1 term_info annterm = (get_compiled32 ()) term_info annterm
474
475 let instantiate term_info name_env term_env pid =
476   let symbol, args =
477     try
478       Hashtbl.find level2_patterns pid
479     with Not_found -> assert false
480   in
481   let rec instantiate_arg = function
482     | Ast.IdentArg name ->
483         (try List.assoc name term_env with Not_found -> assert false)
484     | Ast.EtaArg (None, _) -> assert false  (* TODO *)
485     | Ast.EtaArg (Some name, arg) ->
486         let (name', ty_opt) =
487           try List.assoc name name_env with Not_found -> assert false
488         in
489         let body = instantiate_arg arg in
490         let name' = Ast.Ident (name', None) in
491         match ty_opt with
492         | None -> Ast.Binder (`Lambda, (name', None), body)
493         | Some (ty, id) ->
494             idref id (Ast.Binder (`Lambda, (name', Some ty), body))
495   in
496   let args' = List.map instantiate_arg args in
497   Ast.Appl (Ast.Symbol (symbol, 0) :: args')
498
499 let load_patterns t =
500   let ast_env_of_name_env term_info name_env =
501     List.map
502       (fun (name, (name', ty_opt)) ->
503         let ast_ty_opt =
504           match ty_opt with
505           | None -> None
506           | Some (annterm, id) -> Some (ast_of_acic1 term_info annterm, id)
507         in
508         (name, (name', ast_ty_opt)))
509       name_env
510   in
511   let ast_env_of_term_env term_info =
512     List.map (fun (name, term) -> (name, ast_of_acic1 term_info term))
513   in
514   let fail_k term_info annterm = ast_of_acic0 term_info annterm ast_of_acic1 in
515   let success_k term_info (name_env, term_env, pid) =
516     instantiate
517       term_info
518       (ast_env_of_name_env term_info name_env)
519       (ast_env_of_term_env term_info term_env)
520       pid
521   in
522   let compiled32 = compiler32 t success_k fail_k in
523   set_compiled32 compiled32
524
525 let ast_of_acic id_to_sort annterm =
526   let term_info = { sort = id_to_sort; uri = Hashtbl.create 211 } in
527   let ast = ast_of_acic1 term_info annterm in
528   ast, term_info.uri
529
530 let fresh_id =
531   let counter = ref ~-1 in
532   fun () ->
533     incr counter;
534     !counter
535
536 let add_interpretation (symbol, args) appl_pattern =
537   let id = fresh_id () in
538   Hashtbl.add level2_patterns id (symbol, args);
539   pattern_matrix := ([appl_pattern], id) :: !pattern_matrix;
540   load_patterns !pattern_matrix;
541   id
542
543 exception Interpretation_not_found
544
545 let remove_interpretation id =
546   (try
547     Hashtbl.remove level2_patterns id;
548   with Not_found -> raise Interpretation_not_found);
549   pattern_matrix := List.filter (fun (_, id') -> id <> id') !pattern_matrix;
550   load_patterns !pattern_matrix
551
552 let _ = load_patterns []
553