PyORAm
[iotcloud.git] / PyORAM / src / pyoram / storage / heap_storage.py
diff --git a/PyORAM/src/pyoram/storage/heap_storage.py b/PyORAM/src/pyoram/storage/heap_storage.py
new file mode 100644 (file)
index 0000000..46059c2
--- /dev/null
@@ -0,0 +1,245 @@
+__all__ = ('HeapStorage',)
+
+import struct
+
+from pyoram.util.virtual_heap import SizedVirtualHeap
+from pyoram.storage.block_storage import (BlockStorageInterface,
+                                          BlockStorageTypeFactory)
+
+class HeapStorageInterface(object):
+
+    def __enter__(self):
+        return self
+    def __exit__(self, *args):
+        self.close()
+
+    #
+    # Abstract Interface
+    #
+
+    def clone_device(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+
+    @classmethod
+    def compute_storage_size(cls, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    @classmethod
+    def setup(cls, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+
+    @property
+    def header_data(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    @property
+    def bucket_count(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    @property
+    def bucket_size(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    @property
+    def blocks_per_bucket(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    @property
+    def storage_name(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    @property
+    def virtual_heap(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    @property
+    def bucket_storage(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+
+    def update_header_data(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    def close(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    def read_path(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+    def write_path(self, *args, **kwds):
+        raise NotImplementedError                      # pragma: no cover
+
+    @property
+    def bytes_sent(self):
+        raise NotImplementedError                      # pragma: no cover
+    @property
+    def bytes_received(self):
+        raise NotImplementedError                      # pragma: no cover
+
+class HeapStorage(HeapStorageInterface):
+
+    _header_struct_string = "!LLL"
+    _header_offset = struct.calcsize(_header_struct_string)
+
+    def _new_storage(self, storage, **kwds):
+        storage_type = kwds.pop('storage_type', 'file')
+
+
+    def __init__(self, storage, **kwds):
+        if isinstance(storage, BlockStorageInterface):
+            self._storage = storage
+            if len(kwds):
+                raise ValueError(
+                    "Keywords not used when initializing "
+                    "with a storage device: %s"
+                    % (str(kwds)))
+        else:
+            storage_type = kwds.pop('storage_type', 'file')
+            self._storage = BlockStorageTypeFactory(storage_type)\
+                            (storage, **kwds)
+
+        heap_base, heap_height, blocks_per_bucket = \
+            struct.unpack(
+                self._header_struct_string,
+                self._storage.header_data[:self._header_offset])
+        self._vheap = SizedVirtualHeap(
+            heap_base,
+            heap_height,
+            blocks_per_bucket=blocks_per_bucket)
+
+    #
+    # Define HeapStorageInterface Methods
+    #
+
+    def clone_device(self):
+        return HeapStorage(self._storage.clone_device())
+
+    @classmethod
+    def compute_storage_size(cls,
+                             block_size,
+                             heap_height,
+                             blocks_per_bucket=1,
+                             heap_base=2,
+                             ignore_header=False,
+                             storage_type='file',
+                             **kwds):
+        assert (block_size > 0) and (block_size == int(block_size))
+        assert heap_height >= 0
+        assert blocks_per_bucket >= 1
+        assert heap_base >= 2
+        assert 'block_count' not in kwds
+        vheap = SizedVirtualHeap(
+            heap_base,
+            heap_height,
+            blocks_per_bucket=blocks_per_bucket)
+        if ignore_header:
+            return BlockStorageTypeFactory(storage_type).\
+                compute_storage_size(
+                    vheap.blocks_per_bucket * block_size,
+                    vheap.bucket_count(),
+                    ignore_header=True,
+                    **kwds)
+        else:
+            return cls._header_offset + \
+                BlockStorageTypeFactory(storage_type).\
+                compute_storage_size(
+                    vheap.blocks_per_bucket * block_size,
+                    vheap.bucket_count(),
+                    ignore_header=False,
+                    **kwds)
+
+    @classmethod
+    def setup(cls,
+              storage_name,
+              block_size,
+              heap_height,
+              blocks_per_bucket=1,
+              heap_base=2,
+              storage_type='file',
+              **kwds):
+        if 'block_count' in kwds:
+            raise ValueError("'block_count' keyword is not accepted")
+        if heap_height < 0:
+            raise ValueError(
+                "heap height must be 0 or greater. Invalid value: %s"
+                % (heap_height))
+        if blocks_per_bucket < 1:
+            raise ValueError(
+                "blocks_per_bucket must be 1 or greater. "
+                "Invalid value: %s" % (blocks_per_bucket))
+        if heap_base < 2:
+            raise ValueError(
+                "heap base must be 2 or greater. Invalid value: %s"
+                % (heap_base))
+
+        vheap = SizedVirtualHeap(
+            heap_base,
+            heap_height,
+            blocks_per_bucket=blocks_per_bucket)
+
+        user_header_data = kwds.pop('header_data', bytes())
+        if type(user_header_data) is not bytes:
+            raise TypeError(
+                "'header_data' must be of type bytes. "
+                "Invalid type: %s" % (type(user_header_data)))
+        kwds['header_data'] = \
+            struct.pack(cls._header_struct_string,
+                        heap_base,
+                        heap_height,
+                        blocks_per_bucket) + \
+            user_header_data
+
+        return HeapStorage(
+            BlockStorageTypeFactory(storage_type).setup(
+                storage_name,
+                vheap.blocks_per_bucket * block_size,
+                vheap.bucket_count(),
+                **kwds))
+
+    @property
+    def header_data(self):
+        return self._storage.header_data[self._header_offset:]
+
+    @property
+    def bucket_count(self):
+        return self._storage.block_count
+
+    @property
+    def bucket_size(self):
+        return self._storage.block_size
+
+    @property
+    def blocks_per_bucket(self):
+        return self._vheap.blocks_per_bucket
+
+    @property
+    def storage_name(self):
+        return self._storage.storage_name
+
+    @property
+    def virtual_heap(self):
+        return self._vheap
+
+    @property
+    def bucket_storage(self):
+        return self._storage
+
+    def update_header_data(self, new_header_data):
+        self._storage.update_header_data(
+            self._storage.header_data[:self._header_offset] + \
+            new_header_data)
+
+    def close(self):
+        print("Non enc heap close 1")
+        self._storage.close()
+        print("Non enc heap close 2")
+
+    def read_path(self, b, level_start=0):
+        assert 0 <= b < self._vheap.bucket_count()
+        bucket_list = self._vheap.Node(b).bucket_path_from_root()
+        assert 0 <= level_start < len(bucket_list)
+        return self._storage.read_blocks(bucket_list[level_start:])
+
+    def write_path(self, b, buckets, level_start=0):
+        assert 0 <= b < self._vheap.bucket_count()
+        bucket_list = self._vheap.Node(b).bucket_path_from_root()
+        assert 0 <= level_start < len(bucket_list)
+        self._storage.write_blocks(bucket_list[level_start:],
+                                   buckets)
+
+    @property
+    def bytes_sent(self):
+        return self._storage.bytes_sent
+
+    @property
+    def bytes_received(self):
+        return self._storage.bytes_received