Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,12 @@ __Breaking Changes__:

__API Changes__:

- #1291 `Tensor.grad()` and `Tensor.set_grad()` have been replaced by a new property `Tensor.grad`.
- A potential memory leak caused by `set_grad` has been resolved.

__Bug Fixes__:

- #1300 `Adadelta`, `Adam` and `AdamW` will no longer throw `NullReferenceException` when `maximize` is `true` and `grad` is `null`.
- `torch.normal` will now correctly return a leaf tensor.

# NuGet Version 0.102.4
Expand Down
2 changes: 1 addition & 1 deletion src/Examples/AdversarialExampleGeneration.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ private static double Test(
model.zero_grad();
loss.backward();

var perturbed = Attack(data, ε, data.grad());
var perturbed = Attack(data, ε, data.grad);

using (var final = model.call(perturbed)) {

Expand Down
2 changes: 1 addition & 1 deletion src/FSharp.Examples/AdversarialExampleGeneration.fs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ let test (model:MNIST.Model) (eps:float) (data:Dataset) size =
model.zero_grad()
loss.backward()

use perturbed = attack input (eps.ToScalar()) (input.grad())
use perturbed = attack input (eps.ToScalar()) (input.grad)
use final = perturbed --> model
correct <- correct + final.argmax(1L).eq(labels).sum().ToInt32()
end
Expand Down
24 changes: 12 additions & 12 deletions src/TorchSharp/NN/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,10 @@ private void _toEpilog(ScalarType? dtype, Device device)
.ToDictionary(field => field.ComponentName());

foreach (var (name, param) in named_parameters(false).ToList()) {
using var grad = param.grad;

if (!param.toWillCopy(dtype ?? param.dtype, device ?? param.device) &&
(param.grad() is null || !param.grad().toWillCopy(dtype ?? param.dtype, device ?? param.device)))
(grad is null || !grad.toWillCopy(dtype ?? param.dtype, device ?? param.device)))
continue;

Parameter p;
Expand All @@ -252,20 +254,19 @@ private void _toEpilog(ScalarType? dtype, Device device)
// disable grad we would need to call .detach() on the moved tensor.
using (var d = torch.no_grad()) {
p = new Parameter(
param.to(paramType, device ?? param.device).DetachFromDisposeScope(), param.requires_grad)
.DetachFromDisposeScope() as Parameter;
data: param.to(paramType, device ?? param.device),
requires_grad: param.requires_grad);
_ = p.DetachFromDisposeScope();

// Copy the gradient over as well, if it exists
var grad = param.grad();
if (grad is not null) {
p.set_grad(grad.to(paramType, device ?? param.device)
.with_requires_grad(grad.requires_grad)
.MoveToOtherDisposeScope(p));
using var newGrad = grad.to(paramType, device ?? param.device)
.with_requires_grad(grad.requires_grad);
p.grad = newGrad;
}

// Dispose the param and gradient
// Dispose the param
param.Dispose();
grad?.Dispose();
}
ConditionallyRegisterParameter(name, p);

Expand Down Expand Up @@ -360,11 +361,10 @@ public virtual void zero_grad(bool set_to_none = true)
CheckForErrors();

foreach (var (_, p) in named_parameters()) {
var grad = p.grad();
using var grad = p.grad;
if (grad is not null) {
if (set_to_none) {
p.set_grad(null);
grad.DetachFromDisposeScope().Dispose();
p.grad = null;
} else {
grad.zero_();
}
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Optimizers/ASGD.cs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ public override Tensor step(Func<Tensor> closure = null)

foreach (var param in group.Parameters) {

var grad = param.grad();
var grad = param.grad;

if (grad is null) continue;

Expand Down
4 changes: 3 additions & 1 deletion src/TorchSharp/Optimizers/Adadelta.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ public override Tensor step(Func<Tensor> closure = null)

foreach (var param in group.Parameters) {

var grad = (maximize) ? -param.grad() : param.grad();
var grad = param.grad;

if (grad is null) continue;

if (maximize) grad = -grad;

if (grad.is_sparse) throw new ArgumentException("Adadelta does not support sparse gradients");

var state = (State)_state[param.handle];
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Optimizers/Adagrad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public override Tensor step(Func<Tensor> closure = null)

var state = (State)_state[param.handle];

var grad = param.grad();
var grad = param.grad;

if (grad is null) continue;

Expand Down
4 changes: 3 additions & 1 deletion src/TorchSharp/Optimizers/Adam.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,12 @@ public override Tensor step(Func<Tensor> closure = null)

var state = (State)_state[param.handle];

var grad = (maximize) ? -param.grad() : param.grad();
var grad = param.grad;

if (grad is null) continue;

if (maximize) grad = -grad;

state.step += 1;

var bias_correction1 = 1 - Math.Pow(beta1, state.step);
Expand Down
4 changes: 3 additions & 1 deletion src/TorchSharp/Optimizers/AdamW.cs
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,12 @@ public override Tensor step(Func<Tensor> closure = null)

var state = (State)_state[param.handle];

var grad = (maximize) ? -param.grad() : param.grad();
var grad = param.grad;

if (grad is null) continue;

if (maximize) grad = -grad;

state.step += 1;

param.mul_(1 - lr * weight_decay);
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Optimizers/Adamax.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ public override Tensor step(Func<Tensor> closure = null)

foreach (var param in group.Parameters) {

var grad = param.grad();
var grad = param.grad;

if (grad is null) continue;

Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Optimizers/NAdam.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ public override Tensor step(Func<Tensor> closure = null)

foreach (var param in group.Parameters) {

var grad = param.grad();
var grad = param.grad;

if (grad is null) continue;

Expand Down
9 changes: 2 additions & 7 deletions src/TorchSharp/Optimizers/Optimizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -396,14 +396,9 @@ public void to(Device device)
public override void zero_grad()
{
foreach (var g in _parameter_groups) {

foreach (var p in g.Parameters) {

using var grad = p.grad();

if (grad is null) continue;

grad.zero_().Dispose();
using var grad = p.grad;
_ = grad?.zero_();
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Optimizers/RAdam.cs
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ public override Tensor step(Func<Tensor> closure = null)

foreach (var param in group.Parameters) {

var grad = param.grad();
var grad = param.grad;

if (grad is null) continue;

Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Optimizers/RMSprop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ public override Tensor step(Func<Tensor> closure = null)

var state = (State)_state[param.handle];

var grad = param.grad();
var grad = param.grad;

if (grad is null) continue;

Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Optimizers/Rprop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public override Tensor step(Func<Tensor> closure = null)

foreach (var param in group.Parameters) {

var grad = param.grad();
var grad = param.grad;

if (grad is null) continue;

Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/Optimizers/SGD.cs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ public override Tensor step(Func<Tensor> closure = null)

var state = (State)_state[param.handle];

var grad = param.grad();
var grad = param.grad;

if (grad is null) continue;

Expand Down
30 changes: 13 additions & 17 deletions src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1340,25 +1340,21 @@ public Tensor pin_memory()
/// This attribute is null by default and becomes a Tensor the first time a call to backward() computes gradients for the tensor.
/// The attribute will then contain the gradients computed and future calls to backward() will accumulate (add) gradients into it.
/// </summary>
public Tensor? grad()
{
var res = NativeMethods.THSTensor_grad(Handle);
CheckForErrors();

if (res == IntPtr.Zero)
return null;
public Tensor? grad {
get {
var res = NativeMethods.THSTensor_grad(Handle);

return new Tensor(res);
}
if (res == IntPtr.Zero) {
CheckForErrors();
return null;
}

/// <summary>
/// This function will set the `tensor.grad()` attribute to a custom tensor.
/// </summary>
/// <param name="grad">The new gradient tensor</param>
public void set_grad(Tensor grad)
{
NativeMethods.THSTensor_set_grad(Handle, grad?.DetachFromDisposeScope().Handle ?? IntPtr.Zero);
CheckForErrors();
return new Tensor(res);
}
set {
NativeMethods.THSTensor_set_grad(Handle, value?.Handle ?? IntPtr.Zero);
CheckForErrors();
}
}

internal void EncodeIndices(TensorIndex[] indices,
Expand Down
Loading