31namespace nda::lapack {
36 template <
bool run_on_device>
37 int getrs_batch_impl(
auto const &a,
auto &b,
auto const &ipiv,
char op) {
42 auto const [m, n, n_b] = a_arr.shape();
43 auto const [k, nrhs, n_b_2] = b.shape();
46 EXPECTS(n_b == n_b_2);
47 EXPECTS(ipiv.extent(0) == n);
48 EXPECTS(ipiv.extent(1) == n_b);
51 EXPECTS(a_arr.indexmap().min_stride() == 1);
52 EXPECTS(b.indexmap().min_stride() == 1);
53 EXPECTS(ipiv.indexmap().min_stride() == 1);
57 if constexpr (run_on_device) {
60 blas::device::getrs_batch(op, n, nrhs, a_ptrs.data(),
get_ld(a_arr(range::all, range::all, 0)), ipiv.data(), b_ptrs.data(),
61 get_ld(b(range::all, range::all, 0)), info, n_b);
64 for (
int i = 0; i < n_b; ++i) {
65 auto a_i = a_arr(range::all, range::all, i);
66 auto b_i = b(range::all, range::all, i);
67 auto ipiv_i = ipiv(range::all, i);
70 if (local_info != 0 && info == 0) info = local_info;
112 template <BlasArrayOrConj<3> A, BlasArrayFor<A, 3> B, PivotArrayFor<A, 2> IPIV>
123 return detail::getrs_batch_impl<run_on_device>(
transpose(a), std::forward<B>(b), ipiv, op);
125 return detail::getrs_batch_impl<run_on_device>(a, std::forward<B>(b), ipiv, op);
136 template <BlasArrayOrConj<3> A, BlasArrayFor<A, 3> B, PivotArrayFor<A, 2> IPIV>
138 int getrs(A
const &a, B &&b, IPIV
const &ipiv) {
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides a C++ interface for various BLAS routines.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a generic interface to the LAPACK/cuSOLVER getrs routine.
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
decltype(auto) to_device(A &&a)
Convert an nda::MemoryArray to its regular type on device memory.
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
static constexpr bool is_conj_array_expr
Constexpr variable that is true if the given type is a conjugate lazy expression.
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
auto batch_ptrs(A &&a)
Given a 2- or 3-dimensional array get an array of pointers to each of the submatrices/subvectors inde...
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
int get_ncols(A const &a)
Get the number of columns of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
int getrs_batch(A const &a, B &&b, IPIV const &ipiv)
Interface to batched versions of the LAPACK/cuSOLVER getrs routine.
int getrs(A const &a, B &&b, IPIV const &ipiv)
Interface to the LAPACK/cuSOLVER getrs routine.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Provides a C++ interface for various LAPACK routines.
Macros used in the nda library.
Provides type traits for the nda library.