Skip to content

Commit ed51238

Browse files
rlizzolexierule
authored andcommitted
partial cherry pick of 2f7daac
1 parent b3f6977 commit ed51238

File tree

10 files changed

+359
-49
lines changed

10 files changed

+359
-49
lines changed

src/lightning_app/testing/testing.py

Lines changed: 48 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,29 @@
11
import asyncio
22
import json
3+
import logging
34
import os
45
import shutil
56
import sys
67
import tempfile
78
import time
9+
import traceback
810
from contextlib import contextmanager
911
from subprocess import Popen
1012
from time import sleep
11-
from typing import Any, Callable, Dict, Generator, List, Type
13+
from typing import Any, Callable, Dict, Generator, List, Optional, Type
1214

1315
import requests
1416
from lightning_cloud.openapi.rest import ApiException
1517
from requests import Session
1618
from rich import print
19+
from rich.color import ANSI_COLOR_NAMES
1720

1821
from lightning_app import LightningApp, LightningFlow
1922
from lightning_app.cli.lightning_cli import run_app
2023
from lightning_app.core.constants import LIGHTNING_CLOUD_PROJECT_ID
2124
from lightning_app.runners.multiprocess import MultiProcessRuntime
2225
from lightning_app.testing.config import Config
26+
from lightning_app.utilities.app_logs import _app_logs_reader
2327
from lightning_app.utilities.cloud import _get_project
2428
from lightning_app.utilities.enum import CacheCallsKeys
2529
from lightning_app.utilities.imports import _is_playwright_available, requires
@@ -31,6 +35,9 @@
3135
from playwright.sync_api import HttpCredentials, sync_playwright
3236

3337

38+
_logger = logging.getLogger(__name__)
39+
40+
3441
class LightningTestApp(LightningApp):
3542
def __init__(self, *args, **kwargs):
3643
super().__init__(*args, **kwargs)
@@ -259,20 +266,6 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator:
259266
var scrollingElement = (document.scrollingElement || document.body);
260267
scrollingElement.scrollTop = scrollingElement.scrollHeight;
261268
}, 200);
262-
263-
if (!window._logs) {
264-
window._logs = [];
265-
}
266-
267-
if (window.logTerminals) {
268-
Object.entries(window.logTerminals).forEach(
269-
([key, value]) => {
270-
window.logTerminals[key]._onLightningWritelnHandler = function (data) {
271-
window._logs = window._logs.concat([data]);
272-
}
273-
}
274-
);
275-
}
276269
"""
277270
)
278271

@@ -286,8 +279,46 @@ def run_app_in_cloud(app_folder: str, app_name: str = "app.py") -> Generator:
286279
except (playwright._impl._api_types.Error, playwright._impl._api_types.TimeoutError):
287280
pass
288281

289-
def fetch_logs() -> str:
290-
return admin_page.evaluate("window._logs;")
282+
client = LightningClient()
283+
project = _get_project(client)
284+
identifiers = []
285+
rich_colors = list(ANSI_COLOR_NAMES)
286+
287+
def fetch_logs(component_names: Optional[List[str]] = None) -> Generator:
288+
"""This methods creates websockets connection in threads and returns the logs to the main thread."""
289+
app_id = admin_page.url.split("/")[-1]
290+
291+
if not component_names:
292+
works = client.lightningwork_service_list_lightningwork(
293+
project_id=project.project_id,
294+
app_id=app_id,
295+
).lightningworks
296+
component_names = ["flow"] + [w.name for w in works]
297+
298+
def on_error_callback(ws_app, *_):
299+
print(traceback.print_exc())
300+
ws_app.close()
301+
302+
colors = {c: rich_colors[i + 1] for i, c in enumerate(component_names)}
303+
gen = _app_logs_reader(
304+
client=client,
305+
project_id=project.project_id,
306+
app_id=app_id,
307+
component_names=component_names,
308+
follow=False,
309+
on_error_callback=on_error_callback,
310+
)
311+
max_length = max(len(c.replace("root.", "")) for c in component_names)
312+
for log_event in gen:
313+
message = log_event.message
314+
identifier = f"{log_event.timestamp}{log_event.message}"
315+
if identifier not in identifiers:
316+
date = log_event.timestamp.strftime("%m/%d/%Y %H:%M:%S")
317+
identifiers.append(identifier)
318+
color = colors[log_event.component_name]
319+
padding = (max_length - len(log_event.component_name)) * " "
320+
print(f"[{color}]{log_event.component_name}{padding}[/{color}] {date} {message}")
321+
yield message
291322

292323
# 5. Print your application ID
293324
print(
@@ -300,11 +331,6 @@ def fetch_logs() -> str:
300331
pass
301332
finally:
302333
print("##################################################")
303-
printed_logs = []
304-
for log in fetch_logs():
305-
if log not in printed_logs:
306-
printed_logs.append(log)
307-
print(log.split("[0m")[-1])
308334
button = admin_page.locator('[data-cy="stop"]')
309335
try:
310336
button.wait_for(timeout=3 * 1000)
@@ -314,8 +340,6 @@ def fetch_logs() -> str:
314340
context.close()
315341
browser.close()
316342

317-
client = LightningClient()
318-
project = _get_project(client)
319343
list_lightningapps = client.lightningapp_instance_service_list_lightningapp_instances(project.project_id)
320344

321345
for lightningapp in list_lightningapps.lightningapps:
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import json
2+
import queue
3+
import sys
4+
from dataclasses import dataclass
5+
from datetime import datetime, timedelta
6+
from json import JSONDecodeError
7+
from threading import Thread
8+
from typing import Callable, Iterator, List, Optional
9+
10+
import dateutil.parser
11+
from websocket import WebSocketApp
12+
13+
from lightning_app.utilities.logs_socket_api import _LightningLogsSocketAPI
14+
from lightning_app.utilities.network import LightningClient
15+
16+
17+
@dataclass
18+
class _LogEventLabels:
19+
app: str
20+
container: str
21+
filename: str
22+
job: str
23+
namespace: str
24+
node_name: str
25+
pod: str
26+
stream: Optional[str] = None
27+
28+
29+
@dataclass
30+
class _LogEvent:
31+
message: str
32+
timestamp: datetime
33+
component_name: str
34+
labels: _LogEventLabels
35+
36+
def __ge__(self, other: "_LogEvent") -> bool:
37+
return self.timestamp >= other.timestamp
38+
39+
def __gt__(self, other: "_LogEvent") -> bool:
40+
return self.timestamp > other.timestamp
41+
42+
43+
def _push_log_events_to_read_queue_callback(component_name: str, read_queue: queue.PriorityQueue):
44+
"""Pushes _LogEvents from websocket to read_queue.
45+
46+
Returns callback function used with `on_message_callback` of websocket.WebSocketApp.
47+
"""
48+
49+
def callback(ws_app: WebSocketApp, msg: str):
50+
# We strongly trust that the contract on API will hold atm :D
51+
event_dict = json.loads(msg)
52+
labels = _LogEventLabels(**event_dict["labels"])
53+
54+
if "message" in event_dict:
55+
message = event_dict["message"]
56+
timestamp = dateutil.parser.isoparse(event_dict["timestamp"])
57+
event = _LogEvent(
58+
message=message,
59+
timestamp=timestamp,
60+
component_name=component_name,
61+
labels=labels,
62+
)
63+
read_queue.put(event)
64+
65+
return callback
66+
67+
68+
def _error_callback(ws_app: WebSocketApp, error: Exception):
69+
errors = {
70+
KeyError: "Malformed log message, missing key",
71+
JSONDecodeError: "Malformed log message",
72+
TypeError: "Malformed log format",
73+
ValueError: "Malformed date format",
74+
}
75+
print(f"Error while reading logs ({errors.get(type(error), 'Unknown')})", file=sys.stderr)
76+
ws_app.close()
77+
78+
79+
def _app_logs_reader(
80+
client: LightningClient,
81+
project_id: str,
82+
app_id: str,
83+
component_names: List[str],
84+
follow: bool,
85+
on_error_callback: Optional[Callable] = None,
86+
) -> Iterator[_LogEvent]:
87+
88+
read_queue = queue.PriorityQueue()
89+
logs_api_client = _LightningLogsSocketAPI(client.api_client)
90+
91+
# We will use a socket per component
92+
log_sockets = [
93+
logs_api_client.create_lightning_logs_socket(
94+
project_id=project_id,
95+
app_id=app_id,
96+
component=component_name,
97+
on_message_callback=_push_log_events_to_read_queue_callback(component_name, read_queue),
98+
on_error_callback=on_error_callback or _error_callback,
99+
)
100+
for component_name in component_names
101+
]
102+
103+
# And each socket on separate thread pushing log event to print queue
104+
# run_forever() will run until we close() the connection from outside
105+
log_threads = [Thread(target=work.run_forever) for work in log_sockets]
106+
107+
# Establish connection and begin pushing logs to the print queue
108+
for th in log_threads:
109+
th.start()
110+
111+
# Print logs from queue when log event is available
112+
user_log_start = "<<< BEGIN USER_RUN_FLOW SECTION >>>"
113+
start_timestamp = None
114+
115+
# Print logs from queue when log event is available
116+
try:
117+
while True:
118+
log_event = read_queue.get(timeout=None if follow else 1.0)
119+
if user_log_start in log_event.message:
120+
start_timestamp = log_event.timestamp + timedelta(seconds=0.5)
121+
122+
if start_timestamp and log_event.timestamp > start_timestamp:
123+
yield log_event
124+
125+
except queue.Empty:
126+
# Empty is raised by queue.get if timeout is reached. Follow = False case.
127+
pass
128+
129+
except KeyboardInterrupt:
130+
# User pressed CTRL+C to exit, we sould respect that
131+
pass
132+
133+
finally:
134+
# Close connections - it will cause run_forever() to finish -> thread as finishes aswell
135+
for socket in log_sockets:
136+
socket.close()
137+
138+
# Because all socket were closed, we can just wait for threads to finish.
139+
for th in log_threads:
140+
th.join()
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from typing import Callable, Optional
2+
from urllib.parse import urlparse
3+
4+
from lightning_cloud.openapi import ApiClient, AuthServiceApi, V1LoginRequest
5+
from websocket import WebSocketApp
6+
7+
from lightning_app.utilities.login import Auth
8+
9+
10+
class _LightningLogsSocketAPI:
11+
def __init__(self, api_client: ApiClient):
12+
self.api_client = api_client
13+
self._auth = Auth()
14+
self._auth.authenticate()
15+
self._auth_service = AuthServiceApi(api_client)
16+
17+
def _get_api_token(self) -> str:
18+
token_resp = self._auth_service.auth_service_login(
19+
body=V1LoginRequest(
20+
username=self._auth.username,
21+
api_key=self._auth.api_key,
22+
)
23+
)
24+
return token_resp.token
25+
26+
@staticmethod
27+
def _socket_url(host: str, project_id: str, app_id: str, token: str, component: str) -> str:
28+
return (
29+
f"wss://{host}/v1/projects/{project_id}/appinstances/{app_id}/logs?"
30+
f"token={token}&component={component}&follow=true"
31+
)
32+
33+
def create_lightning_logs_socket(
34+
self,
35+
project_id: str,
36+
app_id: str,
37+
component: str,
38+
on_message_callback: Callable[[WebSocketApp, str], None],
39+
on_error_callback: Optional[Callable[[Exception, str], None]] = None,
40+
) -> WebSocketApp:
41+
"""Creates and returns WebSocketApp to listen to lightning app logs.
42+
43+
.. code-block:: python
44+
# Synchronous reading, run_forever() is blocking
45+
def print_log_msg(ws_app, msg):
46+
print(msg)
47+
48+
49+
flow_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "flow", print_log_msg)
50+
flow_socket.run_forever()
51+
.. code-block:: python
52+
# Asynchronous reading (with Threads)
53+
def print_log_msg(ws_app, msg):
54+
print(msg)
55+
56+
57+
flow_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "flow", print_log_msg)
58+
work_logs_socket = client.create_lightning_logs_socket("project_id", "app_id", "work_1", print_log_msg)
59+
flow_logs_thread = Thread(target=flow_logs_socket.run_forever)
60+
work_logs_thread = Thread(target=work_logs_socket.run_forever)
61+
flow_logs_thread.start()
62+
work_logs_thread.start()
63+
# .......
64+
flow_logs_socket.close()
65+
work_logs_thread.close()
66+
Arguments:
67+
project_id: Project ID.
68+
app_id: Application ID.
69+
component: Component name eg flow.
70+
on_message_callback: Callback object which is called when received data.
71+
on_error_callback: Callback object which is called when we get error.
72+
Returns:
73+
WebSocketApp of the wanted socket
74+
"""
75+
_token = self._get_api_token()
76+
clean_ws_host = urlparse(self.api_client.configuration.host).netloc
77+
socket_url = self._socket_url(
78+
host=clean_ws_host,
79+
project_id=project_id,
80+
app_id=app_id,
81+
token=_token,
82+
component=component,
83+
)
84+
85+
return WebSocketApp(socket_url, on_message=on_message_callback, on_error=on_error_callback)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from datetime import datetime
2+
from unittest.mock import MagicMock
3+
4+
from lightning_app.utilities.app_logs import _LogEvent
5+
6+
7+
def test_log_event():
8+
event_1 = _LogEvent("", datetime.now(), MagicMock(), MagicMock())
9+
event_2 = _LogEvent("", datetime.now(), MagicMock(), MagicMock())
10+
assert event_1 < event_2
11+
assert event_1 <= event_2

0 commit comments

Comments
 (0)