File tree Expand file tree Collapse file tree 3 files changed +19
-0
lines changed
pytorch_lightning/overrides Expand file tree Collapse file tree 3 files changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -44,6 +44,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4444- Added ` ddp_fully_sharded ` support ([ #7487 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/7487 ) )
4545
4646
47+ - Added ` __len__ ` to ` IndexBatchSamplerWrapper ` ([ #7681 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/7681 ) )
48+
49+
4750- Added ` should_rank_save_checkpoint ` property to Training Plugins ([ #7684 ] ( https://github.com/PyTorchLightning/pytorch-lightning/pull/7684 ) )
4851
4952
Original file line number Diff line number Diff line change @@ -132,6 +132,9 @@ def __iter__(self) -> Iterator[List[int]]:
132132 self .batch_indices = batch
133133 yield batch
134134
135+ def __len__ (self ) -> int :
136+ return len (self ._sampler )
137+
135138 @property
136139 def drop_last (self ) -> bool :
137140 return self ._sampler .drop_last
Original file line number Diff line number Diff line change 1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ from collections .abc import Iterable
15+
1416import pytest
1517from torch .utils .data import BatchSampler , SequentialSampler
1618
1719from pytorch_lightning import seed_everything
1820from pytorch_lightning .overrides .distributed import IndexBatchSamplerWrapper , UnrepeatedDistributedSampler
21+ from pytorch_lightning .utilities .data import has_len
1922
2023
2124@pytest .mark .parametrize ("shuffle" , [False , True ])
@@ -54,3 +57,13 @@ def test_index_batch_sampler(tmpdir):
5457
5558 for batch in index_batch_sampler :
5659 assert index_batch_sampler .batch_indices == batch
60+
61+
62+ def test_index_batch_sampler_methods ():
63+ dataset = range (15 )
64+ sampler = SequentialSampler (dataset )
65+ batch_sampler = BatchSampler (sampler , 3 , False )
66+ index_batch_sampler = IndexBatchSamplerWrapper (batch_sampler )
67+
68+ assert isinstance (index_batch_sampler , Iterable )
69+ assert has_len (index_batch_sampler )
You can’t perform that action at this time.
0 commit comments