Skip to content

Commit 7c0c8ca

Browse files
committed
!squash
1 parent 9d2d986 commit 7c0c8ca

File tree

6 files changed

+54
-28
lines changed

6 files changed

+54
-28
lines changed

libtmux/common.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import typing as t
1414
from collections.abc import MutableMapping
1515
from distutils.version import LooseVersion
16-
from typing import Any, Dict, List, Optional, Union
16+
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union, overload
1717

1818
from . import exc
1919
from ._compat import console_to_str, str_from_console
@@ -315,8 +315,11 @@ def __getattr__(self, key: str) -> str:
315315
raise AttributeError(f"{self.__class__} has no property {key}")
316316

317317

318-
class TmuxRelationalObject:
318+
O = TypeVar("O", "Pane", "Window", "Session")
319+
D = TypeVar("D", "PaneDict", "WindowDict", "SessionDict")
319320

321+
322+
class TmuxRelationalObject(Generic[O, D]):
320323
"""Base Class for managing tmux object child entities. .. # NOQA
321324
322325
Manages collection of child objects (a :class:`Server` has a collection of
@@ -345,6 +348,9 @@ class TmuxRelationalObject:
345348
================ ================================== ==============
346349
"""
347350

351+
children: t.List[O]
352+
child_id_attribute: str
353+
348354
def find_where(
349355
self, attrs: Dict[str, str]
350356
) -> Optional[Union["Pane", "Window", "Session"]]:
@@ -359,9 +365,20 @@ def find_where(
359365
except IndexError:
360366
return None
361367

362-
def where(
363-
self, attrs: Dict[str, str], first: bool = False
364-
) -> List[Union["Session", "Pane", "Window", t.Any]]:
368+
@overload
369+
def where(self, attrs: Dict[str, str], first: t.Literal[True]) -> O:
370+
...
371+
372+
@overload
373+
def where(self, attrs: Dict[str, str], first: t.Literal[False]) -> t.List[O]:
374+
...
375+
376+
@overload
377+
def where(self, attrs: Dict[str, str]) -> t.List[O]:
378+
...
379+
380+
def where(self, attrs: Dict[str, str], first: bool = False) -> t.Union[List[O], O]:
381+
# ) -> List[Union["Session", "Pane", "Window", t.Any]]:
365382
"""
366383
Return objects matching child objects properties.
367384
@@ -376,7 +393,7 @@ def where(
376393
"""
377394

378395
# from https://github.com/serkanyersen/underscore.py
379-
def by(val: WindowDict) -> bool:
396+
def by(val: D) -> bool:
380397
for key in attrs.keys():
381398
try:
382399
if attrs[key] != val[key]:

libtmux/pane.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
logger = logging.getLogger(__name__)
2424

2525

26-
class Pane(TmuxMappingObject, TmuxRelationalObject):
26+
class Pane(TmuxMappingObject):
2727
"""
2828
A :term:`tmux(1)` :term:`Pane` [pane_manual]_.
2929

libtmux/server.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@
2424

2525
logger = logging.getLogger(__name__)
2626

27+
if t.TYPE_CHECKING:
28+
from libtmux.common import SessionDict
29+
from libtmux.window import Window
2730

28-
class Server(TmuxRelationalObject, EnvironmentMixin):
31+
32+
class Server(TmuxRelationalObject["Session", "SessionDict"], EnvironmentMixin):
2933

3034
"""
3135
The :term:`tmux(1)` :term:`Server` [server_manual]_.
@@ -193,7 +197,7 @@ def sessions(self) -> t.List[Session]:
193197
return self.list_sessions()
194198

195199
#: Alias :attr:`sessions` for :class:`~libtmux.common.TmuxRelationalObject`
196-
children = sessions
200+
children = sessions # type: ignore
197201

198202
def _list_windows(self) -> t.List[WindowDict]:
199203
"""

libtmux/session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
logger = logging.getLogger(__name__)
3333

3434

35-
class Session(TmuxMappingObject, TmuxRelationalObject, EnvironmentMixin):
35+
class Session(
36+
TmuxMappingObject, TmuxRelationalObject["Window", "WindowDict"], EnvironmentMixin
37+
):
3638
"""
3739
A :term:`tmux(1)` :term:`Session` [session_manual]_.
3840
@@ -319,7 +321,7 @@ def windows(self) -> t.List[Window]:
319321
return self.list_windows()
320322

321323
#: Alias :attr:`windows` for :class:`~libtmux.common.TmuxRelationalObject`
322-
children = windows
324+
children = windows # type: ignore # mypy#1362
323325

324326
@property
325327
def attached_window(self) -> Window:

libtmux/window.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,14 @@
2424
)
2525

2626
if t.TYPE_CHECKING:
27+
from .common import PaneDict
2728
from .server import Server
2829
from .session import Session
2930

3031
logger = logging.getLogger(__name__)
3132

3233

33-
class Window(TmuxMappingObject, TmuxRelationalObject):
34+
class Window(TmuxMappingObject, TmuxRelationalObject["Pane", "PaneDict"]):
3435
"""
3536
A :term:`tmux(1)` :term:`Window` [window_manual]_.
3637
@@ -550,4 +551,4 @@ def panes(self) -> t.List[Pane]:
550551
return self.list_panes()
551552

552553
#: Alias :attr:`panes` for :class:`~libtmux.common.TmuxRelationalObject`
553-
children = panes
554+
children = panes # type:ignore # mypy#1362

tests/test_tmuxobject.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,14 @@ def test_where(server: Server, session: Session) -> None:
109109
session_name = session.get("session_name")
110110
assert session_name is not None
111111

112-
where = server.where({"session_id": session_id, "session_name": session_name})
112+
server_sessions = server.where(
113+
{"session_id": session_id, "session_name": session_name}
114+
)
113115

114-
assert len(where) == 1
115-
assert isinstance(where, list)
116-
assert where[0] == session
117-
assert isinstance(where[0], Session)
116+
assert len(server_sessions) == 1
117+
assert isinstance(server_sessions, list)
118+
assert server_sessions[0] == session
119+
assert isinstance(server_sessions[0], Session)
118120

119121
# session.where
120122
for window in session.windows:
@@ -124,14 +126,14 @@ def test_where(server: Server, session: Session) -> None:
124126
window_index = window.get("window_index")
125127
assert window_index is not None
126128

127-
where = session.where(
129+
session_windows = session.where(
128130
{"window_id": window_id, "window_index": window_index}
129131
)
130132

131-
assert len(where) == 1
132-
assert isinstance(where, list)
133-
assert where[0] == window
134-
assert isinstance(where[0], Window)
133+
assert len(session_windows) == 1
134+
assert isinstance(session_windows, list)
135+
assert session_windows[0] == window
136+
assert isinstance(session_windows[0], Window)
135137

136138
# window.where
137139
for pane in window.panes:
@@ -141,12 +143,12 @@ def test_where(server: Server, session: Session) -> None:
141143
pane_tty = pane.get("pane_tty")
142144
assert pane_tty is not None
143145

144-
where = window.where({"pane_id": pane_id, "pane_tty": pane_tty})
146+
window_panes = window.where({"pane_id": pane_id, "pane_tty": pane_tty})
145147

146-
assert len(where) == 1
147-
assert isinstance(where, list)
148-
assert where[0] == pane
149-
assert isinstance(where[0], Pane)
148+
assert len(window_panes) == 1
149+
assert isinstance(window_panes, list)
150+
assert window_panes[0] == pane
151+
assert isinstance(window_panes[0], Pane)
150152

151153

152154
def test_get_by_id(server: Server, session: Session) -> None:

0 commit comments

Comments
 (0)