From 6dec17bfd77d4f0b603a96aff550427df17e3f3c Mon Sep 17 00:00:00 2001 From: Haiting Pu Date: Thu, 8 May 2025 13:49:11 -0700 Subject: [PATCH 1/3] Introduce assertj test lib to make the throw exception test more accurate --- .../android/executorch_android/build.gradle | 1 + .../java/org/pytorch/executorch/Tensor.java | 2 +- .../java/org/pytorch/executorch/EValueTest.kt | 130 ++++---- .../java/org/pytorch/executorch/TensorTest.kt | 296 ++++++++---------- 4 files changed, 195 insertions(+), 234 deletions(-) diff --git a/extension/android/executorch_android/build.gradle b/extension/android/executorch_android/build.gradle index fac08588740..2fa0b9fd57c 100644 --- a/extension/android/executorch_android/build.gradle +++ b/extension/android/executorch_android/build.gradle @@ -49,6 +49,7 @@ dependencies { implementation 'com.facebook.soloader:nativeloader:0.10.5' implementation libs.core.ktx testImplementation 'junit:junit:4.12' + testImplementation 'org.assertj:assertj-core:3.27.2' androidTestImplementation 'androidx.test.ext:junit:1.1.5' androidTestImplementation 'androidx.test:rules:1.2.0' androidTestImplementation 'commons-io:commons-io:2.4' diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index 1a30baba2f1..174d08c0365 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -394,7 +394,7 @@ public byte[] getDataAsByteArray() { */ public byte[] getDataAsUnsignedByteArray() { throw new IllegalStateException( - "Tensor of type " + getClass().getSimpleName() + " cannot return data as byte array."); + "Tensor of type " + getClass().getSimpleName() + " cannot return data as unsigned byte array."); } /** diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt index 0e56480d621..a2cc328fa37 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt @@ -7,7 +7,10 @@ */ package org.pytorch.executorch -import org.junit.Assert +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 @@ -18,7 +21,7 @@ class EValueTest { @Test fun testNone() { val evalue = EValue.optionalNone() - Assert.assertTrue(evalue.isNone) + assertTrue(evalue.isNone) } @Test @@ -26,83 +29,68 @@ class EValueTest { val data = longArrayOf(1, 2, 3) val shape = longArrayOf(1, 3) val evalue = EValue.from(Tensor.fromBlob(data, shape)) - Assert.assertTrue(evalue.isTensor) - Assert.assertTrue(evalue.toTensor().shape.contentEquals(shape)) - Assert.assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data)) + assertTrue(evalue.isTensor) + assertTrue(evalue.toTensor().shape.contentEquals(shape)) + assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data)) } @Test fun testBoolValue() { val evalue = EValue.from(true) - Assert.assertTrue(evalue.isBool) - Assert.assertTrue(evalue.toBool()) + assertTrue(evalue.isBool) + assertTrue(evalue.toBool()) } @Test fun testIntValue() { val evalue = EValue.from(1) - Assert.assertTrue(evalue.isInt) - Assert.assertEquals(evalue.toInt(), 1) + assertTrue(evalue.isInt) + assertEquals(evalue.toInt(), 1) } @Test fun testDoubleValue() { val evalue = EValue.from(0.1) - Assert.assertTrue(evalue.isDouble) - Assert.assertEquals(evalue.toDouble(), 0.1, 0.0001) + assertTrue(evalue.isDouble) + assertEquals(evalue.toDouble(), 0.1, 0.0001) } @Test fun testStringValue() { val evalue = EValue.from("a") - Assert.assertTrue(evalue.isString) - Assert.assertEquals(evalue.toStr(), "a") + assertTrue(evalue.isString) + assertEquals(evalue.toStr(), "a") } @Test fun testAllIllegalCast() { val evalue = EValue.optionalNone() - Assert.assertTrue(evalue.isNone) + assertTrue(evalue.isNone) // try Tensor - Assert.assertFalse(evalue.isTensor) - try { - evalue.toTensor() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } + assertFalse(evalue.isTensor) + assertThatThrownBy { + evalue.toTensor() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Tensor, actual type None") // try bool - Assert.assertFalse(evalue.isBool) - try { - evalue.toBool() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } + assertFalse(evalue.isBool) + assertThatThrownBy { + evalue.toBool() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Bool, actual type None") // try int - Assert.assertFalse(evalue.isInt) - try { - evalue.toInt() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } + assertFalse(evalue.isInt) + assertThatThrownBy { + evalue.toInt() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Int, actual type None") // try double - Assert.assertFalse(evalue.isDouble) - try { - evalue.toDouble() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } + assertFalse(evalue.isDouble) + assertThatThrownBy { + evalue.toDouble() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Double, actual type None") // try string - Assert.assertFalse(evalue.isString) - try { - evalue.toStr() - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - } + assertFalse(evalue.isString) + assertThatThrownBy { + evalue.toStr() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type String, actual type None") } @Test @@ -111,47 +99,47 @@ class EValueTest { val bytes = evalue.toByteArray() val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isNone, true) + assertEquals(deser.isNone, true) } @Test fun testBoolSerde() { val evalue = EValue.from(true) val bytes = evalue.toByteArray() - Assert.assertEquals(1, bytes[1].toLong()) + assertEquals(1, bytes[1].toLong()) val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isBool, true) - Assert.assertEquals(deser.toBool(), true) + assertEquals(deser.isBool, true) + assertEquals(deser.toBool(), true) } @Test fun testBoolSerde2() { val evalue = EValue.from(false) val bytes = evalue.toByteArray() - Assert.assertEquals(0, bytes[1].toLong()) + assertEquals(0, bytes[1].toLong()) val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isBool, true) - Assert.assertEquals(deser.toBool(), false) + assertEquals(deser.isBool, true) + assertEquals(deser.toBool(), false) } @Test fun testIntSerde() { val evalue = EValue.from(1) val bytes = evalue.toByteArray() - Assert.assertEquals(0, bytes[1].toLong()) - Assert.assertEquals(0, bytes[2].toLong()) - Assert.assertEquals(0, bytes[3].toLong()) - Assert.assertEquals(0, bytes[4].toLong()) - Assert.assertEquals(0, bytes[5].toLong()) - Assert.assertEquals(0, bytes[6].toLong()) - Assert.assertEquals(0, bytes[7].toLong()) - Assert.assertEquals(1, bytes[8].toLong()) + assertEquals(0, bytes[1].toLong()) + assertEquals(0, bytes[2].toLong()) + assertEquals(0, bytes[3].toLong()) + assertEquals(0, bytes[4].toLong()) + assertEquals(0, bytes[5].toLong()) + assertEquals(0, bytes[6].toLong()) + assertEquals(0, bytes[7].toLong()) + assertEquals(1, bytes[8].toLong()) val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isInt, true) - Assert.assertEquals(deser.toInt(), 1) + assertEquals(deser.isInt, true) + assertEquals(deser.toInt(), 1) } @Test @@ -160,8 +148,8 @@ class EValueTest { val bytes = evalue.toByteArray() val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isInt, true) - Assert.assertEquals(deser.toInt(), 256000) + assertEquals(deser.isInt, true) + assertEquals(deser.toInt(), 256000) } @Test @@ -170,8 +158,8 @@ class EValueTest { val bytes = evalue.toByteArray() val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isDouble, true) - Assert.assertEquals(1.345e-2, deser.toDouble(), 1e-6) + assertEquals(deser.isDouble, true) + assertEquals(1.345e-2, deser.toDouble(), 1e-6) } @Test @@ -184,17 +172,17 @@ class EValueTest { val bytes = evalue.toByteArray() val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isTensor, true) + assertEquals(deser.isTensor, true) val deserTensor = deser.toTensor() val deserShape = deserTensor.shape() val deserData = deserTensor.dataAsLongArray for (i in data.indices) { - Assert.assertEquals(data[i], deserData[i]) + assertEquals(data[i], deserData[i]) } for (i in shape.indices) { - Assert.assertEquals(shape[i], deserShape[i]) + assertEquals(shape[i], deserShape[i]) } } @@ -208,17 +196,17 @@ class EValueTest { val bytes = evalue.toByteArray() val deser = EValue.fromByteArray(bytes) - Assert.assertEquals(deser.isTensor, true) + assertEquals(deser.isTensor, true) val deserTensor = deser.toTensor() val deserShape = deserTensor.shape() val deserData = deserTensor.dataAsFloatArray for (i in data.indices) { - Assert.assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) + assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) } for (i in shape.indices) { - Assert.assertEquals(shape[i], deserShape[i]) + assertEquals(shape[i], deserShape[i]) } } } diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt index 4b206c8efbd..45265b52362 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt @@ -7,7 +7,8 @@ */ package org.pytorch.executorch -import org.junit.Assert +import org.assertj.core.api.Assertions.assertThatThrownBy +import org.junit.Assert.assertEquals import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 @@ -20,26 +21,26 @@ class TensorTest { val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) val shape = longArrayOf(2, 2) var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.FLOAT) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) - Assert.assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) - Assert.assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) - Assert.assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + assertEquals(tensor.dtype(), DType.FLOAT) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) val floatBuffer = Tensor.allocateFloatBuffer(4) floatBuffer.put(data) tensor = Tensor.fromBlob(floatBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.FLOAT) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) - Assert.assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) - Assert.assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) - Assert.assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + assertEquals(tensor.dtype(), DType.FLOAT) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) } @Test @@ -47,28 +48,28 @@ class TensorTest { val data = intArrayOf(Int.MIN_VALUE, 0, 1, Int.MAX_VALUE) val shape = longArrayOf(1, 4, 1) var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.INT32) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + assertEquals(tensor.dtype(), DType.INT32) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) val intBuffer = Tensor.allocateIntBuffer(4) intBuffer.put(data) tensor = Tensor.fromBlob(intBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.INT32) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + assertEquals(tensor.dtype(), DType.INT32) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) } @Test @@ -76,26 +77,26 @@ class TensorTest { val data = doubleArrayOf(Double.MIN_VALUE, 0.0, 0.1, Double.MAX_VALUE) val shape = longArrayOf(1, 4) var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.DOUBLE) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) - Assert.assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) - Assert.assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) - Assert.assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) + assertEquals(tensor.dtype(), DType.DOUBLE) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) val doubleBuffer = Tensor.allocateDoubleBuffer(4) doubleBuffer.put(data) tensor = Tensor.fromBlob(doubleBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.DOUBLE) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) - Assert.assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) - Assert.assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) - Assert.assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) + assertEquals(tensor.dtype(), DType.DOUBLE) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) } @Test @@ -103,26 +104,26 @@ class TensorTest { val data = longArrayOf(Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE) val shape = longArrayOf(4, 1) var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.INT64) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0], tensor.dataAsLongArray[0]) - Assert.assertEquals(data[1], tensor.dataAsLongArray[1]) - Assert.assertEquals(data[2], tensor.dataAsLongArray[2]) - Assert.assertEquals(data[3], tensor.dataAsLongArray[3]) + assertEquals(tensor.dtype(), DType.INT64) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsLongArray[0]) + assertEquals(data[1], tensor.dataAsLongArray[1]) + assertEquals(data[2], tensor.dataAsLongArray[2]) + assertEquals(data[3], tensor.dataAsLongArray[3]) val longBuffer = Tensor.allocateLongBuffer(4) longBuffer.put(data) tensor = Tensor.fromBlob(longBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.INT64) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0], tensor.dataAsLongArray[0]) - Assert.assertEquals(data[1], tensor.dataAsLongArray[1]) - Assert.assertEquals(data[2], tensor.dataAsLongArray[2]) - Assert.assertEquals(data[3], tensor.dataAsLongArray[3]) + assertEquals(tensor.dtype(), DType.INT64) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsLongArray[0]) + assertEquals(data[1], tensor.dataAsLongArray[1]) + assertEquals(data[2], tensor.dataAsLongArray[2]) + assertEquals(data[3], tensor.dataAsLongArray[3]) } @Test @@ -130,28 +131,28 @@ class TensorTest { val data = byteArrayOf(Byte.MIN_VALUE, 0.toByte(), 1.toByte(), Byte.MAX_VALUE) val shape = longArrayOf(1, 1, 4) var tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.INT8) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + assertEquals(tensor.dtype(), DType.INT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) val byteBuffer = Tensor.allocateByteBuffer(4) byteBuffer.put(data) tensor = Tensor.fromBlob(byteBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.INT8) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + assertEquals(tensor.dtype(), DType.INT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) } @Test @@ -159,28 +160,28 @@ class TensorTest { val data = byteArrayOf(0.toByte(), 1.toByte(), 2.toByte(), 255.toByte()) val shape = longArrayOf(4, 1, 1) var tensor = Tensor.fromBlobUnsigned(data, shape) - Assert.assertEquals(tensor.dtype(), DType.UINT8) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + assertEquals(tensor.dtype(), DType.UINT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) val byteBuffer = Tensor.allocateByteBuffer(4) byteBuffer.put(data) tensor = Tensor.fromBlobUnsigned(byteBuffer, shape) - Assert.assertEquals(tensor.dtype(), DType.UINT8) - Assert.assertEquals(shape[0], tensor.shape()[0]) - Assert.assertEquals(shape[1], tensor.shape()[1]) - Assert.assertEquals(shape[2], tensor.shape()[2]) - Assert.assertEquals(4, tensor.numel()) - Assert.assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) - Assert.assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) - Assert.assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) - Assert.assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + assertEquals(tensor.dtype(), DType.UINT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) } @Test @@ -188,38 +189,22 @@ class TensorTest { val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) val shape = longArrayOf(2, 2) val tensor = Tensor.fromBlob(data, shape) - Assert.assertEquals(tensor.dtype(), DType.FLOAT) + assertEquals(tensor.dtype(), DType.FLOAT) - try { - tensor.dataAsByteArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } - try { - tensor.dataAsUnsignedByteArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } - try { - tensor.dataAsIntArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } - try { - tensor.dataAsDoubleArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } - try { - tensor.dataAsLongArray - Assert.fail("Should have thrown an exception") - } catch (e: IllegalStateException) { - // expected - } + assertThatThrownBy { + tensor.dataAsByteArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") + + assertThatThrownBy { + tensor.dataAsUnsignedByteArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") + + assertThatThrownBy { + tensor.dataAsIntArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") + + assertThatThrownBy { + tensor.dataAsDoubleArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") + + assertThatThrownBy { + tensor.dataAsLongArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") } @Test @@ -228,30 +213,17 @@ class TensorTest { val shapeWithNegativeValues = longArrayOf(-1, 2) val mismatchShape = longArrayOf(1, 2) - try { - val tensor = Tensor.fromBlob(null as FloatArray?, mismatchShape) - Assert.fail("Should have thrown an exception") - } catch (e: IllegalArgumentException) { - // expected - } - try { - val tensor = Tensor.fromBlob(data, null) - Assert.fail("Should have thrown an exception") - } catch (e: IllegalArgumentException) { - // expected - } - try { - val tensor = Tensor.fromBlob(data, shapeWithNegativeValues) - Assert.fail("Should have thrown an exception") - } catch (e: IllegalArgumentException) { - // expected - } - try { - val tensor = Tensor.fromBlob(data, mismatchShape) - Assert.fail("Should have thrown an exception") - } catch (e: IllegalArgumentException) { - // expected - } + assertThatThrownBy { + Tensor.fromBlob(null as FloatArray?, mismatchShape) }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Data array must be not null") + + assertThatThrownBy { + Tensor.fromBlob(data, null) }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Shape must be not null") + + assertThatThrownBy { + Tensor.fromBlob(data, shapeWithNegativeValues) }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Shape elements must be non negative") + + assertThatThrownBy { + Tensor.fromBlob(data, mismatchShape) }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") } @Test @@ -266,11 +238,11 @@ class TensorTest { val deserData = deser.dataAsLongArray for (i in data.indices) { - Assert.assertEquals(data[i], deserData[i]) + assertEquals(data[i], deserData[i]) } for (i in shape.indices) { - Assert.assertEquals(shape[i], deserShape[i]) + assertEquals(shape[i], deserShape[i]) } } @@ -286,11 +258,11 @@ class TensorTest { val deserData = deser.dataAsFloatArray for (i in data.indices) { - Assert.assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) + assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) } for (i in shape.indices) { - Assert.assertEquals(shape[i], deserShape[i]) + assertEquals(shape[i], deserShape[i]) } } } From 05675eebeb043241b3bb3bb1cc5387237f39fd27 Mon Sep 17 00:00:00 2001 From: Haiting Pu Date: Fri, 9 May 2025 14:43:05 -0700 Subject: [PATCH 2/3] apply google java format --- .../java/org/pytorch/executorch/EValueTest.kt | 20 ++++++++--- .../java/org/pytorch/executorch/TensorTest.kt | 35 ++++++++++++++----- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt index a2cc328fa37..edbf6fac9e5 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt @@ -70,27 +70,37 @@ class EValueTest { // try Tensor assertFalse(evalue.isTensor) assertThatThrownBy { - evalue.toTensor() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Tensor, actual type None") + evalue.toTensor() + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Tensor, actual type None") // try bool assertFalse(evalue.isBool) assertThatThrownBy { - evalue.toBool() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Bool, actual type None") + evalue.toBool() + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Bool, actual type None") // try int assertFalse(evalue.isInt) assertThatThrownBy { - evalue.toInt() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Int, actual type None") + evalue.toInt() + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Int, actual type None") // try double assertFalse(evalue.isDouble) assertThatThrownBy { - evalue.toDouble() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type Double, actual type None") + evalue.toDouble() + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Double, actual type None") // try string assertFalse(evalue.isString) assertThatThrownBy { - evalue.toStr() }.isInstanceOf(IllegalStateException::class.java).hasMessage("Expected EValue type String, actual type None") + evalue.toStr() + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type String, actual type None") } @Test diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt index 45265b52362..adc39ebac70 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt @@ -192,19 +192,29 @@ class TensorTest { assertEquals(tensor.dtype(), DType.FLOAT) assertThatThrownBy { - tensor.dataAsByteArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") + tensor.dataAsByteArray + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") assertThatThrownBy { - tensor.dataAsUnsignedByteArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") + tensor.dataAsUnsignedByteArray + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") assertThatThrownBy { - tensor.dataAsIntArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") + tensor.dataAsIntArray + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") assertThatThrownBy { - tensor.dataAsDoubleArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") + tensor.dataAsDoubleArray + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") assertThatThrownBy { - tensor.dataAsLongArray }.isInstanceOf(IllegalStateException::class.java).hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") + tensor.dataAsLongArray + }.isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") } @Test @@ -214,16 +224,23 @@ class TensorTest { val mismatchShape = longArrayOf(1, 2) assertThatThrownBy { - Tensor.fromBlob(null as FloatArray?, mismatchShape) }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Data array must be not null") + Tensor.fromBlob(null as FloatArray?, mismatchShape) + }.isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Data array must be not null") assertThatThrownBy { - Tensor.fromBlob(data, null) }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Shape must be not null") + Tensor.fromBlob(data, null) + }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Shape must be not null") assertThatThrownBy { - Tensor.fromBlob(data, shapeWithNegativeValues) }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Shape elements must be non negative") + Tensor.fromBlob(data, shapeWithNegativeValues) + }.isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Shape elements must be non negative") assertThatThrownBy { - Tensor.fromBlob(data, mismatchShape) }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") + Tensor.fromBlob(data, mismatchShape) + }.isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") } @Test From c82c62452b984d477cc8a309ed174228dfb42af2 Mon Sep 17 00:00:00 2001 From: Haiting Pu Date: Fri, 9 May 2025 16:24:43 -0700 Subject: [PATCH 3/3] Fix the code format --- .../java/org/pytorch/executorch/Tensor.java | 4 +- .../java/org/pytorch/executorch/EValueTest.kt | 379 +++++++------ .../java/org/pytorch/executorch/TensorTest.kt | 508 +++++++++--------- 3 files changed, 440 insertions(+), 451 deletions(-) diff --git a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java index 174d08c0365..62535156a52 100644 --- a/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java +++ b/extension/android/executorch_android/src/main/java/org/pytorch/executorch/Tensor.java @@ -394,7 +394,9 @@ public byte[] getDataAsByteArray() { */ public byte[] getDataAsUnsignedByteArray() { throw new IllegalStateException( - "Tensor of type " + getClass().getSimpleName() + " cannot return data as unsigned byte array."); + "Tensor of type " + + getClass().getSimpleName() + + " cannot return data as unsigned byte array."); } /** diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt index edbf6fac9e5..7e9fea9a699 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/EValueTest.kt @@ -15,208 +15,203 @@ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -/** Unit tests for [EValue]. */ +/** Unit tests for [EValue]. */ @RunWith(JUnit4::class) class EValueTest { - @Test - fun testNone() { - val evalue = EValue.optionalNone() - assertTrue(evalue.isNone) + @Test + fun testNone() { + val evalue = EValue.optionalNone() + assertTrue(evalue.isNone) + } + + @Test + fun testTensorValue() { + val data = longArrayOf(1, 2, 3) + val shape = longArrayOf(1, 3) + val evalue = EValue.from(Tensor.fromBlob(data, shape)) + assertTrue(evalue.isTensor) + assertTrue(evalue.toTensor().shape.contentEquals(shape)) + assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data)) + } + + @Test + fun testBoolValue() { + val evalue = EValue.from(true) + assertTrue(evalue.isBool) + assertTrue(evalue.toBool()) + } + + @Test + fun testIntValue() { + val evalue = EValue.from(1) + assertTrue(evalue.isInt) + assertEquals(evalue.toInt(), 1) + } + + @Test + fun testDoubleValue() { + val evalue = EValue.from(0.1) + assertTrue(evalue.isDouble) + assertEquals(evalue.toDouble(), 0.1, 0.0001) + } + + @Test + fun testStringValue() { + val evalue = EValue.from("a") + assertTrue(evalue.isString) + assertEquals(evalue.toStr(), "a") + } + + @Test + fun testAllIllegalCast() { + val evalue = EValue.optionalNone() + assertTrue(evalue.isNone) + + // try Tensor + assertFalse(evalue.isTensor) + assertThatThrownBy { evalue.toTensor() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Tensor, actual type None") + + // try bool + assertFalse(evalue.isBool) + assertThatThrownBy { evalue.toBool() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Bool, actual type None") + + // try int + assertFalse(evalue.isInt) + assertThatThrownBy { evalue.toInt() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Int, actual type None") + + // try double + assertFalse(evalue.isDouble) + assertThatThrownBy { evalue.toDouble() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type Double, actual type None") + + // try string + assertFalse(evalue.isString) + assertThatThrownBy { evalue.toStr() } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Expected EValue type String, actual type None") + } + + @Test + fun testNoneSerde() { + val evalue = EValue.optionalNone() + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isNone, true) + } + + @Test + fun testBoolSerde() { + val evalue = EValue.from(true) + val bytes = evalue.toByteArray() + assertEquals(1, bytes[1].toLong()) + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isBool, true) + assertEquals(deser.toBool(), true) + } + + @Test + fun testBoolSerde2() { + val evalue = EValue.from(false) + val bytes = evalue.toByteArray() + assertEquals(0, bytes[1].toLong()) + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isBool, true) + assertEquals(deser.toBool(), false) + } + + @Test + fun testIntSerde() { + val evalue = EValue.from(1) + val bytes = evalue.toByteArray() + assertEquals(0, bytes[1].toLong()) + assertEquals(0, bytes[2].toLong()) + assertEquals(0, bytes[3].toLong()) + assertEquals(0, bytes[4].toLong()) + assertEquals(0, bytes[5].toLong()) + assertEquals(0, bytes[6].toLong()) + assertEquals(0, bytes[7].toLong()) + assertEquals(1, bytes[8].toLong()) + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isInt, true) + assertEquals(deser.toInt(), 1) + } + + @Test + fun testLargeIntSerde() { + val evalue = EValue.from(256000) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isInt, true) + assertEquals(deser.toInt(), 256000) + } + + @Test + fun testDoubleSerde() { + val evalue = EValue.from(1.345e-2) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isDouble, true) + assertEquals(1.345e-2, deser.toDouble(), 1e-6) + } + + @Test + fun testLongTensorSerde() { + val data = longArrayOf(1, 2, 3, 4) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + + val evalue = EValue.from(tensor) + val bytes = evalue.toByteArray() + + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isTensor, true) + val deserTensor = deser.toTensor() + val deserShape = deserTensor.shape() + val deserData = deserTensor.dataAsLongArray + + for (i in data.indices) { + assertEquals(data[i], deserData[i]) } - @Test - fun testTensorValue() { - val data = longArrayOf(1, 2, 3) - val shape = longArrayOf(1, 3) - val evalue = EValue.from(Tensor.fromBlob(data, shape)) - assertTrue(evalue.isTensor) - assertTrue(evalue.toTensor().shape.contentEquals(shape)) - assertTrue(evalue.toTensor().dataAsLongArray.contentEquals(data)) + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } - @Test - fun testBoolValue() { - val evalue = EValue.from(true) - assertTrue(evalue.isBool) - assertTrue(evalue.toBool()) - } - - @Test - fun testIntValue() { - val evalue = EValue.from(1) - assertTrue(evalue.isInt) - assertEquals(evalue.toInt(), 1) - } - - @Test - fun testDoubleValue() { - val evalue = EValue.from(0.1) - assertTrue(evalue.isDouble) - assertEquals(evalue.toDouble(), 0.1, 0.0001) - } - - @Test - fun testStringValue() { - val evalue = EValue.from("a") - assertTrue(evalue.isString) - assertEquals(evalue.toStr(), "a") - } - - @Test - fun testAllIllegalCast() { - val evalue = EValue.optionalNone() - assertTrue(evalue.isNone) - - // try Tensor - assertFalse(evalue.isTensor) - assertThatThrownBy { - evalue.toTensor() - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type Tensor, actual type None") - - // try bool - assertFalse(evalue.isBool) - assertThatThrownBy { - evalue.toBool() - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type Bool, actual type None") - - // try int - assertFalse(evalue.isInt) - assertThatThrownBy { - evalue.toInt() - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type Int, actual type None") - - // try double - assertFalse(evalue.isDouble) - assertThatThrownBy { - evalue.toDouble() - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type Double, actual type None") - - // try string - assertFalse(evalue.isString) - assertThatThrownBy { - evalue.toStr() - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Expected EValue type String, actual type None") - } - - @Test - fun testNoneSerde() { - val evalue = EValue.optionalNone() - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isNone, true) - } - - @Test - fun testBoolSerde() { - val evalue = EValue.from(true) - val bytes = evalue.toByteArray() - assertEquals(1, bytes[1].toLong()) - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isBool, true) - assertEquals(deser.toBool(), true) - } - - @Test - fun testBoolSerde2() { - val evalue = EValue.from(false) - val bytes = evalue.toByteArray() - assertEquals(0, bytes[1].toLong()) + @Test + fun testFloatTensorSerde() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isBool, true) - assertEquals(deser.toBool(), false) - } - - @Test - fun testIntSerde() { - val evalue = EValue.from(1) - val bytes = evalue.toByteArray() - assertEquals(0, bytes[1].toLong()) - assertEquals(0, bytes[2].toLong()) - assertEquals(0, bytes[3].toLong()) - assertEquals(0, bytes[4].toLong()) - assertEquals(0, bytes[5].toLong()) - assertEquals(0, bytes[6].toLong()) - assertEquals(0, bytes[7].toLong()) - assertEquals(1, bytes[8].toLong()) - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isInt, true) - assertEquals(deser.toInt(), 1) - } - - @Test - fun testLargeIntSerde() { - val evalue = EValue.from(256000) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isInt, true) - assertEquals(deser.toInt(), 256000) - } - - @Test - fun testDoubleSerde() { - val evalue = EValue.from(1.345e-2) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isDouble, true) - assertEquals(1.345e-2, deser.toDouble(), 1e-6) - } - - @Test - fun testLongTensorSerde() { - val data = longArrayOf(1, 2, 3, 4) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - - val evalue = EValue.from(tensor) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isTensor, true) - val deserTensor = deser.toTensor() - val deserShape = deserTensor.shape() - val deserData = deserTensor.dataAsLongArray + val evalue = EValue.from(tensor) + val bytes = evalue.toByteArray() - for (i in data.indices) { - assertEquals(data[i], deserData[i]) - } + val deser = EValue.fromByteArray(bytes) + assertEquals(deser.isTensor, true) + val deserTensor = deser.toTensor() + val deserShape = deserTensor.shape() + val deserData = deserTensor.dataAsFloatArray - for (i in shape.indices) { - assertEquals(shape[i], deserShape[i]) - } + for (i in data.indices) { + assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) } - @Test - fun testFloatTensorSerde() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - - val evalue = EValue.from(tensor) - val bytes = evalue.toByteArray() - - val deser = EValue.fromByteArray(bytes) - assertEquals(deser.isTensor, true) - val deserTensor = deser.toTensor() - val deserShape = deserTensor.shape() - val deserData = deserTensor.dataAsFloatArray - - for (i in data.indices) { - assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) - } - - for (i in shape.indices) { - assertEquals(shape[i], deserShape[i]) - } + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } } diff --git a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt index adc39ebac70..e59b40030d7 100644 --- a/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt +++ b/extension/android/executorch_android/src/test/java/org/pytorch/executorch/TensorTest.kt @@ -13,273 +13,265 @@ import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 -/** Unit tests for [Tensor]. */ +/** Unit tests for [Tensor]. */ @RunWith(JUnit4::class) class TensorTest { - @Test - fun testFloatTensor() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.FLOAT) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) - assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) - assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) - assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) - - val floatBuffer = Tensor.allocateFloatBuffer(4) - floatBuffer.put(data) - tensor = Tensor.fromBlob(floatBuffer, shape) - assertEquals(tensor.dtype(), DType.FLOAT) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) - assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) - assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) - assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + @Test + fun testFloatTensor() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.FLOAT) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + + val floatBuffer = Tensor.allocateFloatBuffer(4) + floatBuffer.put(data) + tensor = Tensor.fromBlob(floatBuffer, shape) + assertEquals(tensor.dtype(), DType.FLOAT) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toDouble(), tensor.dataAsFloatArray[0].toDouble(), 1e-5) + assertEquals(data[1].toDouble(), tensor.dataAsFloatArray[1].toDouble(), 1e-5) + assertEquals(data[2].toDouble(), tensor.dataAsFloatArray[2].toDouble(), 1e-5) + assertEquals(data[3].toDouble(), tensor.dataAsFloatArray[3].toDouble(), 1e-5) + } + + @Test + fun testIntTensor() { + val data = intArrayOf(Int.MIN_VALUE, 0, 1, Int.MAX_VALUE) + val shape = longArrayOf(1, 4, 1) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.INT32) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + + val intBuffer = Tensor.allocateIntBuffer(4) + intBuffer.put(data) + tensor = Tensor.fromBlob(intBuffer, shape) + assertEquals(tensor.dtype(), DType.INT32) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + } + + @Test + fun testDoubleTensor() { + val data = doubleArrayOf(Double.MIN_VALUE, 0.0, 0.1, Double.MAX_VALUE) + val shape = longArrayOf(1, 4) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.DOUBLE) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) + + val doubleBuffer = Tensor.allocateDoubleBuffer(4) + doubleBuffer.put(data) + tensor = Tensor.fromBlob(doubleBuffer, shape) + assertEquals(tensor.dtype(), DType.DOUBLE) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) + assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) + assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) + assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) + } + + @Test + fun testLongTensor() { + val data = longArrayOf(Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE) + val shape = longArrayOf(4, 1) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.INT64) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsLongArray[0]) + assertEquals(data[1], tensor.dataAsLongArray[1]) + assertEquals(data[2], tensor.dataAsLongArray[2]) + assertEquals(data[3], tensor.dataAsLongArray[3]) + + val longBuffer = Tensor.allocateLongBuffer(4) + longBuffer.put(data) + tensor = Tensor.fromBlob(longBuffer, shape) + assertEquals(tensor.dtype(), DType.INT64) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(4, tensor.numel()) + assertEquals(data[0], tensor.dataAsLongArray[0]) + assertEquals(data[1], tensor.dataAsLongArray[1]) + assertEquals(data[2], tensor.dataAsLongArray[2]) + assertEquals(data[3], tensor.dataAsLongArray[3]) + } + + @Test + fun testSignedByteTensor() { + val data = byteArrayOf(Byte.MIN_VALUE, 0.toByte(), 1.toByte(), Byte.MAX_VALUE) + val shape = longArrayOf(1, 1, 4) + var tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.INT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + + val byteBuffer = Tensor.allocateByteBuffer(4) + byteBuffer.put(data) + tensor = Tensor.fromBlob(byteBuffer, shape) + assertEquals(tensor.dtype(), DType.INT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) + } + + @Test + fun testUnsignedByteTensor() { + val data = byteArrayOf(0.toByte(), 1.toByte(), 2.toByte(), 255.toByte()) + val shape = longArrayOf(4, 1, 1) + var tensor = Tensor.fromBlobUnsigned(data, shape) + assertEquals(tensor.dtype(), DType.UINT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + + val byteBuffer = Tensor.allocateByteBuffer(4) + byteBuffer.put(data) + tensor = Tensor.fromBlobUnsigned(byteBuffer, shape) + assertEquals(tensor.dtype(), DType.UINT8) + assertEquals(shape[0], tensor.shape()[0]) + assertEquals(shape[1], tensor.shape()[1]) + assertEquals(shape[2], tensor.shape()[2]) + assertEquals(4, tensor.numel()) + assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) + assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) + assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) + assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) + } + + @Test + fun testIllegalDataTypeException() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + assertEquals(tensor.dtype(), DType.FLOAT) + + assertThatThrownBy { tensor.dataAsByteArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") + + assertThatThrownBy { tensor.dataAsUnsignedByteArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") + + assertThatThrownBy { tensor.dataAsIntArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") + + assertThatThrownBy { tensor.dataAsDoubleArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") + + assertThatThrownBy { tensor.dataAsLongArray } + .isInstanceOf(IllegalStateException::class.java) + .hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") + } + + @Test + fun testIllegalArguments() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shapeWithNegativeValues = longArrayOf(-1, 2) + val mismatchShape = longArrayOf(1, 2) + + assertThatThrownBy { Tensor.fromBlob(null as FloatArray?, mismatchShape) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Data array must be not null") + + assertThatThrownBy { Tensor.fromBlob(data, null) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Shape must be not null") + + assertThatThrownBy { Tensor.fromBlob(data, shapeWithNegativeValues) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Shape elements must be non negative") + + assertThatThrownBy { Tensor.fromBlob(data, mismatchShape) } + .isInstanceOf(IllegalArgumentException::class.java) + .hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") + } + + @Test + fun testLongTensorSerde() { + val data = longArrayOf(1, 2, 3, 4) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + val bytes = tensor.toByteArray() + + val deser = Tensor.fromByteArray(bytes) + val deserShape = deser.shape() + val deserData = deser.dataAsLongArray + + for (i in data.indices) { + assertEquals(data[i], deserData[i]) } - @Test - fun testIntTensor() { - val data = intArrayOf(Int.MIN_VALUE, 0, 1, Int.MAX_VALUE) - val shape = longArrayOf(1, 4, 1) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.INT32) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) - - val intBuffer = Tensor.allocateIntBuffer(4) - intBuffer.put(data) - tensor = Tensor.fromBlob(intBuffer, shape) - assertEquals(tensor.dtype(), DType.INT32) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsIntArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsIntArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsIntArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsIntArray[3].toLong()) + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } - @Test - fun testDoubleTensor() { - val data = doubleArrayOf(Double.MIN_VALUE, 0.0, 0.1, Double.MAX_VALUE) - val shape = longArrayOf(1, 4) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.DOUBLE) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) - assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) - assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) - assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) - - val doubleBuffer = Tensor.allocateDoubleBuffer(4) - doubleBuffer.put(data) - tensor = Tensor.fromBlob(doubleBuffer, shape) - assertEquals(tensor.dtype(), DType.DOUBLE) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0], tensor.dataAsDoubleArray[0], 1e-5) - assertEquals(data[1], tensor.dataAsDoubleArray[1], 1e-5) - assertEquals(data[2], tensor.dataAsDoubleArray[2], 1e-5) - assertEquals(data[3], tensor.dataAsDoubleArray[3], 1e-5) - } - - @Test - fun testLongTensor() { - val data = longArrayOf(Long.MIN_VALUE, 0L, 1L, Long.MAX_VALUE) - val shape = longArrayOf(4, 1) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.INT64) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0], tensor.dataAsLongArray[0]) - assertEquals(data[1], tensor.dataAsLongArray[1]) - assertEquals(data[2], tensor.dataAsLongArray[2]) - assertEquals(data[3], tensor.dataAsLongArray[3]) - - val longBuffer = Tensor.allocateLongBuffer(4) - longBuffer.put(data) - tensor = Tensor.fromBlob(longBuffer, shape) - assertEquals(tensor.dtype(), DType.INT64) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(4, tensor.numel()) - assertEquals(data[0], tensor.dataAsLongArray[0]) - assertEquals(data[1], tensor.dataAsLongArray[1]) - assertEquals(data[2], tensor.dataAsLongArray[2]) - assertEquals(data[3], tensor.dataAsLongArray[3]) - } - - @Test - fun testSignedByteTensor() { - val data = byteArrayOf(Byte.MIN_VALUE, 0.toByte(), 1.toByte(), Byte.MAX_VALUE) - val shape = longArrayOf(1, 1, 4) - var tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.INT8) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) - - val byteBuffer = Tensor.allocateByteBuffer(4) - byteBuffer.put(data) - tensor = Tensor.fromBlob(byteBuffer, shape) - assertEquals(tensor.dtype(), DType.INT8) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsByteArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsByteArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsByteArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsByteArray[3].toLong()) - } - - @Test - fun testUnsignedByteTensor() { - val data = byteArrayOf(0.toByte(), 1.toByte(), 2.toByte(), 255.toByte()) - val shape = longArrayOf(4, 1, 1) - var tensor = Tensor.fromBlobUnsigned(data, shape) - assertEquals(tensor.dtype(), DType.UINT8) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) - - val byteBuffer = Tensor.allocateByteBuffer(4) - byteBuffer.put(data) - tensor = Tensor.fromBlobUnsigned(byteBuffer, shape) - assertEquals(tensor.dtype(), DType.UINT8) - assertEquals(shape[0], tensor.shape()[0]) - assertEquals(shape[1], tensor.shape()[1]) - assertEquals(shape[2], tensor.shape()[2]) - assertEquals(4, tensor.numel()) - assertEquals(data[0].toLong(), tensor.dataAsUnsignedByteArray[0].toLong()) - assertEquals(data[1].toLong(), tensor.dataAsUnsignedByteArray[1].toLong()) - assertEquals(data[2].toLong(), tensor.dataAsUnsignedByteArray[2].toLong()) - assertEquals(data[3].toLong(), tensor.dataAsUnsignedByteArray[3].toLong()) - } - - @Test - fun testIllegalDataTypeException() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - assertEquals(tensor.dtype(), DType.FLOAT) - - assertThatThrownBy { - tensor.dataAsByteArray - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as byte array.") - - assertThatThrownBy { - tensor.dataAsUnsignedByteArray - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as unsigned byte array.") - - assertThatThrownBy { - tensor.dataAsIntArray - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as int array.") - - assertThatThrownBy { - tensor.dataAsDoubleArray - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as double array.") - - assertThatThrownBy { - tensor.dataAsLongArray - }.isInstanceOf(IllegalStateException::class.java) - .hasMessage("Tensor of type Tensor_float32 cannot return data as long array.") - } - - @Test - fun testIllegalArguments() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shapeWithNegativeValues = longArrayOf(-1, 2) - val mismatchShape = longArrayOf(1, 2) - - assertThatThrownBy { - Tensor.fromBlob(null as FloatArray?, mismatchShape) - }.isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Data array must be not null") - - assertThatThrownBy { - Tensor.fromBlob(data, null) - }.isInstanceOf(IllegalArgumentException::class.java).hasMessage("Shape must be not null") - - assertThatThrownBy { - Tensor.fromBlob(data, shapeWithNegativeValues) - }.isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Shape elements must be non negative") - - assertThatThrownBy { - Tensor.fromBlob(data, mismatchShape) - }.isInstanceOf(IllegalArgumentException::class.java) - .hasMessage("Inconsistent data capacity:4 and shape number elements:2 shape:[1, 2]") - } - - @Test - fun testLongTensorSerde() { - val data = longArrayOf(1, 2, 3, 4) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - val bytes = tensor.toByteArray() + @Test + fun testFloatTensorSerde() { + val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) + val shape = longArrayOf(2, 2) + val tensor = Tensor.fromBlob(data, shape) + val bytes = tensor.toByteArray() - val deser = Tensor.fromByteArray(bytes) - val deserShape = deser.shape() - val deserData = deser.dataAsLongArray + val deser = Tensor.fromByteArray(bytes) + val deserShape = deser.shape() + val deserData = deser.dataAsFloatArray - for (i in data.indices) { - assertEquals(data[i], deserData[i]) - } - - for (i in shape.indices) { - assertEquals(shape[i], deserShape[i]) - } + for (i in data.indices) { + assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) } - @Test - fun testFloatTensorSerde() { - val data = floatArrayOf(Float.MIN_VALUE, 0f, 0.1f, Float.MAX_VALUE) - val shape = longArrayOf(2, 2) - val tensor = Tensor.fromBlob(data, shape) - val bytes = tensor.toByteArray() - - val deser = Tensor.fromByteArray(bytes) - val deserShape = deser.shape() - val deserData = deser.dataAsFloatArray - - for (i in data.indices) { - assertEquals(data[i].toDouble(), deserData[i].toDouble(), 1e-5) - } - - for (i in shape.indices) { - assertEquals(shape[i], deserShape[i]) - } + for (i in shape.indices) { + assertEquals(shape[i], deserShape[i]) } + } }