]> matita.cs.unibo.it Git - helm.git/blobdiff - helm/ocaml/cic_notation/cicNotationRew.ml
snapshot (ported to new "typed" ids_to_inner_sort table)
[helm.git] / helm / ocaml / cic_notation / cicNotationRew.ml
index 8253efd161735b99f459062fe67afd8ba4af658a..b979e84c99f255aa313a3382106455d79d7702ad 100644 (file)
  * http://helm.cs.unibo.it/
  *)
 
+open Printf
+
 type pattern_id = int
+type interpretation_id = pattern_id
 
 type term_info =
   { sort: (Cic.id, CicNotationPt.sort_kind) Hashtbl.t;
@@ -43,16 +46,6 @@ module IntSet = Set.Make (OrderedInt)
 let int_set_of_int_list l =
   List.fold_left (fun acc i -> IntSet.add i acc) IntSet.empty l
 
-let (compiled32: (term_info -> Cic.annterm -> CicNotationPt.term) option ref) =
-  ref None
-
-let get_compiled32 () =
-  match !compiled32 with
-  | None -> assert false
-  | Some f -> f
-
-let set_compiled32 f = compiled32 := Some f
-
 let warning s = prerr_endline ("CicNotation WARNING: " ^ s)
 
 module Patterns =
@@ -60,6 +53,8 @@ module Patterns =
   type row_t = CicNotationPt.cic_appl_pattern list * pattern_id
   type t = row_t list
 
+  let empty = []
+
   let first_column t = List.map (fun (patterns, _) -> List.hd patterns) t
   let pattern_ids t = List.map snd t
 
@@ -130,6 +125,8 @@ module Patterns =
     List.split l
   end
 
+  (* acic -> ast auxiliary function s *)
+
 let get_types uri =
   let o,_ = CicEnvironment.get_obj CicUniv.empty_ugraph uri in
     match o with
@@ -163,7 +160,10 @@ let string_of_name = function
 
 let ident_of_name n = Ast.Ident (string_of_name n, None)
 
+let idref id t = Ast.AttributedTerm (`IdRef id, t)
+
 let ast_of_acic0 term_info acic k =
+(*   prerr_endline "ast_of_acic0"; *)
   let k = k term_info in
   let register_uri id uri = Hashtbl.add term_info.uri id uri in
   let sort_of_id id =
@@ -171,7 +171,6 @@ let ast_of_acic0 term_info acic k =
       Hashtbl.find term_info.sort id
     with Not_found -> assert false
   in
-  let idref id t = Ast.AttributedTerm (`IdRef id, t) in
   let aux_substs substs =
     Some
       (List.map
@@ -211,12 +210,21 @@ let ast_of_acic0 term_info acic k =
         idref id (Ast.LetIn ((ident_of_name n, None), k s, k t))
     | Cic.AAppl (aid,args) -> idref aid (Ast.Appl (List.map k args))
     | Cic.AConst (id,uri,substs) ->
+        register_uri id (UriManager.string_of_uri uri);
         idref id (Ast.Ident (UriManager.name_of_uri uri, aux_substs substs))
-    | Cic.AMutInd (id,uri,i,substs) ->
+    | Cic.AMutInd (id,uri,i,substs) as t ->
         let name = name_of_inductive_type uri i in
+        let uri_str = UriManager.string_of_uri uri in
+        let puri_str =
+          uri_str ^ "#xpointer(1/" ^ (string_of_int (i + 1)) ^ ")"
+        in
+        register_uri id puri_str;
         idref id (Ast.Ident (name, aux_substs substs))
     | Cic.AMutConstruct (id,uri,i,j,substs) ->
         let name = constructor_of_inductive_type uri i j in
+        let uri_str = UriManager.string_of_uri uri in
+        let puri_str = sprintf "%s#xpointer(1/%d/%d)" uri_str (i + 1) j in
+        register_uri id puri_str;
         idref id (Ast.Ident (name, aux_substs substs))
     | Cic.AMutCase (id,uri,typeno,ty,te,patterns) ->
         let name = name_of_inductive_type uri typeno in
@@ -268,12 +276,29 @@ let ast_of_acic0 term_info acic k =
   in
   aux acic
 
+  (* persistent state *)
+
+let level2_patterns = Hashtbl.create 211
+
+let (compiled32: (term_info -> Cic.annterm -> CicNotationPt.term) option ref) =
+  ref None
+
+let pattern_matrix = ref Patterns.empty
+
+let get_compiled32 () =
+  match !compiled32 with
+  | None -> assert false
+  | Some f -> f
+
+let set_compiled32 f = compiled32 := Some f
+
   (* "envl" is a list of triples:
    *   <name environment, term environment, pattern id>, where
    *   name environment: (string * string) list
    *   term environment: (string * Cic.annterm) list *)
 let return_closure success_k =
   (fun term_info terms envl ->
+(*     prerr_endline "return_closure"; *)
     match terms with
     | [] ->
         (try
@@ -283,6 +308,7 @@ let return_closure success_k =
 
 let variable_closure names k =
   (fun term_info terms envl ->
+(*     prerr_endline "variable_closure"; *)
     match terms with
     | hd :: tl ->
         let envl' =
@@ -293,14 +319,15 @@ let variable_closure names k =
                   Ast.IdentArg name, _ ->
                     (name_env, (name, term) :: term_env, pid)
                 | Ast.EtaArg (Some name, arg'),
-                  Cic.ALambda (_, name', ty, body) ->
-                    aux ((name, (string_of_name name', Some ty)) :: name_env)
+                  Cic.ALambda (id, name', ty, body) ->
+                    aux
+                      ((name, (string_of_name name', Some (ty, id))) :: name_env)
                       term_env pid arg' body
                 | Ast.EtaArg (Some name, arg'), _ ->
                     let name' = CicNotationUtil.fresh_name () in
                     aux ((name, (name', None)) :: name_env)
                       term_env pid arg' term
-                | Ast.EtaArg (None, arg'), Cic.ALambda (_, name, ty, body) ->
+                | Ast.EtaArg (None, arg'), Cic.ALambda (id, name, ty, body) ->
                     assert false
                 | Ast.EtaArg (None, arg'), _ ->
                     assert false
@@ -313,6 +340,7 @@ let variable_closure names k =
 
 let appl_closure ks k =
   (fun term_info terms envl ->
+(*     prerr_endline "appl_closure"; *)
     (match terms with
     | Cic.AAppl (_, args) :: tl ->
         (try
@@ -326,9 +354,11 @@ let uri_of_term t = CicUtil.uri_of_term (Deannotate.deannotate_term t)
 
 let uri_closure ks k =
   (fun term_info terms envl ->
+(*     prerr_endline "uri_closure"; *)
     (match terms with
     | [] -> assert false
     | hd :: tl ->
+(*         prerr_endline (sprintf "uri_of_term = %s" (uri_of_term hd)); *)
         begin
           try
             let k' = List.assoc (uri_of_term hd) ks in
@@ -345,7 +375,10 @@ let compiler32 (t: Patterns.t) success_k fail_k =
       k
     else if Patterns.are_empty t then begin
       (match t with
-      | _::_::_ -> warning "Ambiguous patterns"
+      | _::_::_ ->
+          (* optimization possible here: throw away all except one of the rules
+           * which lead to ambiguity *)
+          warning "Ambiguous patterns"
       | _ -> ());
       return_closure success_k
     end else
@@ -439,14 +472,38 @@ let compiler32 (t: Patterns.t) success_k fail_k =
 
 let ast_of_acic1 term_info annterm = (get_compiled32 ()) term_info annterm
 
-let load_patterns t instantiate =
+let instantiate term_info name_env term_env pid =
+  let symbol, args =
+    try
+      Hashtbl.find level2_patterns pid
+    with Not_found -> assert false
+  in
+  let rec instantiate_arg = function
+    | Ast.IdentArg name ->
+        (try List.assoc name term_env with Not_found -> assert false)
+    | Ast.EtaArg (None, _) -> assert false  (* TODO *)
+    | Ast.EtaArg (Some name, arg) ->
+        let (name', ty_opt) =
+          try List.assoc name name_env with Not_found -> assert false
+        in
+        let body = instantiate_arg arg in
+        let name' = Ast.Ident (name', None) in
+        match ty_opt with
+        | None -> Ast.Binder (`Lambda, (name', None), body)
+        | Some (ty, id) ->
+            idref id (Ast.Binder (`Lambda, (name', Some ty), body))
+  in
+  let args' = List.map instantiate_arg args in
+  Ast.Appl (Ast.Symbol (symbol, 0) :: args')
+
+let load_patterns t =
   let ast_env_of_name_env term_info name_env =
     List.map
       (fun (name, (name', ty_opt)) ->
         let ast_ty_opt =
           match ty_opt with
           | None -> None
-          | Some annterm -> Some (ast_of_acic1 term_info annterm)
+          | Some (annterm, id) -> Some (ast_of_acic1 term_info annterm, id)
         in
         (name, (name', ast_ty_opt)))
       name_env
@@ -470,3 +527,27 @@ let ast_of_acic id_to_sort annterm =
   let ast = ast_of_acic1 term_info annterm in
   ast, term_info.uri
 
+let fresh_id =
+  let counter = ref ~-1 in
+  fun () ->
+    incr counter;
+    !counter
+
+let add_interpretation (symbol, args) appl_pattern =
+  let id = fresh_id () in
+  Hashtbl.add level2_patterns id (symbol, args);
+  pattern_matrix := ([appl_pattern], id) :: !pattern_matrix;
+  load_patterns !pattern_matrix;
+  id
+
+exception Interpretation_not_found
+
+let remove_interpretation id =
+  (try
+    Hashtbl.remove level2_patterns id;
+  with Not_found -> raise Interpretation_not_found);
+  pattern_matrix := List.filter (fun (_, id') -> id <> id') !pattern_matrix;
+  load_patterns !pattern_matrix
+
+let _ = load_patterns []
+