]> matita.cs.unibo.it Git - helm.git/blobdiff - matita/components/extlib/discrimination_tree.ml
Use of standard OCaml syntax
[helm.git] / matita / components / extlib / discrimination_tree.ml
index f96a0de5695be7f414408b9a582561cbc05a2c2d..beb6d51fd280916683b85e4c417d7fe8eb2b0833 100644 (file)
@@ -70,6 +70,17 @@ module type DiscriminationTree =
       val in_index : t -> input -> (data -> bool) -> bool
       val retrieve_generalizations : t -> input -> dataset
       val retrieve_unifiables : t -> input -> dataset
+
+      module type Collector = sig
+        type t
+        val empty : t
+        val union : t -> t -> t
+        val inter : t -> t -> data list
+        val to_list : t -> data list
+      end
+      module Collector : Collector
+      val retrieve_generalizations_sorted : t -> input -> Collector.t
+      val retrieve_unifiables_sorted : t -> input -> Collector.t
     end
 
 module Make (I:Indexable) (A:Set.S) : DiscriminationTree 
@@ -146,8 +157,8 @@ and type data = A.elt and type dataset = A.t =
       (* the equivalent of skip, but on the index, thus the list of trees
          that are rooted just after the term represented by the tree root
          are returned (we are skipping the root) *)
-      let skip_root = function DiscriminationTree.Node (value, map) ->
-        let rec get n = function DiscriminationTree.Node (v, m) as tree ->
+      let skip_root = function DiscriminationTree.Node (_value, map) ->
+        let rec get n = function DiscriminationTree.Node (_v, m) as tree ->
            if n = 0 then [tree] else 
            PSMap.fold (fun k v res -> (get (n-1 + arity_of k) v) @ res) m []
         in
@@ -160,7 +171,7 @@ and type data = A.elt and type dataset = A.t =
           match tree, path with
           | DiscriminationTree.Node (Some s, _), [] -> s
           | DiscriminationTree.Node (None, _), [] -> A.empty 
-          | DiscriminationTree.Node (_, map), Variable::path when unif ->
+          | DiscriminationTree.Node (_, _map), Variable::path when unif ->
               List.fold_left A.union A.empty
                 (List.map (retrieve path) (skip_root tree))
           | DiscriminationTree.Node (_, map), node::path ->
@@ -179,6 +190,92 @@ and type data = A.elt and type dataset = A.t =
 
       let retrieve_generalizations tree term = retrieve false tree term;;
       let retrieve_unifiables tree term = retrieve true tree term;;
+
+      module O = struct
+        type t = A.t * int
+        let compare (sa,wa) (sb,wb) = 
+          let c = compare wb wa in
+          if c <> 0 then c else A.compare sb sa
+      end
+      module S = Set.Make(O)
+
+      (* TASSI: here we should think of a smarted data structure *)
+      module type Collector = sig
+        type t
+        val empty : t
+        val union : t -> t -> t
+        val inter : t -> t -> data list
+        val to_list : t -> data list
+      end
+      module Collector : Collector with type t = S.t = struct
+        type t = S.t
+        let union = S.union
+        let empty = S.empty
+
+        let merge l = 
+          let rec aux s w = function
+            | [] -> [s,w]
+            | (t, wt)::tl when w = wt -> aux (A.union s t) w tl
+            | (t, wt)::tl -> (s, w) :: aux t wt tl
+          in
+          match l with
+          | [] -> []
+          | (s, w) :: l -> aux s w l
+          
+        let rec undup ~eq = function
+          | [] -> []
+          | x :: tl -> x :: undup ~eq (List.filter (fun y -> not(eq x y)) tl)
+
+        let to_list t =
+          undup ~eq:(fun x y -> A.equal (A.singleton x) (A.singleton y)) 
+            (List.flatten (List.map 
+              (fun (x,_) -> A.elements x) (merge (S.elements t))))
+
+        let inter t1 t2 =
+          let l1 = merge (S.elements t1) in
+          let l2 = merge (S.elements t2) in
+          let res = 
+           List.flatten
+            (List.map
+              (fun (s, w) ->
+                 HExtlib.filter_map (fun x ->
+                   try Some (x, w + snd (List.find (fun (s,_w) -> A.mem x s) l2))
+                   with Not_found -> None)
+                   (A.elements s))
+              l1)
+          in
+          undup ~eq:(fun x y -> A.equal (A.singleton x) (A.singleton y)) 
+            (List.map fst (List.sort (fun (_,x) (_,y) -> y - x) res))
+      end
+
+      let retrieve_sorted unif tree term =
+        let path = I.path_string_of term in
+        let rec retrieve n path tree =
+          match tree, path with
+          | DiscriminationTree.Node (Some s, _), [] -> S.singleton (s, n)
+          | DiscriminationTree.Node (None, _), [] -> S.empty
+          | DiscriminationTree.Node (_, _map), Variable::path when unif ->
+              List.fold_left S.union S.empty
+                (List.map (retrieve n path) (skip_root tree))
+          | DiscriminationTree.Node (_, map), node::path ->
+              S.union
+                 (if not unif && node = Variable then S.empty else
+                  try retrieve (n+1) path (PSMap.find node map)
+                  with Not_found -> S.empty)
+                 (try
+                    match PSMap.find Variable map,skip (arity_of node) path with
+                    | DiscriminationTree.Node (Some s, _), [] -> 
+                       S.singleton (s, n)
+                    | no, path -> retrieve n path no
+                  with Not_found -> S.empty)
+       in
+        retrieve 0 path tree
+      ;;
+
+      let retrieve_generalizations_sorted tree term = 
+        retrieve_sorted false tree term;;
+      let retrieve_unifiables_sorted tree term = 
+        retrieve_sorted true tree term;;
   end
 ;;