Skip to content

Commit 96ec4fb

Browse files
pbelevichfacebook-github-bot
authored andcommitted
Allow decryption output tensor to be less than input(skipping padding) (#95)
Summary: Pull Request resolved: #95 Test Plan: Imported from OSS Reviewed By: heitorschueroff Differential Revision: D25506728 Pulled By: pbelevich fbshipit-source-id: 1b2f576ab5d2552f71691921109f731fe41c13e5
1 parent abf5bbc commit 96ec4fb

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

test/test_csprng.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -377,11 +377,12 @@ def create_aes(m, k):
377377
for initial_size in [0, 4, 8, 15, 16, 23, 42]:
378378
initial = torch.empty(initial_size, dtype=initial_dtype).random_()
379379
initial_np = initial.numpy().view(np.int8)
380+
initial_size_bytes = initial_size * sizeof(initial_dtype)
380381
for encrypted_dtype in self.all_dtypes:
381-
encrypted_size = (initial_size * sizeof(initial_dtype) + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(encrypted_dtype)
382+
encrypted_size = (initial_size_bytes + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(encrypted_dtype)
382383
encrypted = torch.zeros(encrypted_size, dtype=encrypted_dtype)
383384
for decrypted_dtype in self.all_dtypes:
384-
decrypted_size = (encrypted_size * sizeof(encrypted_dtype) + block_size_bytes - 1) // block_size_bytes * block_size_bytes // sizeof(decrypted_dtype)
385+
decrypted_size = (initial_size_bytes + sizeof(decrypted_dtype) - 1) // sizeof(decrypted_dtype)
385386
decrypted = torch.zeros(decrypted_size, dtype=decrypted_dtype)
386387
for mode in ["ecb", "ctr"]:
387388
for device in self.all_devices:
@@ -399,16 +400,13 @@ def create_aes(m, k):
399400
self.assertTrue(np.array_equal(encrypted_np, encrypted_expected))
400401

401402
csprng.decrypt(encrypted, decrypted, key, "aes128", mode)
402-
decrypted_np = decrypted.cpu().numpy().view(np.int8)
403+
decrypted_np = decrypted.cpu().numpy().view(np.int8)[:initial_size_bytes]
403404

404405
aes = create_aes(mode, key_np)
405406

406-
decrypted_expected = np.frombuffer(aes.decrypt(pad(encrypted_np.tobytes(), block_size_bytes)), dtype=np.int8)
407+
decrypted_expected = np.frombuffer(aes.decrypt(pad(encrypted_np.tobytes(), block_size_bytes)), dtype=np.int8)[:initial_size_bytes]
407408
self.assertTrue(np.array_equal(decrypted_np, decrypted_expected))
408409

409-
padding_size_bytes = initial_size * sizeof(initial_dtype) - decrypted_size * sizeof(decrypted_dtype)
410-
if padding_size_bytes != 0:
411-
decrypted_np = decrypted_np[:padding_size_bytes]
412410
self.assertTrue(np.array_equal(initial_np, decrypted_np))
413411

414412
if __name__ == '__main__':

torchcsprng/csrc/kernels_body.inc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,8 @@ Tensor decrypt(Tensor input, Tensor output, Tensor key, const std::string& ciphe
420420
TORCH_CHECK(input.device() == output.device() && input.device() == key.device(), "input, output and key tensors must have the same device");
421421
const auto output_size_bytes = output.numel() * output.itemsize();
422422
const auto input_size_bytes = input.numel() * input.itemsize();
423-
TORCH_CHECK(output_size_bytes == input_size_bytes, "input and output tensors must have the same size in byte");
423+
const auto diff = input_size_bytes - output_size_bytes;
424+
TORCH_CHECK(0 <= diff && diff < aes::block_t_size, "output tensor size in bytes must be less then or equal to input tensor size in bytes, the difference must be less than block size");
424425
TORCH_CHECK(input_size_bytes % aes::block_t_size == 0, "input tensor size in bytes must divisible by cipher block size in bytes");
425426
check_cipher(cipher, key);
426427
const auto key_bytes = reinterpret_cast<uint8_t*>(key.contiguous().data_ptr());

0 commit comments

Comments
 (0)