Skip to content

Commit 1fa9c8e

Browse files
authored
Merge pull request #175 from dsyme/funcs5
einsum
2 parents 0ae34e0 + 22f545b commit 1fa9c8e

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

src/Native/LibTorchSharp/THSTensor.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,11 @@ Tensor THSTensor_div_scalar_(const Tensor left, const Scalar right)
573573
CATCH_TENSOR(left->div_(*right));
574574
}
575575

576+
Tensor THSTensor_einsum(const char* equation, const Tensor* tensors, const int length)
577+
{
578+
CATCH_TENSOR(torch::einsum(equation, toTensors<at::Tensor>((torch::Tensor**)tensors, length)));
579+
}
580+
576581
int64_t THSTensor_element_size(const Tensor tensor)
577582
{
578583
CATCH_RETURN(int64_t, 0, tensor->element_size());

src/Native/LibTorchSharp/THSTensor.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ EXPORT_API(Tensor) THSTensor_div_scalar(const Tensor left, const Scalar right);
215215

216216
EXPORT_API(Tensor) THSTensor_div_scalar_(const Tensor left, const Scalar right);
217217

218+
EXPORT_API(Tensor) THSTensor_einsum(const char* equation, const Tensor* tensors, const int length);
219+
218220
EXPORT_API(int64_t) THSTensor_element_size(const Tensor tensor);
219221

220222
EXPORT_API(Tensor) THSTensor_empty(

src/TorchSharp/Tensor/TorchTensor.cs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3746,23 +3746,39 @@ public static TorchTensor Cat(this TorchTensor[] tensors, long dimension)
37463746
return tensors[0];
37473747
}
37483748

3749-
var parray = new PinnedArray<IntPtr>();
3750-
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
3749+
using (var parray = new PinnedArray<IntPtr>()) {
3750+
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
37513751

3752-
return new TorchTensor(THSTensor_cat(tensorsRef, parray.Array.Length, dimension));
3752+
return new TorchTensor(THSTensor_cat(tensorsRef, parray.Array.Length, dimension));
3753+
}
37533754
}
37543755

37553756
[DllImport("LibTorchSharp")]
37563757
extern static IntPtr THSTensor_stack(IntPtr tensor, int len, long dim);
37573758

37583759
public static TorchTensor Stack(this TorchTensor[] tensors, long dimension)
37593760
{
3760-
var parray = new PinnedArray<IntPtr>();
3761-
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
3761+
using (var parray = new PinnedArray<IntPtr>()) {
3762+
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
37623763

3763-
var res = THSTensor_stack(tensorsRef, parray.Array.Length, dimension);
3764-
Torch.CheckForErrors();
3765-
return new TorchTensor(res);
3764+
var res = THSTensor_stack(tensorsRef, parray.Array.Length, dimension);
3765+
Torch.CheckForErrors();
3766+
return new TorchTensor(res);
3767+
}
3768+
}
3769+
3770+
[DllImport("LibTorchSharp")]
3771+
extern static IntPtr THSTensor_einsum([MarshalAs(UnmanagedType.LPStr)] string location, IntPtr tensors, int len);
3772+
3773+
public static TorchTensor Einsum(string equation, params TorchTensor[] tensors)
3774+
{
3775+
using (var parray = new PinnedArray<IntPtr>()) {
3776+
IntPtr tensorsRef = parray.CreateArray(tensors.Select(p => p.Handle).ToArray());
3777+
3778+
var res = THSTensor_einsum(equation, tensorsRef, parray.Array.Length);
3779+
Torch.CheckForErrors();
3780+
return new TorchTensor(res);
3781+
}
37663782
}
37673783

37683784
}

0 commit comments

Comments
 (0)