]> matita.cs.unibo.it Git - helm.git/commitdiff
This patch allows generation of minimally dependent discrimination principles
authorWilmer Ricciotti <ricciott@cs.unibo.it>
Fri, 5 Oct 2012 10:08:05 +0000 (10:08 +0000)
committerWilmer Ricciotti <ricciott@cs.unibo.it>
Fri, 5 Oct 2012 10:08:05 +0000 (10:08 +0000)
for inductive types, in the case where Leibniz equality is used.

matita/components/grafite_engine/grafiteEngine.ml
matita/components/ng_tactics/nDestructTac.ml
matita/components/ng_tactics/nDestructTac.mli

index 5891698404d4661e9b4f83fa631b539f684f6055..15e20277f6f1ad5cb6d132c76686671f14512716 100644 (file)
@@ -856,7 +856,7 @@ let rec eval_ncommand ~include_paths opts status (text,prefix_len,cmd) =
           | _ -> prerr_endline ("engine: indty expected... (fix this error message)"); assert false in
         let (_,ind_name,_,_ as it) = List.nth tys indtyno in
         let status,obj =  
-          NDestructTac.mk_discriminator ~use_jmeq:true (ind_name ^ "_jmdiscr")
+          NDestructTac.mk_discriminator ~use_jmeq:true ~force:true (ind_name ^ "_jmdiscr")
            it leftno status status#baseuri in
         let _,_,menv,_,_ = obj in
           (match menv with
index 990cc672bbede15aaa17f8f9a2a088e8c452319e..d85b9ec502081d14c6d57f7e028e7371be60488b 100644 (file)
@@ -28,7 +28,7 @@
 open NTacStatus
 open Continuationals.Stack
 
-let debug = false
+let debug = true
 let pp = 
   if debug then (fun x -> prerr_endline (Lazy.force x)) else (fun _ -> ())
 
@@ -174,7 +174,7 @@ let hp_pattern_jm n =
 (* creates the discrimination = injection+contradiction principle *)
 exception ConstructorTooBig of string;;
 
-let mk_discriminator ~use_jmeq name it leftno status baseuri =
+let mk_discriminator ~use_jmeq ?(force=false) name it leftno status baseuri =
   let itnargs = 
     let _,_,arity,_ = it in 
     List.length (arg_list 0 arity) in
@@ -184,29 +184,40 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
 
   (* PHASE 1: derive the type of the discriminator (we'll name it "principle") *)
 
-  let mk_eq tys ts us es n =
+  let mk_eq tys ts us es n deps =
     if use_jmeq then
       mk_appl [mk_id "jmeq";
                NotationPt.Implicit `JustOne; List.nth ts n;
                NotationPt.Implicit `JustOne; List.nth us n] 
     else
+    (* we use deps in an attempt to remove unnecessary rewritings when the type 
+       is not maximally dependent *)
     (* eqty = Tn u0 e0...un-1 en-1 *)
     let eqty = mk_appl 
-                 (List.nth tys n :: iter (fun i acc -> 
-                                           List.nth us i::
-                                           List.nth es i:: acc) 
-                                     (n-1) []) in
-
-    (* params = [T0;t0;...;Tn;tn;u0;e0;un-1;en-1] *)
+                 (List.nth tys n :: iter 
+                   (fun i acc ->
+                      if (List.mem (List.nth ts i) deps)
+                         then List.nth us i::
+                              List.nth es i::acc
+                         else acc) 
+                   (n-1) []) in
+
+    (* params = [T0;t0;...;Tn;tn;u0;e0;...;un-1;en-1] *)
     let params = iter (fun i acc -> 
-                         List.nth tys i ::
-                         List.nth ts i :: acc) n
-                     (iter (fun i acc ->
-                            List.nth us i::
-                            List.nth es i:: acc) (n-1) []) in
+                         if (List.mem (List.nth ts i) deps)
+                            then List.nth tys i ::
+                                 List.nth ts i :: acc
+                            else acc) (n-1)
+                     (List.nth tys n::List.nth ts n::
+                      iter (fun i acc ->
+                            if (List.mem (List.nth ts i) deps)
+                               then List.nth us i::
+                                    List.nth es i::acc
+                               else acc) (n-1) []) in
+    let nrewrites = List.length deps in
     mk_appl [mk_id "eq"; eqty;
-                        mk_appl (mk_id ("R" ^ string_of_int n) :: params);
-                        List.nth us n]
+             mk_appl (mk_id ("R" ^ string_of_int nrewrites) :: params);
+             List.nth us n]
 
   in
     
@@ -217,15 +228,20 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
     name
   in
 
-  let branch i j ts us = 
+  let branch i j ts us deps 
     let nargs = nargs it leftno i in
     let es = List.map (fun x -> mk_id ("e" ^ string_of_int x)) (HExtlib.list_seq 0 nargs) in
+    let ndepargs k = 
+      let tk = List.nth ts k in
+      List.length (List.assoc tk deps) 
+    in
     let tys = List.map 
-                (fun x -> iter 
+                (fun x -> 
+                  iter
                   (fun i acc -> 
                     NotationPt.Binder (`Lambda, (mk_id ("x" ^ string_of_int i), None),
                     NotationPt.Binder (`Lambda, (mk_id ("p" ^ string_of_int i), None),
-                    acc))) (x-1) 
+                    acc))) ((ndepargs x) - 1) 
                  (NotationPt.Implicit (`Tagged ("T" ^ (string_of_int x)))))
                (HExtlib.list_seq 0 nargs) in
     let tys = tys @ 
@@ -255,7 +271,7 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
       if i = j then 
        NotationPt.Binder (`Forall, (mk_id "_",
         Some (iter (fun i acc -> 
-              NotationPt.Binder (`Forall, (List.nth es i, Some (mk_eq tys ts us es i)), acc))
+              NotationPt.Binder (`Forall, (List.nth es i, Some (mk_eq tys ts us es i (List.assoc (List.nth ts i) deps))), acc))
               (nargs-1) 
               (** (NotationPt.Binder (`Forall, (mk_id "_", 
                 Some (mk_eq tys ts us es nargs)),*)
@@ -263,7 +279,8 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
       else mk_id "P")
   in
 
-  let inner i ts = NotationPt.Case 
+  let inner i ts deps = 
+    NotationPt.Case 
               (mk_id "y",None,
                (*Some (NotationPt.Binder (`Lambda, (mk_id "y",None), 
                  NotationPt.Binder (`Forall, (mk_id "_", Some
@@ -282,24 +299,45 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
                      NotationPt.Pattern (kname j,
                                             None,
                                             List.combine us nones), 
-                                branch i j ts us)
+                                branch i j ts us deps)
                   (HExtlib.list_seq 0 (List.length cl)))
   in
   let outer = NotationPt.Case
                 (mk_id "x",None,
                  None ,
                  List.map
-                   (fun i -> 
+                   (fun i ->
+                      let _,_,kty = List.nth cl i in 
                       let nargs_kty = nargs it leftno i in
-                      if (nargs_kty > 5 && not use_jmeq) then raise (ConstructorTooBig (kname i)); 
-                      let ts = iter (fun m acc -> mk_id ("t" ^ (string_of_int m))::acc)
+                      let ts = iter (fun m acc -> ("t" ^ (string_of_int m))::acc)
                                  (nargs_kty - 1) [] in
-                     let nones = 
+                      (* this context is used to compute dependencies between constructor arguments *)
+                      let kctx = List.map2 (fun t ty -> t, NCic.Decl ty) (List.rev ts) (List.rev (arg_list leftno kty)) in
+                      let ts = List.map mk_id ts in
+                      (* compute graph of dependencies *)
+                      let deps,_ = List.fold_left 
+                        (fun (acc,i) (t,_) -> 
+                         let name,_ = List.nth kctx (i-1) in
+                         (name,fst (cascade_select_in_ctx status ~subst:[] kctx [] i))::acc,i+1) ([],1) kctx
+                      in
+                      (* transpose graph *)
+                      let deps = List.fold_left
+                        (fun acc (t,_) -> 
+                           let t_deps = List.map fst (List.filter (fun (name,rev_deps) -> List.mem t rev_deps) deps) in
+                           (t,t_deps)::acc) [] deps
+                      in 
+                      prerr_endline ("deps dump!");
+                      List.iter (fun (x,xs) -> prerr_endline (x ^ ": " ^ (String.concat ", " xs))) deps;
+                     let deps = List.map (fun (x,xs) -> mk_id x, (List.map mk_id) xs) deps in
+                      let max_dep = List.fold_left max 0 (List.map (fun (_,xs) -> List.length xs) deps) in
+                      if (max_dep > 4 && not use_jmeq && not force) then raise (ConstructorTooBig (kname i)); 
+                      
+                     let nones =
                        iter (fun _ acc -> None::acc) (nargs_kty - 1) [] in
                       NotationPt.Pattern (kname i,
                                              None,
                                              List.combine ts nones),
-                                inner i ts)
+                                inner i ts deps)
                    (HExtlib.list_seq 0 (List.length cl))) in
   let principle = 
     mk_prods params (NotationPt.Binder (`Forall, (mk_id "x",
index f753fa61e41e2b662622e383e70fc51983e5f400..0a324ec2d4bbebd53e9abdb7d9b8e2fe5b7907a2 100644 (file)
@@ -13,7 +13,7 @@
 
 val destruct_tac : string list option -> string list -> 's NTacStatus.tactic
 
-val mk_discriminator: (use_jmeq: bool) ->
+val mk_discriminator: (use_jmeq: bool) -> ?force:bool ->
   string -> NCic.inductiveType -> int ->  
   (#NTacStatus.tac_status as 's) -> string -> 's * NCic.obj