diff --git a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift index 461ef30b653..d4e2c4e9e82 100644 --- a/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift +++ b/extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift @@ -821,9 +821,8 @@ public class Tensor: Equatable { lhs.anyTensor == rhs.anyTensor } - // MARK: Internal - - let anyTensor: AnyTensor + // Wrapped AnyTensor instance. + public let anyTensor: AnyTensor } @available(*, deprecated, message: "This API is experimental.") diff --git a/extension/apple/ExecuTorch/__tests__/TensorTest.swift b/extension/apple/ExecuTorch/__tests__/TensorTest.swift index f596fe49400..407a9ee03e7 100644 --- a/extension/apple/ExecuTorch/__tests__/TensorTest.swift +++ b/extension/apple/ExecuTorch/__tests__/TensorTest.swift @@ -678,6 +678,66 @@ class TensorTest: XCTestCase { XCTAssertEqual(try tensor.scalars().first, 42) } + func testExtractAnyTensorMatchesOriginalDataAndMetadata() { + let tensor = Tensor([1, 2, 3, 4], shape: [2, 2]) + let anyTensor = tensor.anyTensor + XCTAssertEqual(anyTensor.shape, tensor.shape) + XCTAssertEqual(anyTensor.strides, tensor.strides) + XCTAssertEqual(anyTensor.dimensionOrder, tensor.dimensionOrder) + XCTAssertEqual(anyTensor.count, tensor.count) + XCTAssertEqual(anyTensor.dataType, tensor.dataType) + XCTAssertEqual(anyTensor.shapeDynamism, tensor.shapeDynamism) + let newTensor = Tensor(anyTensor) + XCTAssertEqual(newTensor, tensor) + } + + func testReconstructGenericTensorViaInitAndAsTensor() { + let tensor = Tensor([5, 6, 7]) + let anyTensor = tensor.anyTensor + let tensorInit = Tensor(anyTensor) + let tensorFromAny: Tensor = anyTensor.asTensor()! + XCTAssertEqual(tensorInit, tensorFromAny) + } + + func testAsTensorMismatchedTypeReturnsNil() { + let tensor = Tensor([8, 9, 10]) + let anyTensor = tensor.anyTensor + let wrongTypedTensor: Tensor? = anyTensor.asTensor() + XCTAssertNil(wrongTypedTensor) + } + + func testViewSharesDataAndResizeAltersShapeNotData() throws { + var scalars = [11, 12, 13, 14] + let tensor = Tensor(&scalars, shape: [2, 2]) + let viewTensor = Tensor(tensor) + let scalarsAddress = scalars.withUnsafeBufferPointer { $0.baseAddress } + let tensorDataAddress = try tensor.withUnsafeBytes { $0.baseAddress } + let viewTensorDataAddress = try viewTensor.withUnsafeBytes { $0.baseAddress } + XCTAssertEqual(tensorDataAddress, scalarsAddress) + XCTAssertEqual(tensorDataAddress, viewTensorDataAddress) + + scalars[2] = 42 + XCTAssertEqual(try tensor.scalars(), scalars) + XCTAssertEqual(try viewTensor.scalars(), scalars) + + XCTAssertNoThrow(try viewTensor.resize(to: [4, 1])) + XCTAssertEqual(viewTensor.shape, [4, 1]) + XCTAssertEqual(tensor.shape, [2, 2]) + XCTAssertEqual(try tensor.scalars(), scalars) + XCTAssertEqual(try viewTensor.scalars(), scalars) + } + + func testMultipleGenericFromAnyReflectChanges() { + let tensor = Tensor([2, 4, 6, 8], shape: [2, 2]) + let anyTensor = tensor.anyTensor + let tensor1: Tensor = anyTensor.asTensor()! + let tensor2: Tensor = anyTensor.asTensor()! + + XCTAssertEqual(tensor1, tensor2) + XCTAssertNoThrow(try tensor1.withUnsafeMutableBytes { $0[1] = 42 }) + XCTAssertEqual(try tensor2.withUnsafeBytes { $0[1] }, 42) + } + func testEmpty() { let tensor = Tensor.empty(shape: [3, 4]) XCTAssertEqual(tensor.shape, [3, 4])