30namespace nda::lapack {
35 template <
bool run_on_device>
36 auto getri_batch_impl(
auto &&a,
auto const &ipiv, [[maybe_unused]]
auto &&work) {
38 auto const [m, n, n_b] = a.shape();
40 EXPECTS(ipiv.extent(0) == n);
41 EXPECTS(ipiv.extent(1) == n_b);
44 EXPECTS(a.indexmap().min_stride() == 1);
45 EXPECTS(ipiv.indexmap().min_stride() == 1);
49 if constexpr (run_on_device) {
50 using arr_t = std::remove_cvref_t<
decltype(a)>;
54 EXPECTS(work.indexmap().min_stride() == 1);
62 blas::device::getri_batch(n, a_ptrs.data(),
get_ld(a(range::all, range::all, 0)), ipiv.data(), c_ptrs.data(),
63 get_ld(c(range::all, range::all, 0)), info_d.data(), n_b);
70 for (
int i = 0; i < n_b; ++i) {
71 auto a_i = a(range::all, range::all, i);
72 auto ipiv_i = ipiv(range::all, i);
73 info(i) =
getri(a_i, ipiv_i, work);
109 template <BlasArray<3> A, PivotArrayFor<A, 2> IPIV, BlasArrayFor<A, 1> W = vector_value_t<A>>
119 return detail::getri_batch_impl<run_on_device>(
transpose(a), ipiv, std::forward<W>(work));
121 return detail::getri_batch_impl<run_on_device>(std::forward<A>(a), ipiv, std::forward<W>(work));
132 template <BlasArray<3> A, PivotArrayFor<A, 2> IPIV, BlasArrayFor<A, 1> W = vector_value_t<A>>
135 return getri_batch(std::forward<A>(a), ipiv, std::forward<W>(work));
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides a C++ interface for various BLAS routines.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a generic interface to the LAPACK getri routine.
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
decltype(auto) to_device(A &&a)
Convert an nda::MemoryArray to its regular type on device memory.
basic_array< ValueType, Rank, Layout, 'A', ContainerPolicy > array
Alias template of an nda::basic_array with an 'A' algebra.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
basic_array_view< ValueType, Rank, Layout, 'A', default_accessor, borrowed< mem::Device > > cuarray_view
Similar to nda::array_view except the memory is stored on the device.
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
auto batch_ptrs(A &&a)
Given a 2- or 3-dimensional array get an array of pointers to each of the submatrices/subvectors inde...
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
vector< get_value_t< A >, heap< mem::get_addr_space< A > > > vector_value_t
Alias for an nda::vector with the same value type and address space as the given type.
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
int getri(A &&a, IPIV const &ipiv, W &&work=vector_value_t< A >{})
Interface to the LAPACK getri routine.
auto getri_batch(A &&a, IPIV const &ipiv, W &&work=vector_value_t< A >{})
Interface to batched versions of the LAPACK/cuSOLVER getri routine.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Provides a C++ interface for various LAPACK routines.
Macros used in the nda library.
Provides type traits for the nda library.