Skip to content

Commit 8888b66

Browse files
Merge 5b5a2a5 into 8584a2f
2 parents 8584a2f + 5b5a2a5 commit 8888b66

File tree

4 files changed

+25
-6
lines changed

4 files changed

+25
-6
lines changed

dpnp/backend/extensions/lapack/getrs.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ std::pair<sycl::event, sycl::event>
166166
const dpctl::tensor::usm_ndarray &a_array,
167167
const dpctl::tensor::usm_ndarray &ipiv_array,
168168
const dpctl::tensor::usm_ndarray &b_array,
169+
const int trans_code,
169170
const std::vector<sycl::event> &depends)
170171
{
171172
const int a_array_nd = a_array.get_ndim();
@@ -264,11 +265,20 @@ std::pair<sycl::event, sycl::event>
264265
const std::int64_t lda = std::max<size_t>(1UL, n);
265266
const std::int64_t ldb = std::max<size_t>(1UL, n);
266267

267-
// Use transpose::T if the LU-factorized array is passed as C-contiguous.
268-
// For F-contiguous we use transpose::N.
269-
oneapi::mkl::transpose trans = is_a_array_c_contig
270-
? oneapi::mkl::transpose::T
271-
: oneapi::mkl::transpose::N;
268+
oneapi::mkl::transpose trans;
269+
switch (trans_code) {
270+
case 0:
271+
trans = oneapi::mkl::transpose::N;
272+
break;
273+
case 1:
274+
trans = oneapi::mkl::transpose::T;
275+
break;
276+
case 2:
277+
trans = oneapi::mkl::transpose::C;
278+
break;
279+
default:
280+
throw py::value_error("`trans_code` must be 0 (N), 1 (T), or 2 (C)");
281+
}
272282

273283
char *a_array_data = a_array.get_data();
274284
char *b_array_data = b_array.get_data();

dpnp/backend/extensions/lapack/getrs.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ extern std::pair<sycl::event, sycl::event>
3737
const dpctl::tensor::usm_ndarray &a_array,
3838
const dpctl::tensor::usm_ndarray &ipiv_array,
3939
const dpctl::tensor::usm_ndarray &b_array,
40+
const int trans_code,
4041
const std::vector<sycl::event> &depends = {});
4142

4243
extern void init_getrs_dispatch_vector(void);

dpnp/backend/extensions/lapack/lapack_py.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,8 @@ PYBIND11_MODULE(_lapack_impl, m)
160160
"the solves of linear equations with an LU-factored "
161161
"square coefficient matrix, with multiple right-hand sides",
162162
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
163-
py::arg("b_array"), py::arg("depends") = py::list());
163+
py::arg("b_array"), py::arg("trans_code"),
164+
py::arg("depends") = py::list());
164165

165166
m.def("_orgqr_batch", &lapack_ext::orgqr_batch,
166167
"Call `_orgqr_batch` from OneMKL LAPACK library to return "

dpnp/linalg/dpnp_utils_linalg.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2632,6 +2632,12 @@ def dpnp_solve(a, b):
26322632
_manager = dpu.SequentialOrderManager[exec_q]
26332633
dev_evs = _manager.submitted_events
26342634

2635+
# TODO: remove after PR #2558 is merged
2636+
# Temporarily set trans_code=1 (transpose) because the LU-factorized
2637+
# array is C-contiguous.
2638+
# For F-contiguous arrays use 0 (non-transpose)
2639+
trans_code = 1
2640+
26352641
# use DPCTL tensor function to fill the сopy of the input array
26362642
# from the input array
26372643
ht_ev, a_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
@@ -2688,6 +2694,7 @@ def dpnp_solve(a, b):
26882694
a_h.get_array(),
26892695
ipiv_h.get_array(),
26902696
b_h.get_array(),
2697+
trans_code,
26912698
depends=[b_copy_ev, getrf_ev],
26922699
)
26932700
_manager.add_event_pair(ht_ev, getrs_ev)

0 commit comments

Comments
 (0)