Skip to content

Expose type-erased tensor from generic one. #11962

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -821,9 +821,8 @@ public class Tensor<T: Scalar>: Equatable {
lhs.anyTensor == rhs.anyTensor
}

// MARK: Internal

let anyTensor: AnyTensor
// Wrapped AnyTensor instance.
public let anyTensor: AnyTensor
}

@available(*, deprecated, message: "This API is experimental.")
Expand Down
60 changes: 60 additions & 0 deletions extension/apple/ExecuTorch/__tests__/TensorTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,66 @@ class TensorTest: XCTestCase {
XCTAssertEqual(try tensor.scalars().first, 42)
}

func testExtractAnyTensorMatchesOriginalDataAndMetadata() {
let tensor = Tensor([1, 2, 3, 4], shape: [2, 2])
let anyTensor = tensor.anyTensor
XCTAssertEqual(anyTensor.shape, tensor.shape)
XCTAssertEqual(anyTensor.strides, tensor.strides)
XCTAssertEqual(anyTensor.dimensionOrder, tensor.dimensionOrder)
XCTAssertEqual(anyTensor.count, tensor.count)
XCTAssertEqual(anyTensor.dataType, tensor.dataType)
XCTAssertEqual(anyTensor.shapeDynamism, tensor.shapeDynamism)
let newTensor = Tensor<Int>(anyTensor)
XCTAssertEqual(newTensor, tensor)
}

func testReconstructGenericTensorViaInitAndAsTensor() {
let tensor = Tensor([5, 6, 7])
let anyTensor = tensor.anyTensor
let tensorInit = Tensor<Int>(anyTensor)
let tensorFromAny: Tensor<Int> = anyTensor.asTensor()!
XCTAssertEqual(tensorInit, tensorFromAny)
}

func testAsTensorMismatchedTypeReturnsNil() {
let tensor = Tensor([8, 9, 10])
let anyTensor = tensor.anyTensor
let wrongTypedTensor: Tensor<Float>? = anyTensor.asTensor()
XCTAssertNil(wrongTypedTensor)
}

func testViewSharesDataAndResizeAltersShapeNotData() throws {
var scalars = [11, 12, 13, 14]
let tensor = Tensor(&scalars, shape: [2, 2])
let viewTensor = Tensor(tensor)
let scalarsAddress = scalars.withUnsafeBufferPointer { $0.baseAddress }
let tensorDataAddress = try tensor.withUnsafeBytes { $0.baseAddress }
let viewTensorDataAddress = try viewTensor.withUnsafeBytes { $0.baseAddress }
XCTAssertEqual(tensorDataAddress, scalarsAddress)
XCTAssertEqual(tensorDataAddress, viewTensorDataAddress)

scalars[2] = 42
XCTAssertEqual(try tensor.scalars(), scalars)
XCTAssertEqual(try viewTensor.scalars(), scalars)

XCTAssertNoThrow(try viewTensor.resize(to: [4, 1]))
XCTAssertEqual(viewTensor.shape, [4, 1])
XCTAssertEqual(tensor.shape, [2, 2])
XCTAssertEqual(try tensor.scalars(), scalars)
XCTAssertEqual(try viewTensor.scalars(), scalars)
}

func testMultipleGenericFromAnyReflectChanges() {
let tensor = Tensor([2, 4, 6, 8], shape: [2, 2])
let anyTensor = tensor.anyTensor
let tensor1: Tensor<Int> = anyTensor.asTensor()!
let tensor2: Tensor<Int> = anyTensor.asTensor()!

XCTAssertEqual(tensor1, tensor2)
XCTAssertNoThrow(try tensor1.withUnsafeMutableBytes { $0[1] = 42 })
XCTAssertEqual(try tensor2.withUnsafeBytes { $0[1] }, 42)
}

func testEmpty() {
let tensor = Tensor<Float>.empty(shape: [3, 4])
XCTAssertEqual(tensor.shape, [3, 4])
Expand Down
Loading