diff --git a/extension/android/executorch_android/build.gradle b/extension/android/executorch_android/build.gradle index fac08588740..2fa0b9fd57c 100644 --- a/extension/android/executorch_android/build.gradle +++ b/extension/android/executorch_android/build.gradle @@ -49,6 +49,7 @@ dependencies { implementation 'com.facebook.soloader:nativeloader:0.10.5' implementation libs.core.ktx testImplementation 'junit:junit:4.12' + testImplementation 'org.assertj:assertj-core:3.27.2' androidTestImplementation 'androidx.test.ext:junit:1.1.5' androidTestImplementation 'androidx.test:rules:1.2.0' androidTestImplementation 'commons-io:commons-io:2.4' diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index 1a30baba2f1..62535156a52 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -394,7 +394,9 @@ public byte[] getDataAsByteArray() { */ public byte[] getDataAsUnsignedByteArray() { throw new IllegalStateException( - "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); + "Tensor of type " + + getClass().getSimpleName() + + " cannot return data as unsigned byte array."); } /** diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt index 0e56480d621..7e9fea9a699 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt @@ -7,218 +7,211 @@ */ package org.pytorch.executorch -import org.junit.Assert +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -/** Unit tests for [EValue]. */ +/** Unit tests for [EValue]. */ @RunWith(JUnit4::class) class EValueTest { - @Test - fun testNone() { - val evalue = EValue.optionalNone() - Assert.assertTrue(evalue.isNone) + @Test + fun testNone() { + val evalue = EValue.optionalNone() + assertTrue(evalue.isNone) + } + + @Test + fun testTensorValue() { + val data = longArrayOf(1, 2, 3) + val shape = longArrayOf(1, 3) + val evalue = EValue.from(Tensor.fromBlob(data, shape)) + assertTrue(evalue.isTensor) + assertTrue(evalue.toTensor().shape.contentEquals(shape)) + assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data)) + } + + @Test + fun testBoolValue() { + val evalue = EValue.from(true) + assertTrue(evalue.isBool) + assertTrue(evalue.toBool()) + } + + @Test + fun testIntValue() { + val evalue = EValue.from(1) + assertTrue(evalue.isInt) + assertEquals(evalue.toInt(), 1) + } + + @Test + fun testDoubleValue() { + val evalue = EValue.from(0.1) + assertTrue(evalue.isDouble) + assertEquals(evalue.toDouble(), 0.1, 0.0001) + } + + @Test + fun testStringValue() { + val evalue = EValue.from("a") + assertTrue(evalue.isString) + assertEquals(evalue.toStr(), "a") + } + + @Test + fun testAllIllegalCast() { + val evalue = EValue.optionalNone() + assertTrue(evalue.isNone) + + // try Tensor + assertFalse(evalue.isTensor) + assertThatThrownBy { evalue.toTensor() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Tensor, actual type None") + + // try bool + assertFalse(evalue.isBool) + assertThatThrownBy { evalue.toBool() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Bool, actual type None") + + // try int + assertFalse(evalue.isInt) + assertThatThrownBy { evalue.toInt() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Int, actual type None") + + // try double + assertFalse(evalue.isDouble) + assertThatThrownBy { evalue.toDouble() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Double, actual type None") + + // try string + assertFalse(evalue.isString) + assertThatThrownBy { evalue.toStr() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type String, actual type None") + } + + @Test + fun testNoneSerde() { + val evalue = EValue.optionalNone() + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isNone, true) + } + + @Test + fun testBoolSerde() { + val evalue = EValue.from(true) + val bytes = evalue.toByteArray() + assertEquals(1, bytes[1].toLong()) + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isBool, true) + assertEquals(deser.toBool(), true) + } + + @Test + fun testBoolSerde2() { + val evalue = EValue.from(false) + val bytes = evalue.toByteArray() + assertEquals(0, bytes[1].toLong()) + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isBool, true) + assertEquals(deser.toBool(), false) + } + + @Test + fun testIntSerde() { + val evalue = EValue.from(1) + val bytes = evalue.toByteArray() + assertEquals(0, bytes[1].toLong()) + assertEquals(0, bytes[2].toLong()) + assertEquals(0, bytes[3].toLong()) + assertEquals(0, bytes[4].toLong()) + assertEquals(0, bytes[5].toLong()) + assertEquals(0, bytes[6].toLong()) + assertEquals(0, bytes[7].toLong()) + assertEquals(1, bytes[8].toLong()) + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isInt, true) + assertEquals(deser.toInt(), 1) + } + + @Test + fun testLargeIntSerde() { + val evalue = EValue.from(256000) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isInt, true) + assertEquals(deser.toInt(), 256000) + } + + @Test + fun testDoubleSerde() { + val evalue = EValue.from(1.345e-2) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isDouble, true) + assertEquals(1.345e-2, deser.toDouble(), 1e-6) + } + + @Test + fun testLongTensorSerde() { + val data = longArrayOf(1, 2, 3, 4) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + + val evalue = EValue.from(tensor) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isTensor, true) + val deserTensor = deser.toTensor() + val deserShape = deserTensor.shape() + val deserData = deserTensor.dataAsLongArray + + for (i in data.indices) { + assertEquals(data[i], deserData[i]) } - @Test - fun testTensorValue() { - val data = longArrayOf(1, 2, 3) - val shape = longArrayOf(1, 3) - val evalue = EValue.from(Tensor.fromBlob(data, shape)) - Assert.assertTrue(evalue.isTensor) - Assert.assertTrue(evalue.toTensor().shape.contentEquals(shape)) - Assert.assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data)) + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } - @Test - fun testBoolValue() { - val evalue = EValue.from(true) - Assert.assertTrue(evalue.isBool) - Assert.assertTrue(evalue.toBool()) - } - - @Test - fun testIntValue() { - val evalue = EValue.from(1) - Assert.assertTrue(evalue.isInt) - Assert.assertEquals(evalue.toInt(), 1) - } - - @Test - fun testDoubleValue() { - val evalue = EValue.from(0.1) - Assert.assertTrue(evalue.isDouble) - Assert.assertEquals(evalue.toDouble(), 0.1, 0.0001) - } - - @Test - fun testStringValue() { - val evalue = EValue.from("a") - Assert.assertTrue(evalue.isString) - Assert.assertEquals(evalue.toStr(), "a") - } - - @Test - fun testAllIllegalCast() { - val evalue = EValue.optionalNone() - Assert.assertTrue(evalue.isNone) - - // try Tensor - Assert.assertFalse(evalue.isTensor) - try { - evalue.toTensor() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } - - // try bool - Assert.assertFalse(evalue.isBool) - try { - evalue.toBool() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } - - // try int - Assert.assertFalse(evalue.isInt) - try { - evalue.toInt() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } - - // try double - Assert.assertFalse(evalue.isDouble) - try { - evalue.toDouble() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } - - // try string - Assert.assertFalse(evalue.isString) - try { - evalue.toStr() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } - } - - @Test - fun testNoneSerde() { - val evalue = EValue.optionalNone() - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isNone, true) - } - - @Test - fun testBoolSerde() { - val evalue = EValue.from(true) - val bytes = evalue.toByteArray() - Assert.assertEquals(1, bytes[1].toLong()) - - val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isBool, true) - Assert.assertEquals(deser.toBool(), true) - } - - @Test - fun testBoolSerde2() { - val evalue = EValue.from(false) - val bytes = evalue.toByteArray() - Assert.assertEquals(0, bytes[1].toLong()) + @Test + fun testFloatTensorSerde() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) - val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isBool, true) - Assert.assertEquals(deser.toBool(), false) - } - - @Test - fun testIntSerde() { - val evalue = EValue.from(1) - val bytes = evalue.toByteArray() - Assert.assertEquals(0, bytes[1].toLong()) - Assert.assertEquals(0, bytes[2].toLong()) - Assert.assertEquals(0, bytes[3].toLong()) - Assert.assertEquals(0, bytes[4].toLong()) - Assert.assertEquals(0, bytes[5].toLong()) - Assert.assertEquals(0, bytes[6].toLong()) - Assert.assertEquals(0, bytes[7].toLong()) - Assert.assertEquals(1, bytes[8].toLong()) - - val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isInt, true) - Assert.assertEquals(deser.toInt(), 1) - } - - @Test - fun testLargeIntSerde() { - val evalue = EValue.from(256000) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isInt, true) - Assert.assertEquals(deser.toInt(), 256000) - } - - @Test - fun testDoubleSerde() { - val evalue = EValue.from(1.345e-2) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isDouble, true) - Assert.assertEquals(1.345e-2, deser.toDouble(), 1e-6) - } - - @Test - fun testLongTensorSerde() { - val data = longArrayOf(1, 2, 3, 4) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - - val evalue = EValue.from(tensor) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isTensor, true) - val deserTensor = deser.toTensor() - val deserShape = deserTensor.shape() - val deserData = deserTensor.dataAsLongArray + val evalue = EValue.from(tensor) + val bytes = evalue.toByteArray() - for (i in data.indices) { - Assert.assertEquals(data[i], deserData[i]) - } + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isTensor, true) + val deserTensor = deser.toTensor() + val deserShape = deserTensor.shape() + val deserData = deserTensor.dataAsFloatArray - for (i in shape.indices) { - Assert.assertEquals(shape[i], deserShape[i]) - } + for (i in data.indices) { + assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) } - @Test - fun testFloatTensorSerde() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - - val evalue = EValue.from(tensor) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isTensor, true) - val deserTensor = deser.toTensor() - val deserShape = deserTensor.shape() - val deserData = deserTensor.dataAsFloatArray - - for (i in data.indices) { - Assert.assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) - } - - for (i in shape.indices) { - Assert.assertEquals(shape[i], deserShape[i]) - } + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } } diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt index 4b206c8efbd..e59b40030d7 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt @@ -7,290 +7,271 @@ */ package org.pytorch.executorch -import org.junit.Assert +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.Assert.assertEquals import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -/** Unit tests for [Tensor]. */ +/** Unit tests for [Tensor]. */ @RunWith(JUnit4::class) class TensorTest { - @Test - fun testFloatTensor() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.FLOAT) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) - Assert.assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) - Assert.assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) - Assert.assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) - - val floatBuffer = Tensor.allocateFloatBuffer(4) - floatBuffer.put(data) - tensor = Tensor.fromBlob(floatBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.FLOAT) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) - Assert.assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) - Assert.assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) - Assert.assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) - } + @Test + fun testFloatTensor() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.FLOAT) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) - @Test - fun testIntTensor() { - val data = intArrayOf(Int.MIN_VALUE, 0, 1, Int.MAX_VALUE) - val shape = longArrayOf(1, 4, 1) - var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.INT32) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) - - val intBuffer = Tensor.allocateIntBuffer(4) - intBuffer.put(data) - tensor = Tensor.fromBlob(intBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.INT32) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) - } + val floatBuffer = Tensor.allocateFloatBuffer(4) + floatBuffer.put(data) + tensor = Tensor.fromBlob(floatBuffer, shape) + assertEquals(tensor.dtype(), DType.FLOAT) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + } - @Test - fun testDoubleTensor() { - val data = doubleArrayOf(Double.MIN_VALUE, 0.0, 0.1, Double.MAX_VALUE) - val shape = longArrayOf(1, 4) - var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.DOUBLE) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) - Assert.assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) - Assert.assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) - Assert.assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) - - val doubleBuffer = Tensor.allocateDoubleBuffer(4) - doubleBuffer.put(data) - tensor = Tensor.fromBlob(doubleBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.DOUBLE) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) - Assert.assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) - Assert.assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) - Assert.assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) - } + @Test + fun testIntTensor() { + val data = intArrayOf(Int.MIN_VALUE, 0, 1, Int.MAX_VALUE) + val shape = longArrayOf(1, 4, 1) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.INT32) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) - @Test - fun testLongTensor() { - val data = longArrayOf(Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE) - val shape = longArrayOf(4, 1) - var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.INT64) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0], tensor.dataAsLongArray[0]) - Assert.assertEquals(data[1], tensor.dataAsLongArray[1]) - Assert.assertEquals(data[2], tensor.dataAsLongArray[2]) - Assert.assertEquals(data[3], tensor.dataAsLongArray[3]) - - val longBuffer = Tensor.allocateLongBuffer(4) - longBuffer.put(data) - tensor = Tensor.fromBlob(longBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.INT64) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0], tensor.dataAsLongArray[0]) - Assert.assertEquals(data[1], tensor.dataAsLongArray[1]) - Assert.assertEquals(data[2], tensor.dataAsLongArray[2]) - Assert.assertEquals(data[3], tensor.dataAsLongArray[3]) - } + val intBuffer = Tensor.allocateIntBuffer(4) + intBuffer.put(data) + tensor = Tensor.fromBlob(intBuffer, shape) + assertEquals(tensor.dtype(), DType.INT32) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + } - @Test - fun testSignedByteTensor() { - val data = byteArrayOf(Byte.MIN_VALUE, 0.toByte(), 1.toByte(), Byte.MAX_VALUE) - val shape = longArrayOf(1, 1, 4) - var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.INT8) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) - - val byteBuffer = Tensor.allocateByteBuffer(4) - byteBuffer.put(data) - tensor = Tensor.fromBlob(byteBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.INT8) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) - } + @Test + fun testDoubleTensor() { + val data = doubleArrayOf(Double.MIN_VALUE, 0.0, 0.1, Double.MAX_VALUE) + val shape = longArrayOf(1, 4) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.DOUBLE) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) - @Test - fun testUnsignedByteTensor() { - val data = byteArrayOf(0.toByte(), 1.toByte(), 2.toByte(), 255.toByte()) - val shape = longArrayOf(4, 1, 1) - var tensor = Tensor.fromBlobUnsigned(data, shape) - Assert.assertEquals(tensor.dtype(), DType.UINT8) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) - - val byteBuffer = Tensor.allocateByteBuffer(4) - byteBuffer.put(data) - tensor = Tensor.fromBlobUnsigned(byteBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.UINT8) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) - } + val doubleBuffer = Tensor.allocateDoubleBuffer(4) + doubleBuffer.put(data) + tensor = Tensor.fromBlob(doubleBuffer, shape) + assertEquals(tensor.dtype(), DType.DOUBLE) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) + } - @Test - fun testIllegalDataTypeException() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.FLOAT) - - try { - tensor.dataAsByteArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } - try { - tensor.dataAsUnsignedByteArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } - try { - tensor.dataAsIntArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } - try { - tensor.dataAsDoubleArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } - try { - tensor.dataAsLongArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } - } + @Test + fun testLongTensor() { + val data = longArrayOf(Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE) + val shape = longArrayOf(4, 1) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.INT64) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsLongArray[0]) + assertEquals(data[1], tensor.dataAsLongArray[1]) + assertEquals(data[2], tensor.dataAsLongArray[2]) + assertEquals(data[3], tensor.dataAsLongArray[3]) - @Test - fun testIllegalArguments() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shapeWithNegativeValues = longArrayOf(-1, 2) - val mismatchShape = longArrayOf(1, 2) - - try { - val tensor = Tensor.fromBlob(null as FloatArray?, mismatchShape) - Assert.fail("Should have thrown an exception") - } catch (e: IllegalArgumentException) { - // expected - } - try { - val tensor = Tensor.fromBlob(data, null) - Assert.fail("Should have thrown an exception") - } catch (e: IllegalArgumentException) { - // expected - } - try { - val tensor = Tensor.fromBlob(data, shapeWithNegativeValues) - Assert.fail("Should have thrown an exception") - } catch (e: IllegalArgumentException) { - // expected - } - try { - val tensor = Tensor.fromBlob(data, mismatchShape) - Assert.fail("Should have thrown an exception") - } catch (e: IllegalArgumentException) { - // expected - } - } + val longBuffer = Tensor.allocateLongBuffer(4) + longBuffer.put(data) + tensor = Tensor.fromBlob(longBuffer, shape) + assertEquals(tensor.dtype(), DType.INT64) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsLongArray[0]) + assertEquals(data[1], tensor.dataAsLongArray[1]) + assertEquals(data[2], tensor.dataAsLongArray[2]) + assertEquals(data[3], tensor.dataAsLongArray[3]) + } + + @Test + fun testSignedByteTensor() { + val data = byteArrayOf(Byte.MIN_VALUE, 0.toByte(), 1.toByte(), Byte.MAX_VALUE) + val shape = longArrayOf(1, 1, 4) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.INT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + + val byteBuffer = Tensor.allocateByteBuffer(4) + byteBuffer.put(data) + tensor = Tensor.fromBlob(byteBuffer, shape) + assertEquals(tensor.dtype(), DType.INT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + } + + @Test + fun testUnsignedByteTensor() { + val data = byteArrayOf(0.toByte(), 1.toByte(), 2.toByte(), 255.toByte()) + val shape = longArrayOf(4, 1, 1) + var tensor = Tensor.fromBlobUnsigned(data, shape) + assertEquals(tensor.dtype(), DType.UINT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + + val byteBuffer = Tensor.allocateByteBuffer(4) + byteBuffer.put(data) + tensor = Tensor.fromBlobUnsigned(byteBuffer, shape) + assertEquals(tensor.dtype(), DType.UINT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + } + + @Test + fun testIllegalDataTypeException() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.FLOAT) - @Test - fun testLongTensorSerde() { - val data = longArrayOf(1, 2, 3, 4) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - val bytes = tensor.toByteArray() + assertThatThrownBy { tensor.dataAsByteArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") - val deser = Tensor.fromByteArray(bytes) - val deserShape = deser.shape() - val deserData = deser.dataAsLongArray + assertThatThrownBy { tensor.dataAsUnsignedByteArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") - for (i in data.indices) { - Assert.assertEquals(data[i], deserData[i]) - } + assertThatThrownBy { tensor.dataAsIntArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") - for (i in shape.indices) { - Assert.assertEquals(shape[i], deserShape[i]) - } + assertThatThrownBy { tensor.dataAsDoubleArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") + + assertThatThrownBy { tensor.dataAsLongArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") + } + + @Test + fun testIllegalArguments() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shapeWithNegativeValues = longArrayOf(-1, 2) + val mismatchShape = longArrayOf(1, 2) + + assertThatThrownBy { Tensor.fromBlob(null as FloatArray?, mismatchShape) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Data array must be not null") + + assertThatThrownBy { Tensor.fromBlob(data, null) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Shape must be not null") + + assertThatThrownBy { Tensor.fromBlob(data, shapeWithNegativeValues) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Shape elements must be non negative") + + assertThatThrownBy { Tensor.fromBlob(data, mismatchShape) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") + } + + @Test + fun testLongTensorSerde() { + val data = longArrayOf(1, 2, 3, 4) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + val bytes = tensor.toByteArray() + + val deser = Tensor.fromByteArray(bytes) + val deserShape = deser.shape() + val deserData = deser.dataAsLongArray + + for (i in data.indices) { + assertEquals(data[i], deserData[i]) } - @Test - fun testFloatTensorSerde() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - val bytes = tensor.toByteArray() + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) + } + } + + @Test + fun testFloatTensorSerde() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + val bytes = tensor.toByteArray() - val deser = Tensor.fromByteArray(bytes) - val deserShape = deser.shape() - val deserData = deser.dataAsFloatArray + val deser = Tensor.fromByteArray(bytes) + val deserShape = deser.shape() + val deserData = deser.dataAsFloatArray - for (i in data.indices) { - Assert.assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) - } + for (i in data.indices) { + assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) + } - for (i in shape.indices) { - Assert.assertEquals(shape[i], deserShape[i]) - } + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } }