]> matita.cs.unibo.it Git - pkg-cerco/acc-trusted.git/blob - extracted/untrusted/myMap.ml
Imported Upstream version 0.1
[pkg-cerco/acc-trusted.git] / extracted / untrusted / myMap.ml
1 (***********************************************************************)
2 (*                                                                     *)
3 (*                           Objective Caml                            *)
4 (*                                                                     *)
5 (*            Xavier Leroy, projet Cristal, INRIA Rocquencourt         *)
6 (*                                                                     *)
7 (*  Copyright 1996 Institut National de Recherche en Informatique et   *)
8 (*  en Automatique.  All rights reserved.  This file is distributed    *)
9 (*  under the terms of the GNU Library General Public License, with    *)
10 (*  the special exception on linking described in file ../LICENSE.     *)
11 (*                                                                     *)
12 (***********************************************************************)
13
14 (* $Id: myMap.ml,v 1.3 2006/02/17 16:19:52 pottier Exp $ *)
15
16 module type OrderedType =
17   sig
18     type t
19     val compare: t -> t -> int
20   end
21
22 module type S =
23   sig
24     type key
25     type +'a t
26     val empty: 'a t
27     val is_empty: 'a t -> bool
28     val add: key -> 'a -> 'a t -> 'a t
29     val find: key -> 'a t -> 'a
30     val remove: key -> 'a t -> 'a t
31     val mem:  key -> 'a t -> bool
32     val iter: (key -> 'a -> unit) -> 'a t -> unit
33     val map: ('a -> 'b) -> 'a t -> 'b t
34     val mapi: (key -> 'a -> 'b) -> 'a t -> 'b t
35     val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
36     val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int
37     val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
38     type interval = key option * key option
39     val split: interval -> 'a t -> 'a t
40     val minimum: 'a t -> key * 'a
41     val find_remove: key -> 'a t -> 'a * 'a t
42     val update: key -> ('a -> 'a) -> 'a t -> 'a t
43     val restrict: (key -> bool) -> 'a t -> 'a t
44   end
45
46 module Make(Ord: OrderedType) = struct
47
48     type key = Ord.t
49
50     type 'a t =
51         Empty
52       | Node of 'a t * key * 'a * 'a t * int
53
54     let height = function
55         Empty -> 0
56       | Node(_,_,_,_,h) -> h
57
58     let create l x d r =
59       let hl = height l and hr = height r in
60       Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1))
61
62     let bal l x d r =
63       let hl = match l with Empty -> 0 | Node(_,_,_,_,h) -> h in
64       let hr = match r with Empty -> 0 | Node(_,_,_,_,h) -> h in
65       if hl > hr + 2 then begin
66         match l with
67           Empty -> invalid_arg "Map.bal"
68         | Node(ll, lv, ld, lr, _) ->
69             if height ll >= height lr then
70               create ll lv ld (create lr x d r)
71             else begin
72               match lr with
73                 Empty -> invalid_arg "Map.bal"
74               | Node(lrl, lrv, lrd, lrr, _)->
75                   create (create ll lv ld lrl) lrv lrd (create lrr x d r)
76             end
77       end else if hr > hl + 2 then begin
78         match r with
79           Empty -> invalid_arg "Map.bal"
80         | Node(rl, rv, rd, rr, _) ->
81             if height rr >= height rl then
82               create (create l x d rl) rv rd rr
83             else begin
84               match rl with
85                 Empty -> invalid_arg "Map.bal"
86               | Node(rll, rlv, rld, rlr, _) ->
87                   create (create l x d rll) rlv rld (create rlr rv rd rr)
88             end
89       end else
90         Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1))
91
92     let empty = Empty
93
94     let is_empty = function Empty -> true | _ -> false
95
96     let rec add x data = function
97         Empty ->
98           Node(Empty, x, data, Empty, 1)
99       | Node(l, v, d, r, h) ->
100           let c = Ord.compare x v in
101           if c = 0 then
102             Node(l, x, data, r, h)
103           else if c < 0 then
104             bal (add x data l) v d r
105           else
106             bal l v d (add x data r)
107
108     (* Same as create and bal, but no assumptions are made on the
109        relative heights of l and r. *)
110
111     let rec join l v d r =
112       match (l, r) with
113         (Empty, _) -> add v d r
114       | (_, Empty) -> add v d l
115       | (Node(ll, lv, ld, lr, lh), Node(rl, rv, rd, rr, rh)) ->
116           if lh > rh + 2 then bal ll lv ld (join lr v d r) else
117           if rh > lh + 2 then bal (join l v d rl) rv rd rr else
118           create l v d r
119
120     let rec find x = function
121         Empty ->
122           raise Not_found
123       | Node(l, v, d, r, _) ->
124           let c = Ord.compare x v in
125           if c = 0 then d
126           else find x (if c < 0 then l else r)
127
128     let rec mem x = function
129         Empty ->
130           false
131       | Node(l, v, d, r, _) ->
132           let c = Ord.compare x v in
133           c = 0 || mem x (if c < 0 then l else r)
134
135     let rec min_binding = function
136         Empty -> raise Not_found
137       | Node(Empty, x, d, r, _) -> (x, d)
138       | Node(l, x, d, r, _) -> min_binding l
139
140     let rec remove_min_binding = function
141         Empty -> invalid_arg "Map.remove_min_elt"
142       | Node(Empty, x, d, r, _) -> r
143       | Node(l, x, d, r, _) -> bal (remove_min_binding l) x d r
144
145     let merge t1 t2 =
146       match (t1, t2) with
147         (Empty, t) -> t
148       | (t, Empty) -> t
149       | (_, _) ->
150           let (x, d) = min_binding t2 in
151           bal t1 x d (remove_min_binding t2)
152
153     let rec remove x = function
154         Empty ->
155           Empty
156       | Node(l, v, d, r, h) ->
157           let c = Ord.compare x v in
158           if c = 0 then
159             merge l r
160           else if c < 0 then
161             bal (remove x l) v d r
162           else
163             bal l v d (remove x r)
164
165     let rec iter f = function
166         Empty -> ()
167       | Node(l, v, d, r, _) ->
168           iter f l; f v d; iter f r
169
170     let rec map f = function
171         Empty               -> Empty
172       | Node(l, v, d, r, h) -> Node(map f l, v, f d, map f r, h)
173
174     let rec mapi f = function
175         Empty               -> Empty
176       | Node(l, v, d, r, h) -> Node(mapi f l, v, f v d, mapi f r, h)
177
178     let rec fold f m accu =
179       match m with
180         Empty -> accu
181       | Node(l, v, d, r, _) ->
182           fold f r (f v d (fold f l accu))
183
184     type 'a enumeration = End | More of key * 'a * 'a t * 'a enumeration
185
186     let rec cons_enum m e =
187       match m with
188         Empty -> e
189       | Node(l, v, d, r, _) -> cons_enum l (More(v, d, r, e))
190
191     let compare cmp m1 m2 =
192       let rec compare_aux e1 e2 =
193           match (e1, e2) with
194           (End, End) -> 0
195         | (End, _)  -> -1
196         | (_, End) -> 1
197         | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) ->
198             let c = Ord.compare v1 v2 in
199             if c <> 0 then c else
200             let c = cmp d1 d2 in
201             if c <> 0 then c else
202             compare_aux (cons_enum r1 e1) (cons_enum r2 e2)
203       in compare_aux (cons_enum m1 End) (cons_enum m2 End)
204
205     let equal cmp m1 m2 =
206       let rec equal_aux e1 e2 =
207           match (e1, e2) with
208           (End, End) -> true
209         | (End, _)  -> false
210         | (_, End) -> false
211         | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) ->
212             Ord.compare v1 v2 = 0 && cmp d1 d2 &&
213             equal_aux (cons_enum r1 e1) (cons_enum r2 e2)
214       in equal_aux (cons_enum m1 End) (cons_enum m2 End)
215
216   (* Intervals for splitting. An interval consists of a lower bound
217      and an upper bound, each of which can be absent. A key is
218      considered to lie within the interval if it is both greater than
219      (or equal to) the lower bound (if present) and less than (or
220      equal to) the upper bound (if present). *)
221
222   type interval =
223       key option * key option
224
225   (* Splitting. split interval m returns a new map consisting of
226      all bindings in m whose keys are within interval. *)
227
228   let rec split ((lo, hi) as interval) = function
229       Empty ->
230         Empty
231     | Node(l, v, d, r, _) ->
232         let clo = Ord.compare v lo in
233         if clo < 0 then
234           (* v < lo *)
235           split interval r
236         else if clo = 0 then
237           (* v = lo *)
238           add v d (splithi hi r)
239         else
240           (* v > lo *)
241           let chi = Ord.compare v hi in
242           if chi < 0 then
243             (* v < hi *)
244             join (splitlo lo l) v d (splithi hi r)
245           else if chi = 0 then
246             (* v = hi *)
247             add v d (splitlo lo l)
248           else
249             (* v > hi *)
250             split interval l
251
252   and splitlo lo = function
253       Empty ->
254         Empty
255     | Node(l, v, d, r, _) ->
256         let c = Ord.compare v lo in
257         if c < 0 then
258           (* v < lo *)
259           splitlo lo r
260         else if c = 0 then
261           (* v = lo *)
262           add v d r
263         else
264           (* v > lo *)
265           join (splitlo lo l) v d r
266
267   and splithi hi = function
268       Empty ->
269         Empty
270     | Node(l, v, d, r, _) ->
271         let c = Ord.compare v hi in
272         if c < 0 then
273           (* v < hi *)
274           join l v d (splithi hi r)
275         else if c = 0 then
276           (* v = hi *)
277           add v d l
278         else
279           (* v > hi *)
280           splithi hi l
281
282   (* Splitting. This is the public entry point. *)
283
284   let split interval m =
285     match interval with
286     | None, None ->
287         m
288     | Some lo, None ->
289         splitlo lo m
290     | None, Some hi ->
291         splithi hi m
292     | Some lo, Some hi ->
293         split (lo, hi) m
294
295   (* Finding the minimum key in a map. *)
296
297   let rec minimum key data m =
298     match m with
299     | Empty ->
300         (key, data)
301     | Node (l, k, d, _, _) ->
302         minimum k d l
303
304   let minimum = function
305     | Empty ->
306         raise Not_found
307     | Node (l, k, d, _, _) ->
308         minimum k d l
309
310   (* Finding an element and removing it in one single traversal. *)
311
312   let find_remove x m =
313     let data = ref None in
314     let rec remove = function
315       | Empty ->
316           raise Not_found
317       | Node(l, v, d, r, h) ->
318           let c = Ord.compare x v in
319           if c = 0 then begin
320             data := Some d;
321             merge l r
322           end
323           else if c < 0 then
324             bal (remove l) v d r
325           else
326             bal l v d (remove r)
327     in
328     let m = remove m in
329     match !data with
330     | None ->
331         assert false
332     | Some d ->
333         d, m
334
335   (* Updating the data associated with an element in one single traversal. *)
336
337   exception Unmodified
338
339   let rec update x f m =
340     let rec update = function
341       | Empty ->
342           assert false
343       | Node(l, v, d, r, h) ->
344           let c = Ord.compare x v in
345           if c = 0 then
346             let d' = f d in
347             if d == d' then
348               raise Unmodified
349             else
350               Node (l, v, d', r, h)
351           else if c < 0 then
352             Node (update l, v, d, r, h)
353           else
354             Node (l, v, d, update r, h)
355     in
356     try
357       update m
358     with Unmodified ->
359       m
360
361   (* Restricting the domain of a map. *)
362
363   let restrict p m =
364     fold (fun x d m ->
365       if p x then
366         add x d m
367       else
368         m
369     ) m empty
370
371
372 end