Skip to content

Commit b10cc7c

Browse files
YodaEmbeddingfracape
authored andcommitted
fix: default_collate changes in PyTorch 2.0
As shown in PyTorch issue 99227, default_collate behaves differently in PyTorch v2.0. Thus, this commit manually reimplements the desired behavior.
1 parent 50ddf91 commit b10cc7c

File tree

1 file changed

+27
-2
lines changed

1 file changed

+27
-2
lines changed

compressai/latent_codecs/rasterscan.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,13 @@
2727
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
2828
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2929

30-
from typing import Any, Callable, Dict, List, Optional, Tuple
30+
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
3131

3232
import torch
3333
import torch.nn as nn
3434
import torch.nn.functional as F
3535

3636
from torch import Tensor
37-
from torch.utils.data.dataloader import default_collate
3837

3938
from compressai.ans import BufferedRansEncoder, RansDecoder
4039
from compressai.entropy_models import GaussianConditional
@@ -47,6 +46,9 @@
4746
"RasterScanLatentCodec",
4847
]
4948

49+
K = TypeVar("K")
50+
V = TypeVar("V")
51+
5052

5153
@register_module("RasterScanLatentCodec")
5254
class RasterScanLatentCodec(LatentCodec):
@@ -309,3 +311,26 @@ def _pad_2d(x: Tensor, padding: int) -> Tensor:
309311
def _reduce_seq(xs):
310312
assert all(x == xs[0] for x in xs)
311313
return xs[0]
314+
315+
316+
def default_collate(batch: List[Dict[K, V]]) -> Dict[K, List[V]]:
317+
if not isinstance(batch, list) or any(not isinstance(d, dict) for d in batch):
318+
raise NotImplementedError
319+
320+
result = _ld_to_dl(batch)
321+
322+
for k, vs in result.items():
323+
if all(isinstance(v, Tensor) for v in vs):
324+
result[k] = torch.stack(vs)
325+
326+
return result
327+
328+
329+
def _ld_to_dl(ld: List[Dict[K, V]]) -> Dict[K, List[V]]:
330+
dl = {}
331+
for d in ld:
332+
for k, v in d.items():
333+
if k not in dl:
334+
dl[k] = []
335+
dl[k].append(v)
336+
return dl

0 commit comments

Comments
 (0)