|
27 | 27 | # OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF |
28 | 28 | # ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
29 | 29 |
|
30 | | -from typing import Any, Callable, Dict, List, Optional, Tuple |
| 30 | +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar |
31 | 31 |
|
32 | 32 | import torch |
33 | 33 | import torch.nn as nn |
34 | 34 | import torch.nn.functional as F |
35 | 35 |
|
36 | 36 | from torch import Tensor |
37 | | -from torch.utils.data.dataloader import default_collate |
38 | 37 |
|
39 | 38 | from compressai.ans import BufferedRansEncoder, RansDecoder |
40 | 39 | from compressai.entropy_models import GaussianConditional |
|
47 | 46 | "RasterScanLatentCodec", |
48 | 47 | ] |
49 | 48 |
|
| 49 | +K = TypeVar("K") |
| 50 | +V = TypeVar("V") |
| 51 | + |
50 | 52 |
|
51 | 53 | @register_module("RasterScanLatentCodec") |
52 | 54 | class RasterScanLatentCodec(LatentCodec): |
@@ -309,3 +311,26 @@ def _pad_2d(x: Tensor, padding: int) -> Tensor: |
309 | 311 | def _reduce_seq(xs): |
310 | 312 | assert all(x == xs[0] for x in xs) |
311 | 313 | return xs[0] |
| 314 | + |
| 315 | + |
| 316 | +def default_collate(batch: List[Dict[K, V]]) -> Dict[K, List[V]]: |
| 317 | + if not isinstance(batch, list) or any(not isinstance(d, dict) for d in batch): |
| 318 | + raise NotImplementedError |
| 319 | + |
| 320 | + result = _ld_to_dl(batch) |
| 321 | + |
| 322 | + for k, vs in result.items(): |
| 323 | + if all(isinstance(v, Tensor) for v in vs): |
| 324 | + result[k] = torch.stack(vs) |
| 325 | + |
| 326 | + return result |
| 327 | + |
| 328 | + |
| 329 | +def _ld_to_dl(ld: List[Dict[K, V]]) -> Dict[K, List[V]]: |
| 330 | + dl = {} |
| 331 | + for d in ld: |
| 332 | + for k, v in d.items(): |
| 333 | + if k not in dl: |
| 334 | + dl[k] = [] |
| 335 | + dl[k].append(v) |
| 336 | + return dl |
0 commit comments