PyORAm
[iotcloud.git] / PyORAM / src / pyoram / util / virtual_heap.py
1 __all__ = ("VirtualHeap",
2            "SizedVirtualHeap")
3
4 import os
5 import sys
6 import subprocess
7 import random
8 import string
9 import tempfile
10
11 from six.moves import xrange
12
13 from pyoram.util._virtual_heap_helper import lib as _clib
14 from pyoram.util.misc import log2floor
15
16 numerals = ''.join([c for c in string.printable \
17                   if ((c not in string.whitespace) and \
18                       (c != '+') and (c != '-') and \
19                       (c != '"') and (c != "'") and \
20                       (c != '\\') and (c != '/'))])
21 numeral_index = dict((c,i) for i,c in enumerate(numerals))
22
23 # The maximum heap base for which base k labels
24 # can be produced.
25 max_k_labeled = len(numerals)
26
27 def base10_integer_to_basek_string(k, x):
28     """Convert an integer into a base k string."""
29     if not (2 <= k <= max_k_labeled):
30         raise ValueError("k must be in range [2, %d]: %s"
31                          % (max_k_labeled, k))
32     return ((x == 0) and numerals[0]) or \
33         (base10_integer_to_basek_string(k, x // k).\
34          lstrip(numerals[0]) + numerals[x % k])
35
36 def basek_string_to_base10_integer(k, x):
37     """Convert a base k string into an integer."""
38     assert 1 < k <= max_k_labeled
39     return sum(numeral_index[c]*(k**i)
40                for i, c in enumerate(reversed(x)))
41
42 # _clib defines a faster version of this function
43 def calculate_bucket_level(k, b):
44     """
45     Calculate the level in which a 0-based bucket
46     lives inside of a k-ary heap.
47     """
48     assert k >= 2
49     if k == 2:
50         return log2floor(b+1)
51     v = (k - 1) * (b + 1) + 1
52     h = 0
53     while k**(h+1) < v:
54         h += 1
55     return h
56
57 # _clib defines a faster version of this function
58 def calculate_last_common_level(k, b1, b2):
59     """
60     Calculate the highest level after which the
61     paths from the root to these buckets diverge.
62     """
63     l1 = calculate_bucket_level(k, b1)
64     l2 = calculate_bucket_level(k, b2)
65     while l1 > l2:
66         b1 = (b1-1)//k
67         l1 -= 1
68     while l2 > l1:
69         b2 = (b2-1)//k
70         l2 -= 1
71     while b1 != b2:
72         b1 = (b1-1)//k
73         b2 = (b2-1)//k
74         l1 -= 1
75     return l1
76
77 def calculate_necessary_heap_height(k, n):
78     """
79     Calculate the necessary k-ary heap height
80     to store n buckets.
81     """
82     assert n >= 1
83     return calculate_bucket_level(k, n-1)
84
85 def calculate_bucket_count_in_heap_with_height(k, h):
86     """
87     Calculate the number of buckets in a
88     k-ary heap of height h.
89     """
90     assert h >= 0
91     return ((k**(h+1)) - 1) // (k - 1)
92
93 def calculate_bucket_count_in_heap_at_level(k, l):
94     """
95     Calculate the number of buckets in a
96     k-ary heap at level l.
97     """
98     assert l >= 0
99     return k**l
100
101 def calculate_leaf_bucket_count_in_heap_with_height(k, h):
102     """
103     Calculate the number of buckets in the
104     leaf-level of a k-ary heap of height h.
105     """
106     return calculate_bucket_count_in_heap_at_level(k, h)
107
108 def create_node_type(k):
109
110     class VirtualHeapNode(object):
111         __slots__ = ("bucket", "level")
112         def __init__(self, bucket):
113             assert bucket >= 0
114             self.bucket = bucket
115             self.level = _clib.calculate_bucket_level(self.k, self.bucket)
116
117         def __hash__(self):
118             return self.bucket.__hash__()
119         def __int__(self):
120             return self.bucket
121         def __lt__(self, other):
122             return self.bucket < other
123         def __le__(self, other):
124             return self.bucket <= other
125         def __eq__(self, other):
126             return self.bucket == other
127         def __ne__(self, other):
128             return self.bucket != other
129         def __gt__(self, other):
130             return self.bucket > other
131         def __ge__(self, other):
132             return self.bucket >= other
133         def last_common_level(self, n):
134             return _clib.calculate_last_common_level(self.k,
135                                                      self.bucket,
136                                                      n.bucket)
137         def child_node(self, c):
138             assert type(c) is int
139             assert 0 <= c < self.k
140             return VirtualHeapNode(self.k * self.bucket + 1 + c)
141         def parent_node(self):
142             if self.bucket != 0:
143                 return VirtualHeapNode((self.bucket - 1)//self.k)
144             return None
145         def ancestor_node_at_level(self, level):
146             if level > self.level:
147                 return None
148             current = self
149             while current.level != level:
150                 current = current.parent_node()
151             return current
152         def path_to_root(self):
153             bucket = self.bucket
154             yield self
155             while bucket != 0:
156                 bucket = (bucket - 1)//self.k
157                 yield type(self)(bucket)
158         def path_from_root(self):
159             return list(reversed(list(self.path_to_root())))
160         def bucket_path_to_root(self):
161             bucket = self.bucket
162             yield bucket
163             while bucket != 0:
164                 bucket = (bucket - 1)//self.k
165                 yield bucket
166         def bucket_path_from_root(self):
167             return list(reversed(list(self.bucket_path_to_root())))
168
169         #
170         # Expensive Functions
171         #
172         def __repr__(self):
173             try:
174                 label = self.label()
175             except ValueError:
176                 # presumably, k is too large
177                 label = "<unknown>"
178             return ("VirtualHeapNode(k=%s, bucket=%s, level=%s, label=%r)"
179                     % (self.k, self.bucket, self.level, label))
180         def __str__(self):
181             """Returns a tuple (<level>, <bucket offset within level>)."""
182             if self.bucket != 0:
183                 return ("(%s, %s)"
184                         % (self.level,
185                            self.bucket -
186                            calculate_bucket_count_in_heap_with_height(self.k,
187                                                                 self.level-1)))
188             assert self.level == 0
189             return "(0, 0)"
190
191         def label(self):
192             assert 0 <= self.bucket
193             if self.level == 0:
194                 return ''
195             b_offset = self.bucket - \
196                        calculate_bucket_count_in_heap_with_height(self.k,
197                                                             self.level-1)
198             basek = base10_integer_to_basek_string(self.k, b_offset)
199             return basek.zfill(self.level)
200
201         def is_node_on_path(self, n):
202             if n.level <= self.level:
203                 n_label = n.label()
204                 if n_label == "":
205                     return True
206                 return self.label().startswith(n_label)
207             return False
208
209     VirtualHeapNode.k = k
210
211     return VirtualHeapNode
212
213 class VirtualHeap(object):
214
215     clib = _clib
216     random = random.SystemRandom()
217
218     def __init__(self, k, blocks_per_bucket=1):
219         assert 1 < k
220         assert blocks_per_bucket >= 1
221         self._k = k
222         self._blocks_per_bucket = blocks_per_bucket
223         self.Node = create_node_type(k)
224
225     @property
226     def k(self):
227         return self._k
228
229     def node_label_to_bucket(self, label):
230         if len(label) > 0:
231             return \
232                 (calculate_bucket_count_in_heap_with_height(self.k,
233                                                       len(label)-1) +
234                  basek_string_to_base10_integer(self.k, label))
235         return 0
236
237     #
238     # Buckets (0-based integer, equivalent to block for heap
239     # with blocks_per_bucket=1)
240     #
241
242     @property
243     def blocks_per_bucket(self):
244         return self._blocks_per_bucket
245
246     def bucket_count_at_level(self, l):
247         return calculate_bucket_count_in_heap_at_level(self.k, l)
248     def first_bucket_at_level(self, l):
249         if l > 0:
250             return calculate_bucket_count_in_heap_with_height(self.k, l-1)
251         return 0
252     def last_bucket_at_level(self, l):
253         return calculate_bucket_count_in_heap_with_height(self.k, l) - 1
254     def random_bucket_up_to_level(self, l):
255         return self.random.randint(self.first_bucket_at_level(0),
256                                    self.last_bucket_at_level(l))
257     def random_bucket_at_level(self, l):
258         return self.random.randint(self.first_bucket_at_level(l),
259                                    self.first_bucket_at_level(l+1)-1)
260
261     #
262     # Nodes (a class that helps with heap path calculations)
263     #
264
265     def root_node(self):
266         return self.first_node_at_level(0)
267     def node_count_at_level(self, l):
268         return self.bucket_count_at_level(l)
269     def first_node_at_level(self, l):
270         return self.Node(self.first_bucket_at_level(l))
271     def last_node_at_level(self, l):
272         return self.Node(self.last_bucket_at_level(l))
273     def random_node_up_to_level(self, l):
274         return self.Node(self.random_bucket_up_to_level(l))
275     def random_node_at_level(self, l):
276         return self.Node(self.random_bucket_at_level(l))
277
278     #
279     # Block (0-based integer)
280     #
281
282     def bucket_to_block(self, b):
283         assert b >= 0
284         return b * self.blocks_per_bucket
285     def block_to_bucket(self, s):
286         assert s >= 0
287         return s//self.blocks_per_bucket
288     def first_block_in_bucket(self, b):
289         return self.bucket_to_block(b)
290     def last_block_in_bucket(self, b):
291         return self.bucket_to_block(b) + self.blocks_per_bucket - 1
292     def block_count_at_level(self, l):
293         return self.bucket_count_at_level(l) * self.blocks_per_bucket
294     def first_block_at_level(self, l):
295         return self.bucket_to_block(self.first_bucket_at_level(l))
296     def last_block_at_level(self, l):
297         return self.bucket_to_block(self.first_bucket_at_level(l+1)) - 1
298
299 class SizedVirtualHeap(VirtualHeap):
300
301     def __init__(self, k, height, blocks_per_bucket=1):
302         super(SizedVirtualHeap, self).\
303             __init__(k, blocks_per_bucket=blocks_per_bucket)
304         self._height = height
305
306     #
307     # Size properties
308     #
309     @property
310     def height(self):
311         return self._height
312     @property
313     def levels(self):
314         return self.height + 1
315     @property
316     def first_level(self):
317         return 0
318     @property
319     def last_level(self):
320         return self.height
321
322     #
323     # Buckets (0-based integer, equivalent to block for heap
324     # with blocks_per_bucket=1)
325     #
326
327     def bucket_count(self):
328         return calculate_bucket_count_in_heap_with_height(self.k,
329                                                           self.height)
330     def leaf_bucket_count(self):
331         return calculate_leaf_bucket_count_in_heap_with_height(self.k,
332                                                                self.height)
333     def first_leaf_bucket(self):
334         return self.first_bucket_at_level(self.height)
335     def last_leaf_bucket(self):
336         return self.last_bucket_at_level(self.height)
337     def random_bucket(self):
338         return self.random.randint(self.first_bucket_at_level(0),
339                                    self.last_leaf_bucket())
340     def random_leaf_bucket(self):
341         return self.random_bucket_at_level(self.height)
342
343     #
344     # Nodes (a class that helps with heap path calculations)
345     #
346
347     def is_nil_node(self, n):
348         return n.bucket >= self.bucket_count()
349     def node_count(self):
350         return self.bucket_count()
351     def leaf_node_count(self):
352         return self.leaf_bucket_count()
353     def first_leaf_node(self):
354         return self.Node(self.first_leaf_bucket())
355     def last_leaf_node(self):
356         return self.Node(self.last_leaf_bucket())
357     def random_leaf_node(self):
358         return self.Node(self.random_leaf_bucket())
359     def random_node(self):
360         return self.Node(self.random_bucket())
361
362     #
363     # Block (0-based integer)
364     #
365
366     def block_count(self):
367         return self.bucket_count() * self.blocks_per_bucket
368     def leaf_block_count(self):
369         return self.leaf_bucket_count() * self.blocks_per_bucket
370     def first_leaf_block(self):
371         return self.first_block_in_bucket(self.first_leaf_bucket())
372     def last_leaf_block(self):
373         return self.last_block_in_bucket(self.last_leaf_bucket())
374
375     #
376     # Visualization
377     #
378
379     def write_as_dot(self, f, data=None, max_levels=None):
380         "Write the tree in the dot language format to f."
381         assert (max_levels is None) or (max_levels >= 0)
382         def visit_node(n, levels):
383             lbl = "{"
384             if data is None:
385                 if self.k <= max_k_labeled:
386                     lbl = repr(n.label()).\
387                           replace("{","\{").\
388                           replace("}","\}").\
389                           replace("|","\|").\
390                           replace("<","\<").\
391                           replace(">","\>")
392                 else:
393                     lbl = str(n)
394             else:
395                 s = self.bucket_to_block(n.bucket)
396                 for i in xrange(self.blocks_per_bucket):
397                     lbl += "{%s}" % (data[s+i])
398                     if i + 1 != self.blocks_per_bucket:
399                         lbl += "|"
400             lbl += "}"
401             f.write("  %s [penwidth=%s,label=\"%s\"];\n"
402                     % (n.bucket, 1, lbl))
403             levels += 1
404             if (max_levels is None) or (levels <= max_levels):
405                 for i in xrange(self.k):
406                     cn = n.child_node(i)
407                     if not self.is_nil_node(cn):
408                         visit_node(cn, levels)
409                         f.write("  %s -> %s ;\n" % (n.bucket, cn.bucket))
410
411         f.write("// Created by SizedVirtualHeap.write_as_dot(...)\n")
412         f.write("digraph heaptree {\n")
413         f.write("node [shape=record]\n")
414
415         if (max_levels is None) or (max_levels > 0):
416             visit_node(self.root_node(), 1)
417         f.write("}\n")
418
419     def save_image_as_pdf(self, filename, data=None, max_levels=None):
420         "Write the heap as PDF file."
421         assert (max_levels is None) or (max_levels >= 0)
422         import os
423         if not filename.endswith('.pdf'):
424             filename = filename+'.pdf'
425         tmpfd, tmpname = tempfile.mkstemp(suffix='dot')
426         with open(tmpname, 'w') as f:
427             self.write_as_dot(f, data=data, max_levels=max_levels)
428         os.close(tmpfd)
429         try:
430             subprocess.call(['dot',
431                              tmpname,
432                              '-Tpdf',
433                              '-o',
434                              ('%s'%filename)])
435         except OSError:
436             sys.stderr.write(
437                 "DOT -> PDF conversion failed. See DOT file: %s\n"
438                 % (tmpname))
439             return False
440         os.remove(tmpname)
441         return True