Skip to content

Commit ade6be2

Browse files
committed
TST: Add one for the update sequence
1 parent f921ea9 commit ade6be2

File tree

1 file changed

+283
-0
lines changed

1 file changed

+283
-0
lines changed

test/test_scs_concurrent_solve.py

Lines changed: 283 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,283 @@
1+
import pytest
2+
import threading
3+
import numpy as np
4+
import scipy.sparse as sp
5+
import scs
6+
import sys
7+
from numpy.testing import assert_almost_equal
8+
import time
9+
import queue
10+
from concurrent.futures import ThreadPoolExecutor, as_completed
11+
import gen_random_cone_prob as tools
12+
13+
# --- Global constant ---
14+
FAIL = "failure"
15+
16+
SHARED_DATA_FOR_TEST = {
17+
"A": sp.csc_matrix([1.0, -1.0]).T.tocsc(),
18+
"b": np.array([1.0, 0.0]),
19+
"c": np.array([-1.0])
20+
}
21+
SHARED_CONE_CONFIG_FOR_TEST = {"q": [], "l": 2}
22+
EXPECTED_X0_FOR_SHARED_PROBLEM_TEST = 1.0
23+
NUM_CONCURRENT_SOLVES=8
24+
25+
# Cone definition
26+
K_CONFIG = {
27+
"z": 5,
28+
"l": 10,
29+
"q": [3, 4],
30+
"s": [2, 3],
31+
"ep": 4,
32+
"ed": 4,
33+
"p": [-0.25, 0.5],
34+
}
35+
36+
SOLVER_PARAMS_CONFIG = {
37+
"verbose": False,
38+
"eps_abs": 1e-5,
39+
"eps_rel": 1e-5,
40+
"eps_infeas": 1e-5,
41+
"max_iters": 3500,
42+
}
43+
44+
UPDATE_TEST_C_NEW = np.array([1.0])
45+
UPDATE_TEST_B_NEW = np.array([1.0, 1.0])
46+
47+
EXPECTED_X1_UPDATE = 1.0
48+
EXPECTED_X2_UPDATE = 0.0
49+
EXPECTED_X3_UPDATE = -1.0
50+
51+
# --- Worker function executed by each thread ---
52+
def solve_one_random_cone_problem(cone_def, solver_params_def, worker_id):
53+
"""
54+
Generates a random feasible cone problem, solves it with SCS, and performs assertions.
55+
This function is intended to be run in a separate thread.
56+
Returns True on success, raises AssertionError on failure.
57+
"""
58+
thread_name = threading.current_thread().name
59+
print(f"[Worker {worker_id} on {thread_name}]")
60+
61+
m_dims = tools.get_scs_cone_dims(cone_def)
62+
n_vars = m_dims // 2
63+
if n_vars == 0: n_vars = 1
64+
65+
# Generate a new feasible problem for each worker
66+
data, p_star_expected = tools.gen_feasible(cone_def, n=n_vars, density=0.2)
67+
68+
print(f"[Worker {worker_id} on {thread_name}]: Problem generated. m={m_dims}, n={n_vars}. Expected p_star ~ {p_star_expected:.4f}")
69+
70+
# Create and run the SCS solver
71+
solver = scs.SCS(data, cone_def, use_indirect=False, gpu=False, **solver_params_def)
72+
sol = solver.solve()
73+
x_sol = sol["x"]
74+
y_sol = sol["y"]
75+
s_sol = sol["s"]
76+
info = sol["info"]
77+
78+
print(f"[Worker {worker_id} on {thread_name}]: Solved. Status: {info['status']}. Pobj: {info['pobj']:.4f}, Iters: {info['iter']}")
79+
80+
# Assertions (similar to test_solve_feasible)
81+
# 1. Objective value
82+
np.testing.assert_almost_equal(np.dot(data["c"], x_sol), p_star_expected, decimal=2,
83+
err_msg=f"Worker {worker_id}: Objective value mismatch.")
84+
85+
# 2. Primal feasibility (Ax - b + s = 0 => ||Ax - b + s|| ~ 0)
86+
# Relaxed tolerance from 1e-3 to 5e-3
87+
primal_residual_norm = np.linalg.norm(data["A"] @ x_sol - data["b"] + s_sol)
88+
np.testing.assert_array_less(primal_residual_norm, 5e-3,
89+
err_msg=f"Worker {worker_id}: Primal residual norm too high: {primal_residual_norm}")
90+
91+
# 3. Dual feasibility (A'y + c = 0 => ||A'y + c|| ~ 0 for LP part, more complex for cones)
92+
# Relaxed tolerance from 1e-3 to 5e-3
93+
dual_residual_norm = np.linalg.norm(data["A"].T @ y_sol + data["c"])
94+
np.testing.assert_array_less(dual_residual_norm, 5e-3,
95+
err_msg=f"Worker {worker_id}: Dual residual norm too high: {dual_residual_norm}")
96+
97+
# 4. Complementary slackness (s'y ~ 0)
98+
complementarity = s_sol.T @ y_sol
99+
np.testing.assert_almost_equal(complementarity, 0.0, decimal=3, # Check if close to zero
100+
err_msg=f"Worker {worker_id}: Complementary slackness violation: {complementarity}")
101+
102+
# 5. Slack variable s in primal cone K (s = proj_K(s))
103+
projected_s = tools.proj_cone(s_sol, cone_def)
104+
np.testing.assert_almost_equal(s_sol, projected_s, decimal=3,
105+
err_msg=f"Worker {worker_id}: Slack variable s not in primal cone.")
106+
107+
# 6. Dual variable y in dual cone K* (y = proj_K*(y))
108+
projected_y_dual = tools.proj_dual_cone(y_sol, cone_def)
109+
np.testing.assert_almost_equal(y_sol, projected_y_dual, decimal=3,
110+
err_msg=f"Worker {worker_id}: Dual variable y not in dual cone.")
111+
112+
print(f"[Worker {worker_id} on {thread_name}]: All assertions passed.")
113+
return {"id": worker_id, "status": "success", "pobj": info['pobj'], "iters": info['iter']}
114+
115+
# --- Pytest test function using ThreadPoolExecutor ---
116+
pytest.mark.skipif(sys._is_gil_enabled(), "Only for free threaded")
117+
def test_concurrent_independent_cone_solves():
118+
"""
119+
Tests running multiple independent SCS solves concurrently using ThreadPoolExecutor.
120+
Each solve uses the provided use_indirect and gpu flags.
121+
"""
122+
completed_solves = 0
123+
failed_solves_details = []
124+
125+
with ThreadPoolExecutor(max_workers=NUM_CONCURRENT_SOLVES) as executor:
126+
futures = []
127+
for i in range(NUM_CONCURRENT_SOLVES):
128+
worker_id = i + 1
129+
future = executor.submit(
130+
solve_one_random_cone_problem,
131+
K_CONFIG,
132+
SOLVER_PARAMS_CONFIG,
133+
worker_id
134+
)
135+
futures.append(future)
136+
print(f"pytest: Submitted task for worker {worker_id}.")
137+
138+
print(f"\npytest: All {NUM_CONCURRENT_SOLVES} tasks submitted. Waiting for completion...\n")
139+
140+
for future in as_completed(futures, timeout=NUM_CONCURRENT_SOLVES * 60.0):
141+
# Determine worker_id based on the future object's position in the original list.
142+
# This is a bit fragile if futures list were modified, but common for simple cases.
143+
# A more robust way would be to wrap future with its ID if needed for complex scenarios.
144+
worker_id_from_future = -1 # Default / placeholder
145+
for idx, f_item in enumerate(futures):
146+
if f_item == future:
147+
worker_id_from_future = idx + 1
148+
break
149+
150+
try:
151+
result = future.result(timeout=60.0)
152+
print(f"pytest: Worker {result.get('id', worker_id_from_future)} completed successfully: {result}")
153+
completed_solves += 1
154+
except Exception as e:
155+
error_detail = f"Worker {worker_id_from_future} failed: {type(e).__name__}: {e}"
156+
print(f"pytest: ERROR - {error_detail}")
157+
failed_solves_details.append(error_detail)
158+
159+
print(f"\npytest: Test execution finished.")
160+
print(f"Total solves attempted: {NUM_CONCURRENT_SOLVES}")
161+
print(f"Successful solves: {completed_solves}")
162+
print(f"Failed solves: {len(failed_solves_details)}")
163+
164+
if failed_solves_details:
165+
pytest.fail(f"{len(failed_solves_details)} out of {NUM_CONCURRENT_SOLVES} concurrent solves failed.\n"
166+
f"Failures:\n" + "\n".join(failed_solves_details))
167+
168+
assert completed_solves == NUM_CONCURRENT_SOLVES, \
169+
f"Expected {NUM_CONCURRENT_SOLVES} successful concurrent solves, but got {completed_solves}."
170+
171+
print(f"pytest: All {NUM_CONCURRENT_SOLVES} concurrent solves passed.")
172+
173+
def worker_perform_solve_update_sequence(solver_params_def, worker_id):
174+
"""
175+
Performs a sequence of solve and update operations on an SCS instance.
176+
"""
177+
thread_name = threading.current_thread().name
178+
print(f"[Worker {worker_id} (UpdateSeq) on {thread_name}]: Starting")
179+
180+
solver = scs.SCS(SHARED_DATA_FOR_TEST, SHARED_CONE_CONFIG_FOR_TEST,
181+
use_indirect=False, gpu=False, **solver_params_def)
182+
183+
print(f"[Worker {worker_id} (UpdateSeq) on {thread_name}]: Performing initial solve.")
184+
sol1 = solver.solve()
185+
np.testing.assert_almost_equal(sol1["x"][0], EXPECTED_X1_UPDATE, decimal=2,
186+
err_msg=f"Worker {worker_id} (UpdateSeq): Initial solve failed.")
187+
print(f"[Worker {worker_id} (UpdateSeq) on {thread_name}]: Initial solve OK, x={sol1['x'][0]:.2f}")
188+
189+
print(f"[Worker {worker_id} (UpdateSeq) on {thread_name}]: Updating c and solving.")
190+
solver.update(c=UPDATE_TEST_C_NEW)
191+
sol2 = solver.solve()
192+
np.testing.assert_almost_equal(sol2["x"][0], EXPECTED_X2_UPDATE, decimal=2,
193+
err_msg=f"Worker {worker_id} (UpdateSeq): Solve after c update failed.")
194+
print(f"[Worker {worker_id} (UpdateSeq) on {thread_name}]: Solve after c update OK, x={sol2['x'][0]:.2f}")
195+
196+
print(f"[Worker {worker_id} (UpdateSeq) on {thread_name}]: Updating b and solving.")
197+
solver.update(b=UPDATE_TEST_B_NEW)
198+
sol3 = solver.solve()
199+
np.testing.assert_almost_equal(sol3["x"][0], EXPECTED_X3_UPDATE, decimal=2,
200+
err_msg=f"Worker {worker_id} (UpdateSeq): Solve after b update failed.")
201+
print(f"[Worker {worker_id} (UpdateSeq) on {thread_name}]: Solve after b update OK, x={sol3['x'][0]:.2f}")
202+
203+
print(f"[Worker {worker_id} (UpdateSeq) on {thread_name}]: All update sequence assertions passed.")
204+
return {"id": worker_id, "type": "UpdateSeq", "status": "success"}
205+
206+
207+
# --- Test for Concurrent Solve and Update Sequences ---
208+
pytest.mark.skipif(sys._is_gil_enabled(), "Only for free threaded")
209+
def test_concurrent_solve_update_sequences():
210+
"""
211+
Tests running multiple SCS solve-update-solve sequences concurrently.
212+
"""
213+
print(f"\npytest: Starting concurrent solve-update sequences test (use_indirect=False, gpu=False)")
214+
215+
completed_jobs = 0
216+
failed_jobs_details = []
217+
218+
with ThreadPoolExecutor(max_workers=NUM_CONCURRENT_SOLVES) as executor:
219+
futures = []
220+
for i in range(NUM_CONCURRENT_SOLVES):
221+
worker_id = i + 1
222+
future = executor.submit(
223+
worker_perform_solve_update_sequence,
224+
SOLVER_PARAMS_CONFIG, worker_id
225+
)
226+
futures.append(future)
227+
print(f"pytest: Submitted task for UpdateSeq worker {worker_id}.")
228+
229+
print(f"\npytest: All {NUM_CONCURRENT_SOLVES} UpdateSeq tasks submitted. Waiting for completion...\n")
230+
for future in as_completed(futures, timeout=NUM_CONCURRENT_SOLVES * 30.0):
231+
worker_id_from_future = futures.index(future) + 1
232+
try:
233+
result = future.result(timeout=30.0)
234+
print(f"pytest: UpdateSeq Worker {result.get('id', worker_id_from_future)} completed successfully: {result}")
235+
completed_jobs += 1
236+
except Exception as e:
237+
error_detail = f"UpdateSeq Worker {worker_id_from_future} failed: {type(e).__name__}: {e}"
238+
print(f"pytest: ERROR - {error_detail}")
239+
failed_jobs_details.append(error_detail)
240+
241+
print(f"\npytest: UpdateSeq test execution finished.")
242+
print(f"Total UpdateSeq jobs attempted: {NUM_CONCURRENT_SOLVES}, Successful: {completed_jobs}, Failed: {len(failed_jobs_details)}")
243+
244+
if failed_jobs_details:
245+
pytest.fail(f"{len(failed_jobs_details)} out of {NUM_CONCURRENT_SOLVES} concurrent UpdateSeq jobs failed.\nFailures:\n" + "\n".join(failed_jobs_details))
246+
assert completed_jobs == NUM_CONCURRENT_SOLVES, f"Expected {NUM_CONCURRENT_SOLVES} successful UpdateSeq jobs, got {completed_jobs}."
247+
print(f"pytest: All {NUM_CONCURRENT_SOLVES} concurrent UpdateSeq jobs passed.")
248+
249+
# --- Worker function for threads ---
250+
def worker_solve_on_shared_instance(test_id, shared_solver_instance, expected_x0, results_queue):
251+
"""
252+
Attempts to call solve() on a shared SCS solver instance.
253+
Reports result or exception to the main thread via a queue.
254+
"""
255+
print(f"[Thread {test_id}]: Attempting to call solve() on the shared solver instance.")
256+
try:
257+
sol = shared_solver_instance.solve(warm_start=False, x=None, y=None, s=None)
258+
if sol["info"]["status"] != "solved":
259+
# Report failure status
260+
results_queue.put({
261+
"id": test_id,
262+
"status": "solver_fail_status",
263+
"info": sol["info"],
264+
"x": sol.get("x")
265+
})
266+
print(f"[Thread {test_id}]: Solver status: {sol['info']['status']}.")
267+
return
268+
269+
assert_almost_equal(sol["x"][0], expected_x0, decimal=2)
270+
results_queue.put({
271+
"id": test_id,
272+
"status": "success",
273+
"x0": sol["x"][0],
274+
"info": sol["info"]
275+
})
276+
print(f"[Thread {test_id}]: Call to solve() completed. Expected x[0] ~ {expected_x0}, Got x[0] ~ {sol['x'][0]:.2f}.")
277+
278+
except AssertionError as e:
279+
results_queue.put({"id": test_id, "status": "assertion_error", "error": e})
280+
print(f"[Thread {test_id}]: TEST FAILED (result inconsistent). Assertion Error: {e}")
281+
except Exception as e:
282+
results_queue.put({"id": test_id, "status": "exception", "error": e, "type": type(e).__name__})
283+
print(f"[Thread {test_id}]: An unexpected error occurred: {type(e).__name__}: {e}.")

0 commit comments

Comments
 (0)