X-Git-Url: http://matita.cs.unibo.it/gitweb/?a=blobdiff_plain;f=matita%2Fcomponents%2Fextlib%2Fdiscrimination_tree.ml;h=1a8147f7db8d62cb4eaf51111daa4a77e5587607;hb=74c6905907b0bca229366d52450e2a6982b5b8be;hp=4caf38de19b545b9e3937336debf6b61abf79e2a;hpb=3ba4306ecd693b48f70ecbe9916aec6975373549;p=helm.git diff --git a/matita/components/extlib/discrimination_tree.ml b/matita/components/extlib/discrimination_tree.ml index 4caf38de1..1a8147f7d 100644 --- a/matita/components/extlib/discrimination_tree.ml +++ b/matita/components/extlib/discrimination_tree.ml @@ -70,8 +70,17 @@ module type DiscriminationTree = val in_index : t -> input -> (data -> bool) -> bool val retrieve_generalizations : t -> input -> dataset val retrieve_unifiables : t -> input -> dataset - val retrieve_generalizations_sorted : t -> input -> (data * int) list - val retrieve_unifiables_sorted : t -> input -> (data * int) list + + module type Collector = sig + type t + val empty : t + val union : t -> t -> t + val inter : t -> t -> data list + val to_list : t -> data list + end + module Collector : Collector + val retrieve_generalizations_sorted : t -> input -> Collector.t + val retrieve_unifiables_sorted : t -> input -> Collector.t end module Make (I:Indexable) (A:Set.S) : DiscriminationTree @@ -148,8 +157,8 @@ and type data = A.elt and type dataset = A.t = (* the equivalent of skip, but on the index, thus the list of trees that are rooted just after the term represented by the tree root are returned (we are skipping the root) *) - let skip_root = function DiscriminationTree.Node (value, map) -> - let rec get n = function DiscriminationTree.Node (v, m) as tree -> + let skip_root = function DiscriminationTree.Node (_value, map) -> + let rec get n = function DiscriminationTree.Node (_v, m) as tree -> if n = 0 then [tree] else PSMap.fold (fun k v res -> (get (n-1 + arity_of k) v) @ res) m [] in @@ -162,7 +171,7 @@ and type data = A.elt and type dataset = A.t = match tree, path with | DiscriminationTree.Node (Some s, _), [] -> s | DiscriminationTree.Node (None, _), [] -> A.empty - | DiscriminationTree.Node (_, map), Variable::path when unif -> + | DiscriminationTree.Node (_, _map), Variable::path when unif -> List.fold_left A.union A.empty (List.map (retrieve path) (skip_root tree)) | DiscriminationTree.Node (_, map), node::path -> @@ -184,17 +193,68 @@ and type data = A.elt and type dataset = A.t = module O = struct type t = A.t * int - let compare (_,a) (_,b) = compare a b + let compare (sa,wa) (sb,wb) = + let c = compare wb wa in + if c <> 0 then c else A.compare sb sa end module S = Set.Make(O) + (* TASSI: here we should think of a smarted data structure *) + module type Collector = sig + type t + val empty : t + val union : t -> t -> t + val inter : t -> t -> data list + val to_list : t -> data list + end + module Collector : Collector with type t = S.t = struct + type t = S.t + let union = S.union + let empty = S.empty + + let merge l = + let rec aux s w = function + | [] -> [s,w] + | (t, wt)::tl when w = wt -> aux (A.union s t) w tl + | (t, wt)::tl -> (s, w) :: aux t wt tl + in + match l with + | [] -> [] + | (s, w) :: l -> aux s w l + + let rec undup ~eq = function + | [] -> [] + | x :: tl -> x :: undup ~eq (List.filter (fun y -> not(eq x y)) tl) + + let to_list t = + undup ~eq:(fun x y -> A.equal (A.singleton x) (A.singleton y)) + (List.flatten (List.map + (fun x,_ -> A.elements x) (merge (S.elements t)))) + + let inter t1 t2 = + let l1 = merge (S.elements t1) in + let l2 = merge (S.elements t2) in + let res = + List.flatten + (List.map + (fun s, w -> + HExtlib.filter_map (fun x -> + try Some (x, w + snd (List.find (fun (s,_w) -> A.mem x s) l2)) + with Not_found -> None) + (A.elements s)) + l1) + in + undup ~eq:(fun x y -> A.equal (A.singleton x) (A.singleton y)) + (List.map fst (List.sort (fun (_,x) (_,y) -> y - x) res)) + end + let retrieve_sorted unif tree term = let path = I.path_string_of term in let rec retrieve n path tree = match tree, path with | DiscriminationTree.Node (Some s, _), [] -> S.singleton (s, n) | DiscriminationTree.Node (None, _), [] -> S.empty - | DiscriminationTree.Node (_, map), Variable::path when unif -> + | DiscriminationTree.Node (_, _map), Variable::path when unif -> List.fold_left S.union S.empty (List.map (retrieve n path) (skip_root tree)) | DiscriminationTree.Node (_, map), node::path -> @@ -209,9 +269,7 @@ and type data = A.elt and type dataset = A.t = | no, path -> retrieve n path no with Not_found -> S.empty) in - List.flatten - (List.map (fun x -> List.map (fun y -> y, snd x) (A.elements (fst x))) - (S.elements (retrieve 0 path tree))) + retrieve 0 path tree ;; let retrieve_generalizations_sorted tree term =