Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -443,15 +443,46 @@ export class Tensor {
// Precompute strides
const stride = this.stride();

for (let i = 0; i < newBufferSize; ++i) {
let originalIndex = 0;
for (let j = newDims.length - 1, num = i; j >= 0; --j) {
const size = newDims[j];
originalIndex += ((num % size) + newOffsets[j][0]) * stride[j];
num = Math.floor(num / size);
// Detect if the slice is contiguous
let isContiguous = true;
for (let i = 1; i < newDims.length; ++i) {
if (newOffsets[i][0] !== 0 || newOffsets[i][1] !== this.dims[i]) {
isContiguous = false;
break;
}
data[i] = this_data[originalIndex];
}

if (isContiguous) {
// Perform bulk copy for contiguous slices to improve performance
const start = newOffsets[0][0] * stride[0];
const end = newOffsets[0][1] * stride[0];

if (ArrayBuffer.isView(this_data)) {
// If this.data is a TypedArray, use subarray
// @ts-ignore
data.set(this_data.subarray(start, end));
} else if (Array.isArray(this_data)) {
// If this.data is a plain array, use slice
const slicedData = this_data.slice(start, end);
for (let i = 0; i < slicedData.length; ++i) {
data[i] = slicedData[i];
}
} else {
throw new Error("Unsupported data type for slicing");
}
} else {
// Fallback to manual copying for non-contiguous slices
for (let i = 0; i < newBufferSize; ++i) {
let originalIndex = 0;
for (let j = newDims.length - 1, num = i; j >= 0; --j) {
const size = newDims[j];
originalIndex += ((num % size) + newOffsets[j][0]) * stride[j];
num = Math.floor(num / size);
}
data[i] = this_data[originalIndex];
}
}

return new Tensor(this.type, data, newTensorDims);
}

Expand Down
62 changes: 59 additions & 3 deletions tests/utils/tensor.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,13 @@ describe("Tensor operations", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(1);
const target = new Tensor("float32", [3, 4], [2]);

compare(t2, target);
});

it("should return a range of rows", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice([1, 3]);
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);

compare(t2, target);
});

Expand All @@ -78,9 +76,67 @@ describe("Tensor operations", () => {
[4, 7],
);
const t2 = t1.slice([1, -1], [1, -1]);

const target = new Tensor("float32", [9, 10, 11, 12, 13, 16, 17, 18, 19, 20], [2, 5]);
compare(t2, target);
});

it("should return the whole tensor when all indices are null/unset", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice();
compare(t2, t1);
});

it("should return the whole dimension when index is null", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(null);
compare(t2, t1);
});

it("should slice from index to end when [start, null] is used", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice([1, null]);
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);
compare(t2, target);
});

it("should slice from beginning to index when [null, end] is used", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice([null, 2]);
const target = new Tensor("float32", [1, 2, 3, 4], [2, 2]);
compare(t2, target);
});

it("should handle [null, null] as full slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice([null, null]);
compare(t2, t1);
});

it("should select a single element when a number is used in slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(2, 1);
const target = new Tensor("float32", [6], []);
compare(t2, target);
});

it("should select a single row when a number is used in slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(0);
const target = new Tensor("float32", [1, 2], [2]);
compare(t2, target);
});

it("should select a single column when a number is used in slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(null, 1);
const target = new Tensor("float32", [2, 4, 6], [3]);
compare(t2, target);
});

it("should handle negative indices in slice", () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(-1);
const target = new Tensor("float32", [5, 6], [2]);
compare(t2, target);
});
});
Expand Down