|
4 | 4 | from numba.core.extending import overload
|
5 | 5 | from numba.np.linalg import ensure_lapack
|
6 | 6 | from numpy import ndarray
|
7 |
| -from scipy import linalg |
8 | 7 |
|
9 | 8 | from pytensor.link.numba.dispatch.basic import numba_njit
|
10 | 9 | from pytensor.link.numba.dispatch.linalg._LAPACK import (
|
|
13 | 12 | int_ptr_to_val,
|
14 | 13 | val_to_int_ptr,
|
15 | 14 | )
|
16 |
| -from pytensor.link.numba.dispatch.linalg.solve.utils import _solve_check_input_shapes |
17 | 15 | from pytensor.link.numba.dispatch.linalg.utils import (
|
18 | 16 | _check_scipy_linalg_matrix,
|
19 | 17 | _copy_to_fortran_order_even_if_1d,
|
20 |
| - _solve_check, |
21 | 18 | _trans_char_to_int,
|
22 | 19 | )
|
23 | 20 |
|
@@ -227,72 +224,42 @@ def impl(
|
227 | 224 |
|
228 | 225 |
|
229 | 226 | def _solve_tridiagonal(
|
230 |
| - a: ndarray, |
231 |
| - b: ndarray, |
232 |
| - lower: bool, |
233 |
| - overwrite_a: bool, |
| 227 | + dl: ndarray, |
| 228 | + d: ndarray, |
| 229 | + ul: ndarray, |
| 230 | + B: ndarray, |
234 | 231 | overwrite_b: bool,
|
235 |
| - check_finite: bool, |
236 |
| - transposed: bool, |
237 | 232 | ):
|
238 | 233 | """
|
239 |
| - Solve a positive-definite linear system using the Cholesky decomposition. |
| 234 | + Solve a tridiagonal linear system. |
240 | 235 | """
|
241 |
| - return linalg.solve( |
242 |
| - a=a, |
243 |
| - b=b, |
244 |
| - lower=lower, |
245 |
| - overwrite_a=overwrite_a, |
246 |
| - overwrite_b=overwrite_b, |
247 |
| - check_finite=check_finite, |
248 |
| - transposed=transposed, |
249 |
| - assume_a="tridiagonal", |
250 |
| - ) |
| 236 | + return |
251 | 237 |
|
252 | 238 |
|
253 | 239 | @overload(_solve_tridiagonal)
|
254 |
| -def _tridiagonal_solve_impl( |
255 |
| - A: ndarray, |
| 240 | +def _solve_tridiagonal_impl( |
| 241 | + dl: ndarray, |
| 242 | + d: ndarray, |
| 243 | + du: ndarray, |
256 | 244 | B: ndarray,
|
257 |
| - lower: bool, |
258 |
| - overwrite_a: bool, |
259 | 245 | overwrite_b: bool,
|
260 |
| - check_finite: bool, |
261 |
| - transposed: bool, |
262 |
| -) -> Callable[[ndarray, ndarray, bool, bool, bool, bool, bool], ndarray]: |
| 246 | +) -> Callable[[ndarray, ndarray, ndarray, ndarray, bool], ndarray]: |
263 | 247 | ensure_lapack()
|
264 |
| - _check_scipy_linalg_matrix(A, "solve") |
| 248 | + _check_scipy_linalg_matrix(dl, "solve_") |
| 249 | + _check_scipy_linalg_matrix(dl, "solve") |
| 250 | + _check_scipy_linalg_matrix(dl, "solve") |
265 | 251 | _check_scipy_linalg_matrix(B, "solve")
|
266 | 252 |
|
267 | 253 | def impl(
|
268 |
| - A: ndarray, |
| 254 | + dl: ndarray, |
| 255 | + d: ndarray, |
| 256 | + du: ndarray, |
269 | 257 | B: ndarray,
|
270 |
| - lower: bool, |
271 |
| - overwrite_a: bool, |
272 | 258 | overwrite_b: bool,
|
273 |
| - check_finite: bool, |
274 |
| - transposed: bool, |
275 | 259 | ) -> ndarray:
|
276 |
| - n = np.int32(A.shape[-1]) |
277 |
| - _solve_check_input_shapes(A, B) |
278 |
| - norm = "1" |
279 |
| - |
280 |
| - if transposed: |
281 |
| - A = A.T |
282 |
| - dl, d, du = np.diag(A, -1), np.diag(A, 0), np.diag(A, 1) |
283 |
| - |
284 |
| - anorm = tridiagonal_norm(du, d, dl) |
285 |
| - |
286 |
| - dl, d, du, du2, IPIV, INFO = _gttrf(dl, d, du) |
287 |
| - _solve_check(n, INFO) |
288 |
| - |
289 |
| - X, INFO = _gttrs( |
290 |
| - dl, d, du, du2, IPIV, B, trans=transposed, overwrite_b=overwrite_b |
291 |
| - ) |
292 |
| - _solve_check(n, INFO) |
| 260 | + dl, d, du, du2, IPIV, _ = _gttrf(dl, d, du) |
293 | 261 |
|
294 |
| - RCOND, INFO = _gtcon(dl, d, du, du2, IPIV, anorm, norm) |
295 |
| - _solve_check(n, INFO, True, RCOND) |
| 262 | + X, _ = _gttrs(dl, d, du, du2, IPIV, B, trans=0, overwrite_b=overwrite_b) |
296 | 263 |
|
297 | 264 | return X
|
298 | 265 |
|
|
0 commit comments