1 (***********************************************************************)
5 (* Xavier Leroy, projet Cristal, INRIA Rocquencourt *)
7 (* Copyright 1996 Institut National de Recherche en Informatique et *)
8 (* en Automatique. All rights reserved. This file is distributed *)
9 (* under the terms of the GNU Library General Public License, with *)
10 (* the special exception on linking described in file ../LICENSE. *)
12 (***********************************************************************)
14 (* $Id: myMap.ml,v 1.3 2006/02/17 16:19:52 pottier Exp $ *)
16 module type OrderedType =
19 val compare: t -> t -> int
27 val is_empty: 'a t -> bool
28 val add: key -> 'a -> 'a t -> 'a t
29 val find: key -> 'a t -> 'a
30 val remove: key -> 'a t -> 'a t
31 val mem: key -> 'a t -> bool
32 val iter: (key -> 'a -> unit) -> 'a t -> unit
33 val map: ('a -> 'b) -> 'a t -> 'b t
34 val mapi: (key -> 'a -> 'b) -> 'a t -> 'b t
35 val fold: (key -> 'a -> 'b -> 'b) -> 'a t -> 'b -> 'b
36 val compare: ('a -> 'a -> int) -> 'a t -> 'a t -> int
37 val equal: ('a -> 'a -> bool) -> 'a t -> 'a t -> bool
38 type interval = key option * key option
39 val split: interval -> 'a t -> 'a t
40 val minimum: 'a t -> key * 'a
41 val find_remove: key -> 'a t -> 'a * 'a t
42 val update: key -> ('a -> 'a) -> 'a t -> 'a t
43 val restrict: (key -> bool) -> 'a t -> 'a t
46 module Make(Ord: OrderedType) = struct
52 | Node of 'a t * key * 'a * 'a t * int
56 | Node(_,_,_,_,h) -> h
59 let hl = height l and hr = height r in
60 Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1))
63 let hl = match l with Empty -> 0 | Node(_,_,_,_,h) -> h in
64 let hr = match r with Empty -> 0 | Node(_,_,_,_,h) -> h in
65 if hl > hr + 2 then begin
67 Empty -> invalid_arg "Map.bal"
68 | Node(ll, lv, ld, lr, _) ->
69 if height ll >= height lr then
70 create ll lv ld (create lr x d r)
73 Empty -> invalid_arg "Map.bal"
74 | Node(lrl, lrv, lrd, lrr, _)->
75 create (create ll lv ld lrl) lrv lrd (create lrr x d r)
77 end else if hr > hl + 2 then begin
79 Empty -> invalid_arg "Map.bal"
80 | Node(rl, rv, rd, rr, _) ->
81 if height rr >= height rl then
82 create (create l x d rl) rv rd rr
85 Empty -> invalid_arg "Map.bal"
86 | Node(rll, rlv, rld, rlr, _) ->
87 create (create l x d rll) rlv rld (create rlr rv rd rr)
90 Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1))
94 let is_empty = function Empty -> true | _ -> false
96 let rec add x data = function
98 Node(Empty, x, data, Empty, 1)
99 | Node(l, v, d, r, h) ->
100 let c = Ord.compare x v in
102 Node(l, x, data, r, h)
104 bal (add x data l) v d r
106 bal l v d (add x data r)
108 (* Same as create and bal, but no assumptions are made on the
109 relative heights of l and r. *)
111 let rec join l v d r =
113 (Empty, _) -> add v d r
114 | (_, Empty) -> add v d l
115 | (Node(ll, lv, ld, lr, lh), Node(rl, rv, rd, rr, rh)) ->
116 if lh > rh + 2 then bal ll lv ld (join lr v d r) else
117 if rh > lh + 2 then bal (join l v d rl) rv rd rr else
120 let rec find x = function
123 | Node(l, v, d, r, _) ->
124 let c = Ord.compare x v in
126 else find x (if c < 0 then l else r)
128 let rec mem x = function
131 | Node(l, v, d, r, _) ->
132 let c = Ord.compare x v in
133 c = 0 || mem x (if c < 0 then l else r)
135 let rec min_binding = function
136 Empty -> raise Not_found
137 | Node(Empty, x, d, r, _) -> (x, d)
138 | Node(l, x, d, r, _) -> min_binding l
140 let rec remove_min_binding = function
141 Empty -> invalid_arg "Map.remove_min_elt"
142 | Node(Empty, x, d, r, _) -> r
143 | Node(l, x, d, r, _) -> bal (remove_min_binding l) x d r
150 let (x, d) = min_binding t2 in
151 bal t1 x d (remove_min_binding t2)
153 let rec remove x = function
156 | Node(l, v, d, r, h) ->
157 let c = Ord.compare x v in
161 bal (remove x l) v d r
163 bal l v d (remove x r)
165 let rec iter f = function
167 | Node(l, v, d, r, _) ->
168 iter f l; f v d; iter f r
170 let rec map f = function
172 | Node(l, v, d, r, h) -> Node(map f l, v, f d, map f r, h)
174 let rec mapi f = function
176 | Node(l, v, d, r, h) -> Node(mapi f l, v, f v d, mapi f r, h)
178 let rec fold f m accu =
181 | Node(l, v, d, r, _) ->
182 fold f r (f v d (fold f l accu))
184 type 'a enumeration = End | More of key * 'a * 'a t * 'a enumeration
186 let rec cons_enum m e =
189 | Node(l, v, d, r, _) -> cons_enum l (More(v, d, r, e))
191 let compare cmp m1 m2 =
192 let rec compare_aux e1 e2 =
197 | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) ->
198 let c = Ord.compare v1 v2 in
199 if c <> 0 then c else
201 if c <> 0 then c else
202 compare_aux (cons_enum r1 e1) (cons_enum r2 e2)
203 in compare_aux (cons_enum m1 End) (cons_enum m2 End)
205 let equal cmp m1 m2 =
206 let rec equal_aux e1 e2 =
211 | (More(v1, d1, r1, e1), More(v2, d2, r2, e2)) ->
212 Ord.compare v1 v2 = 0 && cmp d1 d2 &&
213 equal_aux (cons_enum r1 e1) (cons_enum r2 e2)
214 in equal_aux (cons_enum m1 End) (cons_enum m2 End)
216 (* Intervals for splitting. An interval consists of a lower bound
217 and an upper bound, each of which can be absent. A key is
218 considered to lie within the interval if it is both greater than
219 (or equal to) the lower bound (if present) and less than (or
220 equal to) the upper bound (if present). *)
223 key option * key option
225 (* Splitting. split interval m returns a new map consisting of
226 all bindings in m whose keys are within interval. *)
228 let rec split ((lo, hi) as interval) = function
231 | Node(l, v, d, r, _) ->
232 let clo = Ord.compare v lo in
238 add v d (splithi hi r)
241 let chi = Ord.compare v hi in
244 join (splitlo lo l) v d (splithi hi r)
247 add v d (splitlo lo l)
252 and splitlo lo = function
255 | Node(l, v, d, r, _) ->
256 let c = Ord.compare v lo in
265 join (splitlo lo l) v d r
267 and splithi hi = function
270 | Node(l, v, d, r, _) ->
271 let c = Ord.compare v hi in
274 join l v d (splithi hi r)
282 (* Splitting. This is the public entry point. *)
284 let split interval m =
292 | Some lo, Some hi ->
295 (* Finding the minimum key in a map. *)
297 let rec minimum key data m =
301 | Node (l, k, d, _, _) ->
304 let minimum = function
307 | Node (l, k, d, _, _) ->
310 (* Finding an element and removing it in one single traversal. *)
312 let find_remove x m =
313 let data = ref None in
314 let rec remove = function
317 | Node(l, v, d, r, h) ->
318 let c = Ord.compare x v in
335 (* Updating the data associated with an element in one single traversal. *)
339 let rec update x f m =
340 let rec update = function
343 | Node(l, v, d, r, h) ->
344 let c = Ord.compare x v in
350 Node (l, v, d', r, h)
352 Node (update l, v, d, r, h)
354 Node (l, v, d, update r, h)
361 (* Restricting the domain of a map. *)