Skip to content
This repository was archived by the owner on Oct 27, 2023. It is now read-only.

Commit 67b39a6

Browse files
author
Will Feng
authored
Remove legacy function usgae in numpy_extensions_tutorial.py (pytorch#561)
1 parent 0bc39f2 commit 67b39a6

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

advanced_source/numpy_extensions_tutorial.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@
3535

3636

3737
class BadFFTFunction(Function):
38-
39-
def forward(self, input):
38+
@staticmethod
39+
def forward(ctx, input):
4040
numpy_input = input.detach().numpy()
4141
result = abs(rfft2(numpy_input))
4242
return input.new(result)
4343

44-
def backward(self, grad_output):
44+
@staticmethod
45+
def backward(ctx, grad_output):
4546
numpy_go = grad_output.numpy()
4647
result = irfft2(numpy_go)
4748
return grad_output.new(result)
@@ -51,7 +52,7 @@ def backward(self, grad_output):
5152

5253

5354
def incorrect_fft(input):
54-
return BadFFTFunction()(input)
55+
return BadFFTFunction.apply(input)
5556

5657
###############################################################
5758
# **Example usage of the created layer:**

0 commit comments

Comments
 (0)