Skip to content

Commit 3d82a1d

Browse files
malfetpytorchmergebot
authored andcommitted
Add checks for empty tensor list (pytorch#155383)
Vibe-coded with Codex, after collecting a backtrace, see https://chatgpt.com/s/cd_68438be8a1248191adbfa0a5f000e60b Even though, check for empty tensor list exists in `at::cat` crash might happens while resolving named dimension to position, by calling `dimname_to_position(tensors[0], dim)`, see backtrace below ``` (lldb) up frame #1: 0x00000001101146dc libtorch_cpu.dylib`at::TensorBase::has_names(this=0x0000000000000000) const at TensorBase.h:559:10 556 bool has_names() const { 557 // If a user is using unnamed tensors, then we can short-circuit right here. 558 // Otherwise, impl::has_names attempts to retrieve names. -> 559 if (!impl_->has_named_tensor_meta()) { 560 return false; 561 } 562 return impl::has_names(unsafeGetTensorImpl()); (lldb) up frame #2: 0x00000001101144c4 libtorch_cpu.dylib`at::dimname_to_position(tensor=0x0000000000000000, dim=Dimname @ 0x000000016fdfe348) at NamedTensorUtils.cpp:23:3 20 int64_t dimname_to_position(const Tensor& tensor, Dimname dim) { 21 TORCH_CHECK(dim.type() != NameType::WILDCARD, 22 "Please look up dimensions by name, got: name = None."); -> 23 TORCH_CHECK(tensor.has_names(), 24 "Name ", dim, " not found in ", toDimnameRepr(tensor), "."); 25 const auto names = tensor.names(); 26 ``` TODOs: - May be move test from `test_tensor_creation.py` to OpInfo (not sure which one is more readable) - Replace `TORCH_CHECK` with `TORCH_CHECK_VALUE` and adjust unit tests Fixes pytorch#155306 Pull Request resolved: pytorch#155383 Approved by: https://github.com/cyyever, https://github.com/ezyang ghstack dependencies: pytorch#155382
1 parent 95448b2 commit 3d82a1d

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,10 +774,12 @@ Tensor cat(TensorList tensors, Dimname dim) {
774774

775775
// torch.concat, alias for torch.cat
776776
Tensor& concat_out(TensorList tensors, Dimname dim, Tensor& result) {
777+
TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
777778
return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
778779
}
779780

780781
Tensor concat(TensorList tensors, Dimname dim) {
782+
TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
781783
return at::cat(tensors, dimname_to_position(tensors[0], dim));
782784
}
783785

@@ -791,10 +793,12 @@ Tensor concat(TensorList tensors, int64_t dim) {
791793

792794
// torch.concatenate, alias for torch.cat
793795
Tensor& concatenate_out(TensorList tensors, Dimname dim, Tensor& result) {
796+
TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
794797
return at::cat_out(result, tensors, dimname_to_position(tensors[0], dim));
795798
}
796799

797800
Tensor concatenate(TensorList tensors, Dimname dim) {
801+
TORCH_CHECK(!tensors.empty(), "expected a non-empty list of Tensors");
798802
return at::cat(tensors, dimname_to_position(tensors[0], dim));
799803
}
800804

test/test_tensor_creation_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,14 @@ def test_cat_empty(self, device):
533533
res1 = torch.cat([empty, empty], dim=1)
534534
self.assertEqual(res1, empty)
535535

536+
def test_concat_empty_list_error(self, device):
537+
# Regression test for https://github.com/pytorch/pytorch/issues/155306
538+
msg = "expected a non-empty list of Tensors"
539+
with self.assertRaisesRegex(RuntimeError, msg):
540+
torch.concat([], dim='N')
541+
with self.assertRaisesRegex(RuntimeError, msg):
542+
torch.concatenate([], dim='N')
543+
536544
def test_cat_out(self, device):
537545
x = torch.zeros((0), device=device)
538546
y = torch.randn((4, 6), device=device)

0 commit comments

Comments
 (0)