]> matita.cs.unibo.it Git - helm.git/blobdiff - matita/components/ng_tactics/nDestructTac.ml
Removes debug prints that were left from last commit.
[helm.git] / matita / components / ng_tactics / nDestructTac.ml
index de4c23a1c92ae039174ff20035fb76d96a3f6e69..2b6f4688ac988c706087a6896bf40b2050e02a8e 100644 (file)
@@ -174,7 +174,7 @@ let hp_pattern_jm n =
 (* creates the discrimination = injection+contradiction principle *)
 exception ConstructorTooBig of string;;
 
-let mk_discriminator ~use_jmeq name it leftno status baseuri =
+let mk_discriminator ~use_jmeq ?(force=false) name it leftno status baseuri =
   let itnargs = 
     let _,_,arity,_ = it in 
     List.length (arg_list 0 arity) in
@@ -184,29 +184,40 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
 
   (* PHASE 1: derive the type of the discriminator (we'll name it "principle") *)
 
-  let mk_eq tys ts us es n =
+  let mk_eq tys ts us es n deps =
     if use_jmeq then
       mk_appl [mk_id "jmeq";
                NotationPt.Implicit `JustOne; List.nth ts n;
                NotationPt.Implicit `JustOne; List.nth us n] 
     else
+    (* we use deps in an attempt to remove unnecessary rewritings when the type 
+       is not maximally dependent *)
     (* eqty = Tn u0 e0...un-1 en-1 *)
     let eqty = mk_appl 
-                 (List.nth tys n :: iter (fun i acc -> 
-                                           List.nth us i::
-                                           List.nth es i:: acc) 
-                                     (n-1) []) in
-
-    (* params = [T0;t0;...;Tn;tn;u0;e0;un-1;en-1] *)
+                 (List.nth tys n :: iter 
+                   (fun i acc ->
+                      if (List.mem (List.nth ts i) deps)
+                         then List.nth us i::
+                              List.nth es i::acc
+                         else acc) 
+                   (n-1) []) in
+
+    (* params = [T0;t0;...;Tn;tn;u0;e0;...;un-1;en-1] *)
     let params = iter (fun i acc -> 
-                         List.nth tys i ::
-                         List.nth ts i :: acc) n
-                     (iter (fun i acc ->
-                            List.nth us i::
-                            List.nth es i:: acc) (n-1) []) in
+                         if (List.mem (List.nth ts i) deps)
+                            then List.nth tys i ::
+                                 List.nth ts i :: acc
+                            else acc) (n-1)
+                     (List.nth tys n::List.nth ts n::
+                      iter (fun i acc ->
+                            if (List.mem (List.nth ts i) deps)
+                               then List.nth us i::
+                                    List.nth es i::acc
+                               else acc) (n-1) []) in
+    let nrewrites = List.length deps in
     mk_appl [mk_id "eq"; eqty;
-                        mk_appl (mk_id ("R" ^ string_of_int n) :: params);
-                        List.nth us n]
+             mk_appl (mk_id ("R" ^ string_of_int nrewrites) :: params);
+             List.nth us n]
 
   in
     
@@ -217,15 +228,20 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
     name
   in
 
-  let branch i j ts us = 
+  let branch i j ts us deps 
     let nargs = nargs it leftno i in
     let es = List.map (fun x -> mk_id ("e" ^ string_of_int x)) (HExtlib.list_seq 0 nargs) in
+    let ndepargs k = 
+      let tk = List.nth ts k in
+      List.length (List.assoc tk deps) 
+    in
     let tys = List.map 
-                (fun x -> iter 
+                (fun x -> 
+                  iter
                   (fun i acc -> 
                     NotationPt.Binder (`Lambda, (mk_id ("x" ^ string_of_int i), None),
                     NotationPt.Binder (`Lambda, (mk_id ("p" ^ string_of_int i), None),
-                    acc))) (x-1) 
+                    acc))) ((ndepargs x) - 1) 
                  (NotationPt.Implicit (`Tagged ("T" ^ (string_of_int x)))))
                (HExtlib.list_seq 0 nargs) in
     let tys = tys @ 
@@ -255,7 +271,7 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
       if i = j then 
        NotationPt.Binder (`Forall, (mk_id "_",
         Some (iter (fun i acc -> 
-              NotationPt.Binder (`Forall, (List.nth es i, Some (mk_eq tys ts us es i)), acc))
+              NotationPt.Binder (`Forall, (List.nth es i, Some (mk_eq tys ts us es i (List.assoc (List.nth ts i) deps))), acc))
               (nargs-1) 
               (** (NotationPt.Binder (`Forall, (mk_id "_", 
                 Some (mk_eq tys ts us es nargs)),*)
@@ -263,7 +279,8 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
       else mk_id "P")
   in
 
-  let inner i ts = NotationPt.Case 
+  let inner i ts deps = 
+    NotationPt.Case 
               (mk_id "y",None,
                (*Some (NotationPt.Binder (`Lambda, (mk_id "y",None), 
                  NotationPt.Binder (`Forall, (mk_id "_", Some
@@ -282,24 +299,43 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
                      NotationPt.Pattern (kname j,
                                             None,
                                             List.combine us nones), 
-                                branch i j ts us)
+                                branch i j ts us deps)
                   (HExtlib.list_seq 0 (List.length cl)))
   in
   let outer = NotationPt.Case
                 (mk_id "x",None,
                  None ,
                  List.map
-                   (fun i -> 
+                   (fun i ->
+                      let _,_,kty = List.nth cl i in 
                       let nargs_kty = nargs it leftno i in
-                      if (nargs_kty > 5 && not use_jmeq) then raise (ConstructorTooBig (kname i)); 
-                      let ts = iter (fun m acc -> mk_id ("t" ^ (string_of_int m))::acc)
+                      let ts = iter (fun m acc -> ("t" ^ (string_of_int m))::acc)
                                  (nargs_kty - 1) [] in
-                     let nones = 
+                      (* this context is used to compute dependencies between constructor arguments *)
+                      let kctx = List.map2 (fun t ty -> t, NCic.Decl ty) (List.rev ts) (List.rev (arg_list leftno kty)) in
+                      let ts = List.map mk_id ts in
+                      (* compute graph of dependencies *)
+                      let deps,_ = List.fold_left 
+                        (fun (acc,i) (t,_) -> 
+                         let name,_ = List.nth kctx (i-1) in
+                         (name,fst (cascade_select_in_ctx status ~subst:[] kctx [] i))::acc,i+1) ([],1) kctx
+                      in
+                      (* transpose graph *)
+                      let deps = List.fold_left
+                        (fun acc (t,_) -> 
+                           let t_deps = List.map fst (List.filter (fun (name,rev_deps) -> List.mem t rev_deps) deps) in
+                           (t,t_deps)::acc) [] deps
+                      in 
+                     let deps = List.map (fun (x,xs) -> mk_id x, (List.map mk_id) xs) deps in
+                      let max_dep = List.fold_left max 0 (List.map (fun (_,xs) -> List.length xs) deps) in
+                      if (max_dep > 4 && not use_jmeq && not force) then raise (ConstructorTooBig (kname i)); 
+                      
+                     let nones =
                        iter (fun _ acc -> None::acc) (nargs_kty - 1) [] in
                       NotationPt.Pattern (kname i,
                                              None,
                                              List.combine ts nones),
-                                inner i ts)
+                                inner i ts deps)
                    (HExtlib.list_seq 0 (List.length cl))) in
   let principle = 
     mk_prods params (NotationPt.Binder (`Forall, (mk_id "x",
@@ -340,7 +376,8 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
     let nlist = HExtlib.list_seq 0 (nargs it leftno consno) in
     (* (\forall ...\forall P.\forall DH : ( ... = ... -> P). P) *)
     let params = List.map (fun x -> NTactics.intro_tac ("a" ^ string_of_int x)) nlist in
-        NTactics.reduce_tac ~reduction:(`Normalize true) ~where:default_pattern::
+        (* NTactics.reduce_tac ~reduction:(`Normalize true)
+         * ~where:default_pattern::*)
         params @ [
         NTactics.intro_tac "P";
         NTactics.intro_tac "DH";
@@ -364,8 +401,8 @@ let mk_discriminator ~use_jmeq name it leftno status baseuri =
 
   let status =
    NTactics.block_tac (
-    [print_tac (lazy "ci sono");
-     NTactics.reduce_tac ~reduction:(`Normalize true) ~where:default_pattern 
+    [print_tac (lazy "ci sono") (*;
+     NTactics.reduce_tac ~reduction:(`Normalize true) ~where:default_pattern *)
     ]
   @ List.map (fun x -> NTactics.intro_tac x) params @
     [NTactics.intro_tac "x";
@@ -468,6 +505,13 @@ let subst_tac ~context ~dir skip cur_eq =
       | NCic.Rel var ->
         cascade_select_in_ctx status ~subst:(get_subst status) context skip (var+cur_eq)
       | _ -> cascade_select_in_ctx status ~subst:(get_subst status) context skip cur_eq in
+    let varname = match var with
+      | NCic.Rel var -> 
+          let name,_ = List.nth context (var+cur_eq-1) in
+         HLog.warn (Printf.sprintf "destruct: trying to remove variable: %s" name);
+         [name]
+      | _ -> []
+    in      
     let names_to_gen = List.filter (fun n -> n <> eq_name) names_to_gen in
     if (List.exists (fun x -> List.mem x skip) names_to_gen)
       then oldstatus
@@ -486,7 +530,12 @@ let subst_tac ~context ~dir skip cur_eq =
                    ~what:("",0,mk_id eq_name) ~where:default_pattern;
 (*                 NTactics.reduce_tac ~reduction:(`Normalize true)
                    ~where:default_pattern;*)
+                 (* XXX: temo che la clear multipla funzioni bene soltanto se
+                  * gli identificatori sono nell'ordine giusto.
+                  * Per non saper né leggere né scrivere, usiamo due clear
+                  * invece di una *)
                  NTactics.try_tac (NTactics.clear_tac [eq_name]);
+                NTactics.try_tac (NTactics.clear_tac varname);
 ]@
                  (List.map NTactics.intro_tac (List.rev names_to_gen))) status
 ;;
@@ -699,7 +748,9 @@ let rec destruct_tac0 tags acc domain skip status goal =
     let has_cleared = 
      try 
        let _ = NTactics.find_in_context eq_name (get_ctx status' newgoal) in false
-     with _ -> true in
+     with 
+      | Sys.Break as e -> raise e
+      |_ -> true in
     let rm_eq b l = if b then List.filter (fun x -> x <> eq_name) l else l in
     let acc = rm_eq has_cleared acc in
     let skip = rm_eq has_cleared skip in
@@ -715,7 +766,9 @@ let rec destruct_tac0 tags acc domain skip status goal =
       let has_cleared = 
        try 
          let _ = NTactics.find_in_context eq_name (get_ctx status' newgoal) in false
-       with _ -> true in
+       with 
+         | Sys.Break as e -> raise e         
+        | _ -> true in
       let rm_eq b l = if b then List.filter (fun x -> x <> eq_name) l else l in
       let acc = rm_eq has_cleared acc in
       let skip = rm_eq has_cleared skip in