6
6
from numpy import ndarray
7
7
from scipy import linalg
8
8
9
+ from pytensor .link .numba .dispatch import numba_funcify
9
10
from pytensor .link .numba .dispatch .basic import numba_njit
10
11
from pytensor .link .numba .dispatch .linalg ._LAPACK import (
11
12
_LAPACK ,
20
21
_solve_check ,
21
22
_trans_char_to_int ,
22
23
)
24
+ from pytensor .tensor ._linalg .solve .tridiagonal import (
25
+ LUFactorTridiagonal ,
26
+ SolveLUFactorTridiagonal ,
27
+ )
23
28
24
29
25
30
@numba_njit
@@ -34,7 +39,12 @@ def tridiagonal_norm(du, d, dl):
34
39
35
40
36
41
def _gttrf (
37
- dl : ndarray , d : ndarray , du : ndarray
42
+ dl : ndarray ,
43
+ d : ndarray ,
44
+ du : ndarray ,
45
+ overwrite_dl : bool ,
46
+ overwrite_d : bool ,
47
+ overwrite_du : bool ,
38
48
) -> tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ]:
39
49
"""Placeholder for LU factorization of tridiagonal matrix."""
40
50
return # type: ignore
@@ -45,6 +55,9 @@ def gttrf_impl(
45
55
dl : ndarray ,
46
56
d : ndarray ,
47
57
du : ndarray ,
58
+ overwrite_dl : bool ,
59
+ overwrite_d : bool ,
60
+ overwrite_du : bool ,
48
61
) -> Callable [
49
62
[ndarray , ndarray , ndarray ], tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ]
50
63
]:
@@ -60,12 +73,24 @@ def impl(
60
73
dl : ndarray ,
61
74
d : ndarray ,
62
75
du : ndarray ,
76
+ overwrite_dl : bool ,
77
+ overwrite_d : bool ,
78
+ overwrite_du : bool ,
63
79
) -> tuple [ndarray , ndarray , ndarray , ndarray , ndarray , int ]:
64
80
n = np .int32 (d .shape [- 1 ])
65
81
ipiv = np .empty (n , dtype = np .int32 )
66
82
du2 = np .empty (n - 2 , dtype = dtype )
67
83
info = val_to_int_ptr (0 )
68
84
85
+ if not overwrite_dl or not dl .flags .f_contiguous :
86
+ dl = dl .copy ()
87
+
88
+ if not overwrite_d or not d .flags .f_contiguous :
89
+ d = d .copy ()
90
+
91
+ if not overwrite_du or not du .flags .f_contiguous :
92
+ du = du .copy ()
93
+
69
94
numba_gttrf (
70
95
val_to_int_ptr (n ),
71
96
dl .view (w_type ).ctypes ,
@@ -133,10 +158,23 @@ def impl(
133
158
nrhs = 1 if b .ndim == 1 else int (b .shape [- 1 ])
134
159
info = val_to_int_ptr (0 )
135
160
136
- if overwrite_b and b .flags .f_contiguous :
137
- b_copy = b
138
- else :
139
- b_copy = _copy_to_fortran_order_even_if_1d (b )
161
+ if not overwrite_b or not b .flags .f_contiguous :
162
+ b = _copy_to_fortran_order_even_if_1d (b )
163
+
164
+ if not dl .flags .f_contiguous :
165
+ dl = dl .copy ()
166
+
167
+ if not d .flags .f_contiguous :
168
+ d = d .copy ()
169
+
170
+ if not du .flags .f_contiguous :
171
+ du = du .copy ()
172
+
173
+ if not du2 .flags .f_contiguous :
174
+ du2 = du2 .copy ()
175
+
176
+ if not ipiv .flags .f_contiguous :
177
+ ipiv = ipiv .copy ()
140
178
141
179
numba_gttrs (
142
180
val_to_int_ptr (_trans_char_to_int (trans )),
@@ -147,12 +185,12 @@ def impl(
147
185
du .view (w_type ).ctypes ,
148
186
du2 .view (w_type ).ctypes ,
149
187
ipiv .ctypes ,
150
- b_copy .view (w_type ).ctypes ,
188
+ b .view (w_type ).ctypes ,
151
189
val_to_int_ptr (n ),
152
190
info ,
153
191
)
154
192
155
- return b_copy , int_ptr_to_val (info )
193
+ return b , int_ptr_to_val (info )
156
194
157
195
return impl
158
196
@@ -283,7 +321,9 @@ def impl(
283
321
284
322
anorm = tridiagonal_norm (du , d , dl )
285
323
286
- dl , d , du , du2 , IPIV , INFO = _gttrf (dl , d , du )
324
+ dl , d , du , du2 , IPIV , INFO = _gttrf (
325
+ dl , d , du , overwrite_dl = True , overwrite_d = True , overwrite_du = True
326
+ )
287
327
_solve_check (n , INFO )
288
328
289
329
X , INFO = _gttrs (
@@ -297,3 +337,48 @@ def impl(
297
337
return X
298
338
299
339
return impl
340
+
341
+
342
+ @numba_funcify .register (LUFactorTridiagonal )
343
+ def numba_funcify_LUFactorTridiagonal (op : LUFactorTridiagonal , node , ** kwargs ):
344
+ overwrite_dl = op .overwrite_dl
345
+ overwrite_d = op .overwrite_d
346
+ overwrite_du = op .overwrite_du
347
+
348
+ @numba_njit (cache = False )
349
+ def lu_factor_tridiagonal (dl , d , du ):
350
+ dl , d , du , du2 , ipiv , _ = _gttrf (
351
+ dl ,
352
+ d ,
353
+ du ,
354
+ overwrite_dl = overwrite_dl ,
355
+ overwrite_d = overwrite_d ,
356
+ overwrite_du = overwrite_du ,
357
+ )
358
+ return dl , d , du , du2 , ipiv
359
+
360
+ return lu_factor_tridiagonal
361
+
362
+
363
+ @numba_funcify .register (SolveLUFactorTridiagonal )
364
+ def numba_funcify_SolveLUFactorTridiagonal (
365
+ op : SolveLUFactorTridiagonal , node , ** kwargs
366
+ ):
367
+ overwrite_b = op .overwrite_b
368
+ transposed = op .transposed
369
+
370
+ @numba_njit (cache = False )
371
+ def solve_lu_factor_tridiagonal (dl , d , du , du2 , ipiv , b ):
372
+ x , _ = _gttrs (
373
+ dl ,
374
+ d ,
375
+ du ,
376
+ du2 ,
377
+ ipiv ,
378
+ b ,
379
+ overwrite_b = overwrite_b ,
380
+ trans = transposed ,
381
+ )
382
+ return x
383
+
384
+ return solve_lu_factor_tridiagonal
0 commit comments