]> matita.cs.unibo.it Git - helm.git/blobdiff - matita/components/extlib/discrimination_tree.ml
coercion lookup now returns coercions ranked using the number of symbols matched...
[helm.git] / matita / components / extlib / discrimination_tree.ml
index d69228b0c257e2784801f24f2e1cb22bf52a7faa..cdc498e9cfb48db2bd981b95009162f615ea84b0 100644 (file)
@@ -70,8 +70,17 @@ module type DiscriminationTree =
       val in_index : t -> input -> (data -> bool) -> bool
       val retrieve_generalizations : t -> input -> dataset
       val retrieve_unifiables : t -> input -> dataset
-      val retrieve_generalizations_sorted : t -> input -> (data * int) list
-      val retrieve_unifiables_sorted : t -> input -> (data * int) list
+
+      module type Collector = sig
+        type t
+        val empty : t
+        val union : t -> t -> t
+        val inter : t -> t -> data list
+        val to_list : t -> data list
+      end
+      module Collector : Collector
+      val retrieve_generalizations_sorted : t -> input -> Collector.t
+      val retrieve_unifiables_sorted : t -> input -> Collector.t
     end
 
 module Make (I:Indexable) (A:Set.S) : DiscriminationTree 
@@ -184,10 +193,61 @@ and type data = A.elt and type dataset = A.t =
 
       module O = struct
         type t = A.t * int
-        let compare (_,a) (_,b) = compare b a
+        let compare (sa,wa) (sb,wb) = 
+          let c = compare wb wa in
+          if c <> 0 then c else A.compare sb sa
       end
       module S = Set.Make(O)
 
+      (* TASSI: here we should think of a smarted data structure *)
+      module type Collector = sig
+        type t
+        val empty : t
+        val union : t -> t -> t
+        val inter : t -> t -> data list
+        val to_list : t -> data list
+      end
+      module Collector : Collector with type t = S.t = struct
+        type t = S.t
+        let union = S.union
+        let empty = S.empty
+
+        let merge l = 
+          let rec aux s w = function
+            | [] -> [s,w]
+            | (t, wt)::tl when w = wt -> aux (A.union s t) w tl
+            | (t, wt)::tl -> (s, w) :: aux t wt tl
+          in
+          match l with
+          | [] -> []
+          | (s, w) :: l -> aux s w l
+          
+        let rec undup ~eq = function
+          | [] -> []
+          | x :: tl -> x :: undup ~eq (List.filter (fun y -> not(eq x y)) tl)
+
+        let to_list t =
+          undup ~eq:(fun x y -> A.equal (A.singleton x) (A.singleton y)) 
+            (List.flatten (List.map 
+              (fun x,_ -> A.elements x) (merge (S.elements t))))
+
+        let inter t1 t2 =
+          let l1 = merge (S.elements t1) in
+          let l2 = merge (S.elements t2) in
+          let res = 
+           List.flatten
+            (List.map
+              (fun s, w ->
+                 HExtlib.filter_map (fun x ->
+                   try Some (x, w + snd (List.find (fun (s,w) -> A.mem x s) l2))
+                   with Not_found -> None)
+                   (A.elements s))
+              l1)
+          in
+          undup ~eq:(fun x y -> A.equal (A.singleton x) (A.singleton y)) 
+            (List.map fst (List.sort (fun (_,x) (_,y) -> y - x) res))
+      end
+
       let retrieve_sorted unif tree term =
         let path = I.path_string_of term in
         let rec retrieve n path tree =
@@ -209,9 +269,7 @@ and type data = A.elt and type dataset = A.t =
                     | no, path -> retrieve n path no
                   with Not_found -> S.empty)
        in
-        List.flatten 
-         (List.map (fun x -> List.map (fun y -> y, snd x) (A.elements (fst x))) 
-          (S.elements (retrieve 0 path tree)))
+        retrieve 0 path tree
       ;;
 
       let retrieve_generalizations_sorted tree term =