Skip to content

Commit b4d926b

Browse files
VinhLoiITrohitgr7
andauthored
Fix reset TensorRunningAccum (#5106)
* Fix reset TensorRunningAccum * add test for TensorRunningAccum's reset method * fix CI failed due to PEP8 Co-authored-by: Rohit Gupta <[email protected]>
1 parent afe5da7 commit b4d926b

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

pytorch_lightning/trainer/supporters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, window_length: int):
5050

5151
def reset(self) -> None:
5252
"""Empty the accumulator."""
53-
self = TensorRunningAccum(self.window_length)
53+
self.__init__(self.window_length)
5454

5555
def last(self):
5656
"""Get the last added element."""

tests/trainer/test_supporters.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytest
15+
import torch
16+
17+
from pytorch_lightning.trainer.supporters import TensorRunningAccum
18+
19+
20+
def test_tensor_running_accum_reset():
21+
""" Test that reset would set all attributes to the initialization state """
22+
23+
window_length = 10
24+
25+
accum = TensorRunningAccum(window_length=window_length)
26+
assert accum.last() is None
27+
assert accum.mean() is None
28+
29+
accum.append(torch.tensor(1.5))
30+
assert accum.last() == torch.tensor(1.5)
31+
assert accum.mean() == torch.tensor(1.5)
32+
33+
accum.reset()
34+
assert accum.window_length == window_length
35+
assert accum.memory is None
36+
assert accum.current_idx == 0
37+
assert accum.last_idx is None
38+
assert not accum.rotated

0 commit comments

Comments
 (0)