open NTacStatus
open Continuationals.Stack
-let debug = false
+let debug = true
let pp =
if debug then (fun x -> prerr_endline (Lazy.force x)) else (fun _ -> ())
(* creates the discrimination = injection+contradiction principle *)
exception ConstructorTooBig of string;;
-let mk_discriminator ~use_jmeq name it leftno status baseuri =
+let mk_discriminator ~use_jmeq ?(force=false) name it leftno status baseuri =
let itnargs =
let _,_,arity,_ = it in
List.length (arg_list 0 arity) in
(* PHASE 1: derive the type of the discriminator (we'll name it "principle") *)
- let mk_eq tys ts us es n =
+ let mk_eq tys ts us es n deps =
if use_jmeq then
mk_appl [mk_id "jmeq";
NotationPt.Implicit `JustOne; List.nth ts n;
NotationPt.Implicit `JustOne; List.nth us n]
else
+ (* we use deps in an attempt to remove unnecessary rewritings when the type
+ is not maximally dependent *)
(* eqty = Tn u0 e0...un-1 en-1 *)
let eqty = mk_appl
- (List.nth tys n :: iter (fun i acc ->
- List.nth us i::
- List.nth es i:: acc)
- (n-1) []) in
-
- (* params = [T0;t0;...;Tn;tn;u0;e0;un-1;en-1] *)
+ (List.nth tys n :: iter
+ (fun i acc ->
+ if (List.mem (List.nth ts i) deps)
+ then List.nth us i::
+ List.nth es i::acc
+ else acc)
+ (n-1) []) in
+
+ (* params = [T0;t0;...;Tn;tn;u0;e0;...;un-1;en-1] *)
let params = iter (fun i acc ->
- List.nth tys i ::
- List.nth ts i :: acc) n
- (iter (fun i acc ->
- List.nth us i::
- List.nth es i:: acc) (n-1) []) in
+ if (List.mem (List.nth ts i) deps)
+ then List.nth tys i ::
+ List.nth ts i :: acc
+ else acc) (n-1)
+ (List.nth tys n::List.nth ts n::
+ iter (fun i acc ->
+ if (List.mem (List.nth ts i) deps)
+ then List.nth us i::
+ List.nth es i::acc
+ else acc) (n-1) []) in
+ let nrewrites = List.length deps in
mk_appl [mk_id "eq"; eqty;
- mk_appl (mk_id ("R" ^ string_of_int n) :: params);
- List.nth us n]
+ mk_appl (mk_id ("R" ^ string_of_int nrewrites) :: params);
+ List.nth us n]
in
name
in
- let branch i j ts us =
+ let branch i j ts us deps =
let nargs = nargs it leftno i in
let es = List.map (fun x -> mk_id ("e" ^ string_of_int x)) (HExtlib.list_seq 0 nargs) in
+ let ndepargs k =
+ let tk = List.nth ts k in
+ List.length (List.assoc tk deps)
+ in
let tys = List.map
- (fun x -> iter
+ (fun x ->
+ iter
(fun i acc ->
NotationPt.Binder (`Lambda, (mk_id ("x" ^ string_of_int i), None),
NotationPt.Binder (`Lambda, (mk_id ("p" ^ string_of_int i), None),
- acc))) (x-1)
+ acc))) ((ndepargs x) - 1)
(NotationPt.Implicit (`Tagged ("T" ^ (string_of_int x)))))
(HExtlib.list_seq 0 nargs) in
let tys = tys @
if i = j then
NotationPt.Binder (`Forall, (mk_id "_",
Some (iter (fun i acc ->
- NotationPt.Binder (`Forall, (List.nth es i, Some (mk_eq tys ts us es i)), acc))
+ NotationPt.Binder (`Forall, (List.nth es i, Some (mk_eq tys ts us es i (List.assoc (List.nth ts i) deps))), acc))
(nargs-1)
(** (NotationPt.Binder (`Forall, (mk_id "_",
Some (mk_eq tys ts us es nargs)),*)
else mk_id "P")
in
- let inner i ts = NotationPt.Case
+ let inner i ts deps =
+ NotationPt.Case
(mk_id "y",None,
(*Some (NotationPt.Binder (`Lambda, (mk_id "y",None),
NotationPt.Binder (`Forall, (mk_id "_", Some
NotationPt.Pattern (kname j,
None,
List.combine us nones),
- branch i j ts us)
+ branch i j ts us deps)
(HExtlib.list_seq 0 (List.length cl)))
in
let outer = NotationPt.Case
(mk_id "x",None,
None ,
List.map
- (fun i ->
+ (fun i ->
+ let _,_,kty = List.nth cl i in
let nargs_kty = nargs it leftno i in
- if (nargs_kty > 5 && not use_jmeq) then raise (ConstructorTooBig (kname i));
- let ts = iter (fun m acc -> mk_id ("t" ^ (string_of_int m))::acc)
+ let ts = iter (fun m acc -> ("t" ^ (string_of_int m))::acc)
(nargs_kty - 1) [] in
- let nones =
+ (* this context is used to compute dependencies between constructor arguments *)
+ let kctx = List.map2 (fun t ty -> t, NCic.Decl ty) (List.rev ts) (List.rev (arg_list leftno kty)) in
+ let ts = List.map mk_id ts in
+ (* compute graph of dependencies *)
+ let deps,_ = List.fold_left
+ (fun (acc,i) (t,_) ->
+ let name,_ = List.nth kctx (i-1) in
+ (name,fst (cascade_select_in_ctx status ~subst:[] kctx [] i))::acc,i+1) ([],1) kctx
+ in
+ (* transpose graph *)
+ let deps = List.fold_left
+ (fun acc (t,_) ->
+ let t_deps = List.map fst (List.filter (fun (name,rev_deps) -> List.mem t rev_deps) deps) in
+ (t,t_deps)::acc) [] deps
+ in
+ prerr_endline ("deps dump!");
+ List.iter (fun (x,xs) -> prerr_endline (x ^ ": " ^ (String.concat ", " xs))) deps;
+ let deps = List.map (fun (x,xs) -> mk_id x, (List.map mk_id) xs) deps in
+ let max_dep = List.fold_left max 0 (List.map (fun (_,xs) -> List.length xs) deps) in
+ if (max_dep > 4 && not use_jmeq && not force) then raise (ConstructorTooBig (kname i));
+
+ let nones =
iter (fun _ acc -> None::acc) (nargs_kty - 1) [] in
NotationPt.Pattern (kname i,
None,
List.combine ts nones),
- inner i ts)
+ inner i ts deps)
(HExtlib.list_seq 0 (List.length cl))) in
let principle =
mk_prods params (NotationPt.Binder (`Forall, (mk_id "x",