Skip to content
15 changes: 10 additions & 5 deletions src/ansys/dpf/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,16 @@


# register classes for collection types:
CustomTypeFieldsCollection:type = _CollectionFactory(CustomTypeField)
GenericDataContainersCollection:type = _CollectionFactory(GenericDataContainer)
StringFieldsCollection:type = _CollectionFactory(StringField)
OperatorsCollection: type = _CollectionFactory(Operator)
AnyCollection:type = _Collection
class CustomTypeFieldsCollection(_Collection[CustomTypeField]):
entries_type = CustomTypeField
class GenericDataContainersCollection(_Collection[GenericDataContainer]):
entries_type = GenericDataContainer
class StringFieldsCollection(_Collection[StringField]):
entries_type = StringField
class OperatorsCollection(_Collection[Operator]):
entries_type = Operator
class AnyCollection(_Collection[Any]):
entries_type = Any

# for matplotlib
# solves "QApplication: invalid style override passed, ignoring it."
Expand Down
32 changes: 27 additions & 5 deletions src/ansys/dpf/core/collection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,19 +414,41 @@ def get_label_scoping(self, label="time"):
scoping = Scoping(self._api.collection_get_label_scoping(self, label), server=self._server)
return scoping

def __getitem__(self, index):
"""Retrieve the entry at a requested index value.
def __getitem__(self, index: int | slice):
"""Retrieve the entry at a requested index value or build a new collection from a slice.

Parameters
----------
index : int
index:
Index value.

Returns
-------
entry : Field , Scoping
Entry at the index value.
entry:
Entry at the index value or new collection for entries at requested slice.
"""
if isinstance(index, slice):
# handle slice
indices = list(
range(
index.start if index.start else 0, index.stop, index.step if index.step else 1
)
)
out_collection = self.__class__()
out_collection.set_labels(labels=self._get_labels())
if hasattr(out_collection, "add_entry"):
# For any direct subclass of Collection
func = out_collection.add_entry
else:
# For FieldsContainers, ScopingsContainers and MeshesContainers
# because they have dedicated APIs
func = out_collection._add_entry
for i in indices:
func(
label_space=self.get_label_space(index=i),
entry=self._get_entries(label_space_or_index=i),
)
return out_collection
self_len = len(self)
if index < 0:
# convert to a positive index
Expand Down
33 changes: 32 additions & 1 deletion tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def create_dummy_gdc(server_type, prop="hi"):
@dataclass
class CollectionTypeHelper:
type: type
instance_creator: object
instance_creator: callable
kwargs: dict = field(default_factory=dict)

@property
Expand Down Expand Up @@ -245,3 +245,34 @@ def test_connect_collection_workflow(server_type, subtype_creator):
out = op.get_output(0, subtype_creator.type)
assert out is not None
assert len(out) == 1


def test_generic_data_containers_collection_slice(server_type):
coll = GenericDataContainersCollection(server=server_type)

coll.labels = ["id1", "id2"]
for i in range(5):
coll.add_entry(
label_space={"id1": i, "id2": 0}, entry=create_dummy_gdc(server_type=server_type)
)
assert len(coll) == 5
print(coll)
sliced_coll = coll[:3]
assert len(sliced_coll) == 3
print(sliced_coll)


def test_string_containers_collection_slice(server_type):
coll = StringFieldsCollection(server=server_type)

coll.labels = ["id1", "id2"]
for i in range(5):
coll.add_entry(
label_space={"id1": i, "id2": 0},
entry=create_dummy_string_field(server_type=server_type),
)
assert len(coll) == 5
print(coll)
sliced_coll = coll[:3]
assert len(sliced_coll) == 3
print(sliced_coll)
6 changes: 6 additions & 0 deletions tests/test_fieldscontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,3 +602,9 @@ def test_get_entries_indices_fields_container(server_type):
assert np.allclose(fc.get_entries_indices({"time": 1, "complex": 0}), [0])
assert np.allclose(fc.get_entries_indices({"time": 2}), [1])
assert np.allclose(fc.get_entries_indices({"complex": 0}), range(0, 20))


def test_fields_container_slice(server_type, disp_fc):
print(disp_fc)
fc = disp_fc[:1]
assert len(fc) == 1
Loading