|
29 | 29 | V1QueueServerType, |
30 | 30 | V1SourceType, |
31 | 31 | V1UserRequestedComputeConfig, |
| 32 | + V1UserRequestedFlowComputeConfig, |
32 | 33 | V1Work, |
33 | 34 | ) |
34 | 35 |
|
|
37 | 38 | from lightning_app.storage import Drive, Mount |
38 | 39 | from lightning_app.utilities.cloud import _get_project |
39 | 40 | from lightning_app.utilities.dependency_caching import get_hash |
| 41 | +from lightning_app.utilities.packaging.cloud_compute import CloudCompute |
40 | 42 |
|
41 | 43 |
|
42 | 44 | class MyWork(LightningWork): |
@@ -66,6 +68,47 @@ def run(self): |
66 | 68 | class TestAppCreationClient: |
67 | 69 | """Testing the calls made using GridRestClient to create the app.""" |
68 | 70 |
|
| 71 | + @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock()) |
| 72 | + def test_run_with_custom_flow_compute_config(self, monkeypatch): |
| 73 | + mock_client = mock.MagicMock() |
| 74 | + mock_client.projects_service_list_memberships.return_value = V1ListMembershipsResponse( |
| 75 | + memberships=[V1Membership(name="test-project", project_id="test-project-id")] |
| 76 | + ) |
| 77 | + mock_client.lightningapp_instance_service_list_lightningapp_instances.return_value = ( |
| 78 | + V1ListLightningappInstancesResponse(lightningapps=[]) |
| 79 | + ) |
| 80 | + cloud_backend = mock.MagicMock() |
| 81 | + cloud_backend.client = mock_client |
| 82 | + monkeypatch.setattr(backends, "CloudBackend", mock.MagicMock(return_value=cloud_backend)) |
| 83 | + monkeypatch.setattr(cloud, "LocalSourceCodeDir", mock.MagicMock()) |
| 84 | + app = mock.MagicMock() |
| 85 | + app.flows = [] |
| 86 | + app.frontend = {} |
| 87 | + app.flow_cloud_compute = CloudCompute(name="t2.medium") |
| 88 | + cloud_runtime = cloud.CloudRuntime(app=app, entrypoint_file="entrypoint.py") |
| 89 | + cloud_runtime._check_uploaded_folder = mock.MagicMock() |
| 90 | + |
| 91 | + monkeypatch.setattr(Path, "is_file", lambda *args, **kwargs: False) |
| 92 | + monkeypatch.setattr(cloud, "Path", Path) |
| 93 | + cloud_runtime.dispatch() |
| 94 | + body = Body8( |
| 95 | + app_entrypoint_file=mock.ANY, |
| 96 | + enable_app_server=True, |
| 97 | + flow_servers=[], |
| 98 | + image_spec=None, |
| 99 | + works=[], |
| 100 | + local_source=True, |
| 101 | + dependency_cache_key=mock.ANY, |
| 102 | + user_requested_flow_compute_config=V1UserRequestedFlowComputeConfig( |
| 103 | + name="t2.medium", |
| 104 | + preemptible=False, |
| 105 | + shm_size=0, |
| 106 | + ), |
| 107 | + ) |
| 108 | + cloud_runtime.backend.client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( |
| 109 | + project_id="test-project-id", app_id=mock.ANY, body=body |
| 110 | + ) |
| 111 | + |
69 | 112 | @mock.patch("lightning_app.runners.backends.cloud.LightningClient", mock.MagicMock()) |
70 | 113 | def test_run_on_byoc_cluster(self, monkeypatch): |
71 | 114 | mock_client = mock.MagicMock() |
@@ -100,6 +143,7 @@ def test_run_on_byoc_cluster(self, monkeypatch): |
100 | 143 | works=[], |
101 | 144 | local_source=True, |
102 | 145 | dependency_cache_key=mock.ANY, |
| 146 | + user_requested_flow_compute_config=mock.ANY, |
103 | 147 | ) |
104 | 148 | cloud_runtime.backend.client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( |
105 | 149 | project_id="default-project-id", app_id=mock.ANY, body=body |
@@ -142,6 +186,7 @@ def test_requirements_file(self, monkeypatch): |
142 | 186 | works=[], |
143 | 187 | local_source=True, |
144 | 188 | dependency_cache_key=mock.ANY, |
| 189 | + user_requested_flow_compute_config=mock.ANY, |
145 | 190 | ) |
146 | 191 | cloud_runtime.backend.client.lightningapp_v2_service_create_lightningapp_release.assert_called_once_with( |
147 | 192 | project_id="test-project-id", app_id=mock.ANY, body=body |
@@ -264,6 +309,7 @@ def test_call_with_work_app(self, lightningapps, monkeypatch, tmpdir): |
264 | 309 | enable_app_server=True, |
265 | 310 | flow_servers=[], |
266 | 311 | dependency_cache_key=get_hash(requirements_file), |
| 312 | + user_requested_flow_compute_config=mock.ANY, |
267 | 313 | image_spec=Gridv1ImageSpec( |
268 | 314 | dependency_file_info=V1DependencyFileInfo( |
269 | 315 | package_manager=V1PackageManager.PIP, path="requirements.txt" |
@@ -431,6 +477,7 @@ def test_call_with_work_app_and_attached_drives(self, lightningapps, monkeypatch |
431 | 477 | enable_app_server=True, |
432 | 478 | flow_servers=[], |
433 | 479 | dependency_cache_key=get_hash(requirements_file), |
| 480 | + user_requested_flow_compute_config=mock.ANY, |
434 | 481 | image_spec=Gridv1ImageSpec( |
435 | 482 | dependency_file_info=V1DependencyFileInfo( |
436 | 483 | package_manager=V1PackageManager.PIP, path="requirements.txt" |
@@ -590,6 +637,7 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo |
590 | 637 | enable_app_server=True, |
591 | 638 | flow_servers=[], |
592 | 639 | dependency_cache_key=get_hash(requirements_file), |
| 640 | + user_requested_flow_compute_config=mock.ANY, |
593 | 641 | image_spec=Gridv1ImageSpec( |
594 | 642 | dependency_file_info=V1DependencyFileInfo( |
595 | 643 | package_manager=V1PackageManager.PIP, path="requirements.txt" |
@@ -623,6 +671,7 @@ def test_call_with_work_app_and_multiple_attached_drives(self, lightningapps, mo |
623 | 671 | enable_app_server=True, |
624 | 672 | flow_servers=[], |
625 | 673 | dependency_cache_key=get_hash(requirements_file), |
| 674 | + user_requested_flow_compute_config=mock.ANY, |
626 | 675 | image_spec=Gridv1ImageSpec( |
627 | 676 | dependency_file_info=V1DependencyFileInfo( |
628 | 677 | package_manager=V1PackageManager.PIP, path="requirements.txt" |
@@ -756,6 +805,7 @@ def test_call_with_work_app_and_attached_mount_and_drive(self, lightningapps, mo |
756 | 805 | package_manager=V1PackageManager.PIP, path="requirements.txt" |
757 | 806 | ) |
758 | 807 | ), |
| 808 | + user_requested_flow_compute_config=mock.ANY, |
759 | 809 | works=[ |
760 | 810 | V1Work( |
761 | 811 | name="test-work", |
|
0 commit comments