Skip to content

Commit 68ac7e2

Browse files
authored
Merge pull request #176 from dsyme/echk
avoid unnecessary p/invoke calls
2 parents 1fa9c8e + e344828 commit 68ac7e2

File tree

18 files changed

+477
-391
lines changed

18 files changed

+477
-391
lines changed

src/Native/LibTorchSharp/Utils.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,3 @@ const char * make_sharable_string(const std::string str)
2121
return result;
2222
}
2323

24-
Tensor ResultTensor(const at::Tensor& res)
25-
{
26-
if (res.defined())
27-
return new torch::Tensor(res);
28-
else
29-
return NULL;
30-
}

src/Native/LibTorchSharp/Utils.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,13 @@ typedef std::shared_ptr<torch::optim::Optimizer> * Optimizer;
4141
#define CATCH_RETURN_Tensor(stmt) CATCH_RETURN_RES(Tensor, NULL, stmt)
4242

4343
// Return undefined tensors as NULL to C#
44-
Tensor ResultTensor(const at::Tensor & res);
44+
inline Tensor ResultTensor(const at::Tensor & res)
45+
{
46+
if (res.defined())
47+
return new torch::Tensor(res);
48+
else
49+
return NULL;
50+
}
4551

4652
#define CATCH_TENSOR(expr) \
4753
at::Tensor res = at::Tensor(); \

src/TorchSharp/NN/AdaptiveAvgPool2D.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ internal AdaptiveAvgPool2D (IntPtr handle, IntPtr boxedHandle) : base (handle, b
2020
public TorchTensor Forward (TorchTensor tensor)
2121
{
2222
var res = THSNN_AdaptiveAvgPool2d_forward (handle.DangerousGetHandle (), tensor.Handle);
23-
Torch.CheckForErrors ();
23+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
2424
return new TorchTensor (res);
2525
}
2626
}
@@ -34,7 +34,7 @@ static public AdaptiveAvgPool2D AdaptiveAvgPool2D (long[] kernelSize)
3434
unsafe {
3535
fixed (long* pkernelSize = kernelSize) {
3636
var handle = THSNN_AdaptiveAvgPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, out var boxedHandle);
37-
Torch.CheckForErrors ();
37+
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
3838
return new AdaptiveAvgPool2D (handle, boxedHandle);
3939
}
4040
}

src/TorchSharp/NN/AvgPool2D.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ internal AvgPool2D (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHand
2020
public TorchTensor Forward (TorchTensor tensor)
2121
{
2222
var res = THSNN_AvgPool2d_forward (handle.DangerousGetHandle (), tensor.Handle);
23-
Torch.CheckForErrors ();
23+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
2424
return new TorchTensor (res);
2525
}
2626
}
@@ -34,7 +34,7 @@ static public AvgPool2D AvgPool2D (long[] kernelSize, long[] strides = null)
3434
unsafe {
3535
fixed (long* pkernelSize = kernelSize, pstrides = strides) {
3636
var handle = THSNN_AvgPool2d_ctor ((IntPtr)pkernelSize, kernelSize.Length, (IntPtr)pstrides, (strides == null ? 0 : strides.Length), out var boxedHandle);
37-
Torch.CheckForErrors ();
37+
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
3838
return new AvgPool2D (handle, boxedHandle);
3939
}
4040
}

src/TorchSharp/NN/Conv2D.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ internal Conv2D (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
1515
public TorchTensor Forward (TorchTensor tensor)
1616
{
1717
var res = THSNN_Conv2d_forward (handle, tensor.Handle);
18-
Torch.CheckForErrors ();
18+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
1919
return new TorchTensor (res);
2020
}
2121
}
@@ -27,7 +27,7 @@ public static partial class Modules
2727
static public Conv2D Conv2D (long inputChannel, long outputChannel, long kernelSize, long stride = 1, long padding = 0)
2828
{
2929
var res = THSNN_Conv2d_ctor (inputChannel, outputChannel, kernelSize, stride, padding, out var boxedHandle);
30-
Torch.CheckForErrors ();
30+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
3131
return new Conv2D (res, boxedHandle);
3232
}
3333
}

src/TorchSharp/NN/Dropout.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ internal Dropout (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle
1818
public TorchTensor Forward (TorchTensor tensor)
1919
{
2020
var res = THSNN_Dropout_forward (handle, tensor.Handle);
21-
Torch.CheckForErrors ();
21+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
2222
return new TorchTensor (res);
2323
}
2424
}
@@ -30,7 +30,7 @@ public static partial class Modules
3030
static public Dropout Dropout (double probability = 0.5)
3131
{
3232
var handle = THSNN_Dropout_ctor (probability, out var boxedHandle);
33-
Torch.CheckForErrors ();
33+
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
3434
return new Dropout (handle, boxedHandle);
3535
}
3636
}

src/TorchSharp/NN/FeatureDropout.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ internal FeatureAlphaDropout (IntPtr handle, IntPtr boxedHandle) : base (handle,
2020
public TorchTensor Forward (TorchTensor tensor)
2121
{
2222
var res = THSNN_FeatureAlphaDropout_forward (handle, tensor.Handle);
23-
Torch.CheckForErrors ();
23+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
2424
return new TorchTensor (res);
2525
}
2626
}
@@ -32,7 +32,7 @@ public static partial class Modules
3232
static public FeatureAlphaDropout FeatureAlphaDropout (double probability = 0.5)
3333
{
3434
var handle = THSNN_FeatureAlphaDropout_ctor (probability, out var boxedHandle);
35-
Torch.CheckForErrors ();
35+
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
3636
return new FeatureAlphaDropout (handle, boxedHandle);
3737
}
3838
}

src/TorchSharp/NN/Linear.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ internal Linear (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
1515
public new static Linear Load (String modelPath)
1616
{
1717
var res = Module.Load (modelPath);
18-
Torch.CheckForErrors ();
1918
return new Linear (res.handle.DangerousGetHandle(), IntPtr.Zero);
2019
}
2120

@@ -25,7 +24,7 @@ internal Linear (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHandle)
2524
public TorchTensor Forward (TorchTensor tensor)
2625
{
2726
var res = THSNN_Linear_forward (handle, tensor.Handle);
28-
Torch.CheckForErrors ();
27+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
2928
return new TorchTensor (res);
3029
}
3130
[DllImport ("LibTorchSharp")]
@@ -36,7 +35,7 @@ public TorchTensor Forward (TorchTensor tensor)
3635
public TorchTensor? Bias {
3736
get {
3837
var res = THSNN_Linear_bias (handle);
39-
Torch.CheckForErrors ();
38+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
4039
return ((res == IntPtr.Zero) ? null : new TorchTensor (res));
4140
}
4241
set {
@@ -52,7 +51,7 @@ public TorchTensor? Bias {
5251
public TorchTensor Weight {
5352
get {
5453
var res = THSNN_Linear_weight (handle);
55-
Torch.CheckForErrors ();
54+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
5655
return new TorchTensor (res);
5756
}
5857
set {
@@ -69,7 +68,7 @@ public static partial class Modules
6968
static public Linear Linear (long inputSize, long outputSize, bool hasBias = true)
7069
{
7170
var res = THSNN_Linear_ctor (inputSize, outputSize, hasBias, out var boxedHandle);
72-
Torch.CheckForErrors ();
71+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
7372
return new Linear (res, boxedHandle);
7473
}
7574
}

src/TorchSharp/NN/LogSoftMax.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ internal LogSoftMax (IntPtr handle, IntPtr boxedHandle) : base (handle, boxedHan
2020
public TorchTensor Forward (TorchTensor tensor)
2121
{
2222
var res = THSNN_LogSoftMax_forward (handle, tensor.Handle);
23-
Torch.CheckForErrors ();
23+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
2424
return new TorchTensor (res);
2525
}
2626
}
@@ -32,7 +32,7 @@ public static partial class Modules
3232
static public LogSoftMax LogSoftMax (long dimension)
3333
{
3434
var handle = THSNN_LogSoftMax_ctor (dimension, out var boxedHandle);
35-
Torch.CheckForErrors ();
35+
if (handle == IntPtr.Zero) { Torch.CheckForErrors(); }
3636
return new LogSoftMax (handle, boxedHandle);
3737
}
3838
}

src/TorchSharp/NN/Losses.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public static Loss BCE (TorchTensor? weigths = null, Reduction reduction = Reduc
2222
{
2323
return (TorchTensor src, TorchTensor target) => {
2424
var res = THSNN_binary_cross_entropy (src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction);
25-
Torch.CheckForErrors ();
25+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
2626
return new TorchTensor (res);
2727
};
2828
}
@@ -34,7 +34,7 @@ public static Loss MSE (Reduction reduction = Reduction.Mean)
3434
{
3535
return (TorchTensor src, TorchTensor target) => {
3636
var res = THSNN_mse_loss (src.Handle, target.Handle, (long)reduction);
37-
Torch.CheckForErrors ();
37+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
3838
return new TorchTensor (res);
3939
};
4040
}
@@ -46,7 +46,7 @@ public static Loss NLL (TorchTensor? weigths = null, Reduction reduction = Reduc
4646
{
4747
return (TorchTensor src, TorchTensor target) => {
4848
var res = THSNN_nll_loss (src.Handle, target.Handle, weigths?.Handle ?? IntPtr.Zero, (long)reduction);
49-
Torch.CheckForErrors ();
49+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
5050
return new TorchTensor (res);
5151
};
5252
}
@@ -58,7 +58,7 @@ public static Loss PoissonNLL (bool logInput = true, bool full = false, float ep
5858
{
5959
return (TorchTensor src, TorchTensor target) => {
6060
var res = THSNN_poisson_loss (src.Handle, target.Handle, logInput, full, eps, (long)reduction);
61-
Torch.CheckForErrors ();
61+
if (res == IntPtr.Zero) { Torch.CheckForErrors(); }
6262
return new TorchTensor (res);
6363
};
6464
}

0 commit comments

Comments
 (0)