PyORAm
[iotcloud.git] / PyORAM / src / pyoram / tests / test_aes.py
1 import unittest2
2
3 from pyoram.crypto.aes import AES
4
5 class TestAES(unittest2.TestCase):
6
7     def test_KeyGen(self):
8         self.assertEqual(len(AES.key_sizes), 3)
9         self.assertEqual(len(set(AES.key_sizes)), 3)
10         for keysize in AES.key_sizes:
11             key_list = []
12             key_set = set()
13             for i in range(10):
14                 k = AES.KeyGen(keysize)
15                 self.assertEqual(len(k), keysize)
16                 key_list.append(k)
17                 key_set.add(k)
18             self.assertEqual(len(key_list), 10)
19             # make sure every key is unique
20             self.assertEqual(len(key_list), len(key_set))
21
22     def test_CTR(self):
23         self._test_Enc_Dec(
24             AES.CTREnc,
25             AES.CTRDec,
26             lambda i, size: bytes(bytearray([i]) * size))
27
28     def test_GCM(self):
29         self._test_Enc_Dec(
30             AES.GCMEnc,
31             AES.GCMDec,
32             lambda i, size: bytes(bytearray([i]) * size))
33
34     def _test_Enc_Dec(self,
35                       enc_func,
36                       dec_func,
37                       get_plaintext):
38         blocksize_factor = [0.5, 1, 1.5, 2, 2.5]
39         plaintext_blocks = []
40         for i, f in enumerate(blocksize_factor):
41             size = AES.block_size * f
42             size = int(round(size))
43             if int(f) != f:
44                 assert (size % AES.block_size) != 0
45             plaintext_blocks.append(get_plaintext(i, size))
46
47         assert len(AES.key_sizes) > 0
48         ciphertext_blocks = {}
49         keys = {}
50         for keysize in AES.key_sizes:
51             key = AES.KeyGen(keysize)
52             keys[keysize] = key
53             ciphertext_blocks[keysize] = []
54             for block in plaintext_blocks:
55                 ciphertext_blocks[keysize].append(
56                     enc_func(key, block))
57
58         self.assertEqual(len(ciphertext_blocks),
59                          len(AES.key_sizes))
60         self.assertEqual(len(keys),
61                          len(AES.key_sizes))
62
63         plaintext_decrypted_blocks = {}
64         for keysize in keys:
65             key = keys[keysize]
66             plaintext_decrypted_blocks[keysize] = []
67             for block in ciphertext_blocks[keysize]:
68                 plaintext_decrypted_blocks[keysize].append(
69                     dec_func(key, block))
70
71         self.assertEqual(len(plaintext_decrypted_blocks),
72                          len(AES.key_sizes))
73
74         for i in range(len(blocksize_factor)):
75             for keysize in AES.key_sizes:
76                 self.assertEqual(
77                     plaintext_blocks[i],
78                     plaintext_decrypted_blocks[keysize][i])
79                 self.assertNotEqual(
80                     plaintext_blocks[i],
81                     ciphertext_blocks[keysize][i])
82                 if enc_func is AES.CTREnc:
83                     self.assertEqual(
84                         len(ciphertext_blocks[keysize][i]),
85                         len(plaintext_blocks[i]) + AES.block_size)
86                 else:
87                     assert enc_func is AES.GCMEnc
88                     self.assertEqual(
89                         len(ciphertext_blocks[keysize][i]),
90                         len(plaintext_blocks[i]) + 2*AES.block_size)
91                 # check IND-CPA
92                 key = keys[keysize]
93                 alt_ciphertext = enc_func(key, plaintext_blocks[i])
94                 self.assertNotEqual(
95                     ciphertext_blocks[keysize][i],
96                     alt_ciphertext)
97                 self.assertEqual(
98                     len(ciphertext_blocks[keysize][i]),
99                     len(alt_ciphertext))
100                 self.assertNotEqual(
101                     ciphertext_blocks[keysize][i][:AES.block_size],
102                     alt_ciphertext[:AES.block_size])
103                 self.assertNotEqual(
104                     ciphertext_blocks[keysize][i][AES.block_size:],
105                     alt_ciphertext[AES.block_size:])
106
107 if __name__ == "__main__":
108     unittest2.main()                                    # pragma: no cover