]> matita.cs.unibo.it Git - helm.git/blob - helm/software/components/ng_tactics/nDestructTac.ml
776d5dace0be53f410266bcf2fb776da1515d6c7
[helm.git] / helm / software / components / ng_tactics / nDestructTac.ml
1 (* Copyright (C) 2002, HELM Team.
2  * 
3  * This file is part of HELM, an Hypertextual, Electronic
4  * Library of Mathematics, developed at the Computer Science
5  * Department, University of Bologna, Italy.
6  * 
7  * HELM is free software; you can redistribute it and/or
8  * modify it under the terms of the GNU General Public License
9  * as published by the Free Software Foundation; either version 2
10  * of the License, or (at your option) any later version.
11  * 
12  * HELM is distributed in the hope that it will be useful,
13  * but WITHOUT ANY WARRANTY; without even the implied warranty of
14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15  * GNU General Public License for more details.
16  *
17  * You should have received a copy of the GNU General Public License
18  * along with HELM; if not, write to the Free Software
19  * Foundation, Inc., 59 Temple Place - Suite 330, Boston,
20  * MA  02111-1307, USA.
21  * 
22  * For details, see the HELM World-Wide-Web page,
23  * http://cs.unibo.it/helm/.
24  *)
25
26 (* $Id: destructTactic.ml 9774 2009-05-15 19:37:08Z sacerdot $ *)
27
28 open NTacStatus
29
30 let debug = true 
31 let pp = 
32   if debug then (fun x -> prerr_endline (Lazy.force x)) else (fun _ -> ())
33
34 let fresh_name =
35  let i = ref 0 in
36  function () ->
37   incr i;
38   "z" ^ string_of_int !i
39 ;;
40
41 let mk_id id =
42  let id = if id = "_" then fresh_name () else id in
43   CicNotationPt.Ident (id,None)
44 ;;
45
46 let mk_appl =
47  function
48     [] -> assert false
49   | [x] -> x
50   | l -> CicNotationPt.Appl l
51 ;;
52
53 let rec iter f n acc =
54   if n < 0 then acc
55   else iter f (n-1) (f n acc)
56 ;;
57
58 let subst_metasenv_and_fix_names status =
59   let u,h,metasenv, subst,o = status#obj in
60   let o = 
61     NCicUntrusted.map_obj_kind ~skip_body:true 
62      (NCicUntrusted.apply_subst subst []) o
63   in
64    status#set_obj(u,h,NCicUntrusted.apply_subst_metasenv subst metasenv,subst,o)
65 ;;
66
67 (* input: nome della variabile riscritta
68  * output: lista dei nomi delle variabili il cui tipo dipende dall'input *)
69 let cascade_select_in_ctx ~subst ctx iname =
70   prerr_endline "C";
71   let lctx, rctx = HExtlib.split_nth (iname - 1) ctx in
72   let lctx = List.rev lctx in
73   let rec rm_last = function
74       [] | [_] -> []
75     | hd::tl -> hd::(rm_last tl)
76   in
77
78   let indices,_ = List.fold_left
79        (fun (acc,context) item -> 
80          prerr_endline "C2";
81           match item with
82             | n,(NCic.Decl s | NCic.Def (s,_)) 
83                   when not (List.for_all (fun x -> NCicTypeChecker.does_not_occur ~subst context (x-1) x s) acc) ->
84                 List.iter (fun m -> prerr_endline ("acc has " ^ (string_of_int m))) acc;
85                 prerr_endline ("acc occurs in the type of " ^ n);
86                 (1::List.map ((+) 1) acc, item::context)
87             | _ -> (List.map ((+) 1) acc, item::context))
88        ([1], rctx) lctx in
89     prerr_endline "C3:";
90     List.iter (fun n -> prerr_endline (string_of_int n)) indices;
91     let indices = match rm_last indices with
92       | [] -> []
93       | _::tl -> tl in
94     let res = List.map (fun n -> let s,_ = List.nth ctx (n-1) in s) indices in
95     prerr_endline "C4:";
96     List.iter (fun n -> prerr_endline n) res;
97     prerr_endline (NCicPp.ppcontext ~metasenv:[] ~subst ctx);
98     res, indices
99 ;;
100
101 let rec mk_fresh_name ctx firstch n =
102   let candidate = (String.make 1 firstch) ^ (string_of_int n) in
103   if (List.for_all (fun (s,_) -> s <> candidate) ctx) then candidate
104   else mk_fresh_name ctx firstch (n+1)
105 ;;
106
107 let arg_list nleft t =
108   let rec drop_prods n t =
109     if n <= 0 then t
110     else match t with
111       | NCic.Prod (_,_,ta) -> drop_prods (n-1) ta
112       | _ -> raise (Failure "drop_prods")
113   in
114   let rec aux = function
115     | NCic.Prod (_,so,ta) -> so::aux ta
116     | _ -> []
117   in aux (drop_prods nleft t)
118 ;;
119
120 let nargs it nleft consno =
121   prerr_endline (Printf.sprintf "nargs %d %d" nleft consno);
122   let _,indname,_,cl = it in
123   let _,_,t_k = List.nth cl consno in
124   List.length (arg_list nleft t_k) ;;
125
126 let default_pattern = "",0,(None,[],Some CicNotationPt.UserInput);;
127
128 (* returns the discrimination = injection+contradiction principle *)
129 (* FIXME: mi riservo di considerare tipi con parametri sx alla fine *)
130
131 let mk_discriminator it status =
132   let nleft = 0 in
133   let _,indname,_,cl = it in
134
135
136   let mk_eq tys ts us es n =
137     (* eqty = Tn u0 e0...un-1 en-1 *)
138     let eqty = mk_appl 
139                  (List.nth tys n :: iter (fun i acc -> 
140                                            List.nth us i::
141                                            List.nth es i:: acc) 
142                                      (n-1) []) in
143
144     (* params = [T0;t0;...;Tn;tn;u0;e0;un-1;en-1] *)
145     let params = iter (fun i acc -> 
146                          List.nth tys i ::
147                          List.nth ts i :: acc) n
148                      (iter (fun i acc ->
149                             List.nth us i::
150                             List.nth es i:: acc) (n-1) []) in
151     mk_appl [mk_id "eq"; eqty;
152                         mk_appl (mk_id ("R" ^ string_of_int n) :: params);
153                         List.nth us n] 
154   in
155
156   let kname it j =
157     let _,_,_,cl = it in
158     let _,name,_ = List.nth cl j in
159     name
160   in
161
162   let branch i j ts us = 
163     let nargs = nargs it nleft i in
164     let es = List.map (fun x -> mk_id ("e" ^ string_of_int x)) (HExtlib.list_seq 0 nargs) in
165     let tys = List.map 
166                 (fun x -> CicNotationPt.Implicit (`Tagged ("T" ^ (string_of_int x)))) 
167                 (HExtlib.list_seq 0 nargs) in
168     let tys = tys @ 
169       [iter (fun i acc -> 
170         CicNotationPt.Binder (`Lambda, (mk_id ("x" ^ string_of_int i), None),
171         CicNotationPt.Binder (`Lambda, (mk_id ("p" ^ string_of_int i), None),
172         acc))) (nargs-1)
173         (mk_appl [mk_id "eq"; CicNotationPt.Implicit `JustOne;
174           mk_appl (mk_id (kname it i)::
175            List.map (fun x -> mk_id ("x" ^string_of_int x))
176               (HExtlib.list_seq 0 (List.length ts)));
177               mk_appl (mk_id (kname it j)::us)])]
178     in
179     CicNotationPt.Binder (`Lambda, (mk_id "e", 
180       Some (mk_appl 
181         [mk_id "eq";
182          CicNotationPt.Implicit `JustOne;
183          mk_appl (mk_id (kname it i)::ts);
184          mk_appl (mk_id (kname it j)::us)])),
185     let ts = ts @ [mk_id "e"] in
186     let refl2 = mk_appl
187                   [mk_id "refl";
188                    CicNotationPt.Implicit `JustOne;
189                    mk_appl (mk_id (kname it j)::us)] in
190     let us = us @ [refl2] in
191     CicNotationPt.Binder (`Forall, (mk_id "P", Some (CicNotationPt.Sort (`NType "1") )),
192       if i = j then 
193        CicNotationPt.Binder (`Forall, (mk_id "_",
194         Some (iter (fun i acc -> 
195               CicNotationPt.Binder (`Forall, (List.nth es i, Some (mk_eq tys ts us es i)), acc))
196               (nargs-1) 
197               (CicNotationPt.Binder (`Forall, (mk_id "_", 
198                 Some (mk_eq tys ts us es nargs)),
199                 mk_id "P")))), mk_id "P")
200       else mk_id "P"))
201   in
202
203   let inner i ts = CicNotationPt.Case 
204               (mk_id "y",None,
205                Some (CicNotationPt.Binder (`Lambda, (mk_id "y",None), 
206                  CicNotationPt.Binder (`Forall, (mk_id "_", Some
207                  (mk_appl [mk_id "eq";CicNotationPt.Implicit
208                  `JustOne;CicNotationPt.Implicit `JustOne;mk_id "y"])),
209                  CicNotationPt.Implicit `JustOne ))),
210                   List.map
211                   (fun j -> 
212                      let nargs_kty = nargs it nleft j in
213                      let us = iter (fun m acc -> mk_id ("u" ^ (string_of_int m))::acc) 
214                                 (nargs_kty - 1) [] in
215                      let nones = 
216                        iter (fun _ acc -> None::acc) (nargs_kty - 1) [] in
217                      CicNotationPt.Pattern (kname it j,
218                                             None,
219                                             List.combine us nones), 
220                                 branch i j ts us)
221                   (HExtlib.list_seq 0 (List.length cl)))
222   in
223   let outer = CicNotationPt.Case
224                 (mk_id "x",None,
225                  Some (CicNotationPt.Binder (`Lambda, (mk_id "_",None),
226                  (*CicNotationPt.Sort (`NType "2")*) CicNotationPt.Implicit
227                  `JustOne)) ,
228                  List.map
229                    (fun i -> 
230                       let nargs_kty = nargs it nleft i in
231                       let ts = iter (fun m acc -> mk_id ("t" ^ (string_of_int m))::acc)
232                                  (nargs_kty - 1) [] in
233                      let nones = 
234                        iter (fun _ acc -> None::acc) (nargs_kty - 1) [] in
235                       CicNotationPt.Pattern (kname it i,
236                                              None,
237                                              List.combine ts nones),
238                                 inner i ts)
239                    (HExtlib.list_seq 0 (List.length cl))) in
240   let principle = CicNotationPt.Binder (`Lambda, (mk_id "x", Some (mk_id indname)),
241                         CicNotationPt.Binder (`Lambda, (mk_id "y", Some (mk_id indname)), outer))
242   in
243   pp (lazy ("discriminator = " ^ (CicNotationPp.pp_term principle)));
244   
245   status, principle 
246 ;;
247
248 let hd_of_term = function
249   | NCic.Appl (hd::_) -> hd
250   | t -> t
251 ;;
252
253 let name_of_rel ~context rel =
254   let s, _ = List.nth context (rel-1) in s
255 ;;
256
257 (* let lookup_in_ctx ~context n =
258   List.nth context ((List.length context) - n - 1)
259 ;;*)
260
261 let discriminate_tac ~context cur_eq status =
262   pp (lazy (Printf.sprintf "discriminate: equation %s" (name_of_rel ~context cur_eq)));
263
264   let dbranch it leftno consno =
265     prerr_endline (Printf.sprintf "dbranch %d %d" leftno consno);
266     let nlist = HExtlib.list_seq 0 (nargs it leftno consno) in
267     (* (\forall ...\forall P.\forall DH : ( ... = ... -> P). P) *)
268     let params = List.map (fun x -> prerr_endline (Printf.sprintf "dbranch param a%d" x); NTactics.intro_tac ("a" ^ string_of_int x)) nlist in
269         NTactics.reduce_tac ~reduction:(`Normalize true) ~where:default_pattern::
270         params @ [
271         NTactics.intro_tac "P";
272         NTactics.intro_tac "DH";
273         NTactics.apply_tac ("",0,mk_id "DH");
274         NTactics.apply_tac ("",0,mk_id "refl");
275     ] in
276   let dbranches it leftno =
277     prerr_endline (Printf.sprintf "dbranches %d" leftno);
278     let _,_,_,cl = it in
279     let nbranches = List.length cl in 
280     let branches = iter (fun n acc -> 
281       let m = nbranches - n - 1 in
282       if m = 0 then (prerr_endline "no shift"; acc @ (dbranch it leftno m))
283       else (prerr_endline "sì shift"; acc @ NTactics.shift_tac :: (dbranch it
284       leftno m)))
285       (nbranches-1) [] in
286     if nbranches > 1 then
287       (prerr_endline "sì branch";
288          NTactics.branch_tac:: branches @ [NTactics.merge_tac])
289     else
290       (prerr_endline "no branch";
291       branches)
292   in
293   
294   let eq_name,(NCic.Decl s | NCic.Def (s,_)) = List.nth context (cur_eq-1) in
295   let _,ctx' = HExtlib.split_nth cur_eq context in
296   let status, s = NTacStatus.whd status ctx' (mk_cic_term ctx' s) in
297   let status, s = term_of_cic_term status s ctx' in
298   let status, leftno, it =
299     let it, t1, t2 = match s with
300       | NCic.Appl [_;it;t1;t2] -> it,t1,t2
301       | _ -> assert false in
302     (* XXX: serve? ho già fatto whd *)
303     let status, it = whd status ctx' (mk_cic_term ctx' it) in
304     let status, it = term_of_cic_term status it ctx' in
305     let _uri,indtyno,its = match it with
306         NCic.Const (NReference.Ref (uri, NReference.Ind (_,indtyno,_)) as r) -> 
307         uri, indtyno, NCicEnvironment.get_checked_indtys r
308       | _ -> prerr_endline ("discriminate: indty ="  ^ NCicPp.ppterm
309                   ~metasenv:[] ~subst:[] ~context:[] it) ; assert false in
310     let _,leftno,its,_,_ = its in
311     status, leftno, List.nth its indtyno
312   in
313
314   NTactics.block_tac (
315     [(fun status ->
316      let status, discr = mk_discriminator it status in
317       NTactics.cut_tac ("",0, CicNotationPt.Binder (`Forall, (mk_id "x", None),
318                          CicNotationPt.Binder (`Forall, (mk_id "y", None),
319                          CicNotationPt.Binder (`Forall, (mk_id "e", 
320                            Some (mk_appl [mk_id "eq";CicNotationPt.Implicit `JustOne; mk_id "x"; mk_id "y"])),
321                            mk_appl [discr; mk_id "x"; mk_id "y";
322                            mk_id "e"]))))
323       status);
324     NTactics.branch_tac;
325      NTactics.reduce_tac ~reduction:(`Normalize true) ~where:default_pattern;
326      NTactics.intro_tac "x";
327      NTactics.intro_tac "y";
328      NTactics.intro_tac "Deq";
329      NTactics.rewrite_tac ~dir:`RightToLeft ~what:("",0,mk_id "Deq") ~where:default_pattern;
330      NTactics.cases_tac ~what:("",0,mk_id "x") ~where:default_pattern]
331   @ dbranches it leftno  @ 
332    [NTactics.shift_tac;
333     NTactics.intro_tac "discriminate";
334     NTactics.apply_tac ("",0,mk_appl [mk_id "discriminate";
335                                 CicNotationPt.Implicit `JustOne;  
336                                 CicNotationPt.Implicit `JustOne; mk_id eq_name ]);
337                                 NTactics.reduce_tac ~reduction:(`Normalize true)
338                                 ~where:default_pattern;
339     NTactics.clear_tac ["discriminate"];
340     NTactics.merge_tac] 
341   ) status
342 ;;
343       
344 let subst_tac ~context ~dir cur_eq =
345   fun status ->
346   let eq_name,(NCic.Decl s | NCic.Def (s,_)) = List.nth context (cur_eq-1) in
347   let _,ctx' = HExtlib.split_nth cur_eq context in
348   let status, s = NTacStatus.whd status ctx' (mk_cic_term ctx' s) in
349   let status, s = term_of_cic_term status s ctx' in
350   pp (lazy (Printf.sprintf "subst: equation %s" eq_name));
351     let l, r = match s with
352       | NCic.Appl [_;_;t1;t2] -> t1,t2
353       | _ -> assert false in
354     let var = match dir with
355       | `LeftToRight -> l
356       | `RightToLeft -> r in
357     let var = match var with
358       | NCic.Rel i -> i
359       | _ -> assert false in
360     let names_to_gen, indices_to_gen = 
361       cascade_select_in_ctx ~subst:(get_subst status) context (var+cur_eq) in
362     let moved_indices = List.fold_left
363       (fun acc x -> if x > cur_eq then acc+1 else acc) 0 indices_to_gen in
364     let gen_tac x = 
365       NTactics.generalize_tac 
366       ~where:("",0,(Some (mk_id x),[], Some CicNotationPt.UserInput)) in
367     NTactics.block_tac ((List.map gen_tac names_to_gen)@
368                 [NTactics.clear_tac names_to_gen;
369                  NTactics.rewrite_tac ~dir 
370                    ~what:("",0,mk_id eq_name) ~where:default_pattern;
371                  NTactics.reduce_tac ~reduction:(`Normalize true)
372                    ~where:default_pattern]@
373                  (List.map NTactics.intro_tac (List.rev names_to_gen))) status,
374                  (List.length context - cur_eq + 1 - moved_indices)
375 ;;
376
377 let get_ctx status =
378     let ref_ctx = ref [] in
379     let status = NTactics.distribute_tac 
380       (fun st goal ->
381          let ctx = ctx_of (get_goalty st goal) in
382          ref_ctx := ctx; st) status in
383     !ref_ctx
384 ;;
385
386 let rec select_eq ctx i status acc =
387   try
388     match (List.nth ctx (List.length ctx - i - 1)) with
389     | n, (NCic.Decl s | NCic.Def (s,_)) ->
390         (let _,ctx_s = HExtlib.split_nth (List.length ctx - i) ctx in 
391          let status, s = NTacStatus.whd status ctx_s (mk_cic_term ctx_s s) in
392          let status, s = term_of_cic_term status s ctx_s in
393          pp (lazy (Printf.sprintf "select_eq tries %s" (NCicPp.ppterm ~context:ctx_s ~subst:[] ~metasenv:[] s)));
394          if (List.for_all (fun x -> x <> n) acc) then 
395            match s with
396            | NCic.Appl [NCic.Const (NReference.Ref (u,_)) ;_;_;_] ->
397                if NUri.name_of_uri u = "eq" then status, Some (List.length ctx - i)
398                else select_eq ctx (i+1) status acc
399            | _ -> select_eq ctx (i+1) status acc
400          else select_eq ctx (i+1) status acc)
401   with Failure _ | Invalid_argument _ -> status, None
402 ;;
403
404 let classify ~subst ctx i status =
405   let  _, (NCic.Decl s | NCic.Def (s,_)) = List.nth ctx (i-1) in
406   let _,ctx' = HExtlib.split_nth i ctx in
407   let status, s = NTacStatus.whd status ctx' (mk_cic_term ctx' s) in
408   let status, s = term_of_cic_term status s ctx' in
409   match s with 
410     | NCic.Appl [_;_;l;r] ->
411         (* FIXME: metasenv *)
412         if NCicReduction.are_convertible ~metasenv:[] ~subst ctx' l r 
413         then status, `Identity
414         else status, (match hd_of_term l, hd_of_term r with
415           | NCic.Const (NReference.Ref (_,NReference.Con (_,ki,nleft)) as kref),
416             NCic.Const (NReference.Ref (_,NReference.Con (_,kj,_))) -> 
417               if ki != kj then `Discriminate (0,true)
418               else
419                 let rit = NReference.mk_indty true kref in
420                 let _,_,its,_,itno = NCicEnvironment.get_checked_indtys rit in 
421                 let it = List.nth its itno in
422                 let newprods = (nargs it nleft (ki-1)) + 1 in
423                 `Discriminate (newprods, false) 
424           | NCic.Rel j, _  
425               when NCicTypeChecker.does_not_occur ~subst ctx' (j-1) j r -> 
426                 `Subst `LeftToRight
427           | _, NCic.Rel j 
428               when NCicTypeChecker.does_not_occur ~subst ctx' (j-1) j l -> 
429                 `Subst `RightToLeft
430           | (NCic.Rel _, _ | _, NCic.Rel _ ) -> `Cycle
431           | _ -> `Blob)
432     | _ -> raise (Failure "classify")
433 ;;
434
435 let rec destruct_tac0 nprods i status acc =
436   let ctx = get_ctx status in
437   let subst = get_subst status in
438   let status, selection = select_eq ctx i status acc in
439   match selection with
440   | None -> 
441     pp (lazy (Printf.sprintf "destruct: nprods is %d, i is %d, no selection, context is %s" nprods i (NCicPp.ppcontext ~metasenv:[] ~subst ctx)));
442       if nprods > 0  then 
443         let status' = NTactics.intro_tac (mk_fresh_name ctx 'e' 0) status in
444         destruct_tac0 (nprods-1) (List.length ctx) status' acc
445       else
446         status
447   | Some cur_eq -> pp (lazy (Printf.sprintf 
448         "destruct: nprods is %d, i is %d, selection is %s, context is %s" 
449         nprods i (name_of_rel ~context:ctx cur_eq) (NCicPp.ppcontext ~metasenv:[] ~subst ctx)));
450       match classify ~subst ctx cur_eq status with
451       | status,`Discriminate (newprods,conflict) -> 
452           let status' = discriminate_tac ~context:ctx cur_eq status in
453           if conflict then status'
454           else destruct_tac0 (nprods+newprods) (List.length ctx - cur_eq + 1)
455                  status' (name_of_rel ~context:ctx cur_eq::acc)
456       | status, `Subst dir ->
457           let status', next_i = subst_tac ~context:ctx ~dir cur_eq status in
458           destruct_tac0 nprods next_i status' acc
459       | status, `Identity
460       | status, `Cycle (* TODO *)
461       | status, `Blob ->
462           destruct_tac0 nprods (cur_eq+1) status acc
463 ;;
464
465 let destruct_tac status = destruct_tac0 0 0 status [];;