TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
getrs_batch.hpp
Go to the documentation of this file.
1// Copyright (c) 2021--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./getrs.hpp"
15#include "../basic_array.hpp"
17#include "../blas/tools.hpp"
18#include "../concepts.hpp"
19#include "../declarations.hpp"
20#include "../device.hpp"
22#include "../macros.hpp"
24#include "../traits.hpp"
25
26#include <algorithm>
27#include <tuple>
28#include <type_traits>
29#include <utility>
30
31namespace nda::lapack {
32
33 namespace detail {
34
35 // Implementation of the batched getrs routine.
36 template <bool run_on_device>
37 int getrs_batch_impl(auto const &a, auto &b, auto const &ipiv, char op) {
38 // get underlying array in case it is given as a lazy conjugate expression
39 auto &a_arr = get_array(a);
40
41 // check the dimensions of the input/output arrays/views
42 auto const [m, n, n_b] = a_arr.shape();
43 auto const [k, nrhs, n_b_2] = b.shape();
44 EXPECTS(m == n);
45 EXPECTS(n == k);
46 EXPECTS(n_b == n_b_2);
47 EXPECTS(ipiv.extent(0) == n);
48 EXPECTS(ipiv.extent(1) == n_b);
49
50 // arrays/views must be LAPACK compatible
51 EXPECTS(a_arr.indexmap().min_stride() == 1);
52 EXPECTS(b.indexmap().min_stride() == 1);
53 EXPECTS(ipiv.indexmap().min_stride() == 1);
54
55 // perform actual library call(s)
56 int info = 0;
57 if constexpr (run_on_device) {
58 auto a_ptrs = to_device(batch_ptrs(a_arr));
59 auto b_ptrs = to_device(batch_ptrs(b));
60 blas::device::getrs_batch(op, n, nrhs, a_ptrs.data(), get_ld(a_arr(range::all, range::all, 0)), ipiv.data(), b_ptrs.data(),
61 get_ld(b(range::all, range::all, 0)), info, n_b);
62 } else {
63 // for host, fall back to looping over batches
64 for (int i = 0; i < n_b; ++i) {
65 auto a_i = a_arr(range::all, range::all, i);
66 auto b_i = b(range::all, range::all, i);
67 auto ipiv_i = ipiv(range::all, i);
68 int local_info = 0;
69 f77::getrs(op, get_ncols(a_i), get_ncols(b_i), a_i.data(), get_ld(a_i), ipiv_i.data(), b_i.data(), get_ld(b_i), local_info);
70 if (local_info != 0 && info == 0) info = local_info;
71 }
72 }
73 return info;
74 }
75
76 } // namespace detail
77
112 template <BlasArrayOrConj<3> A, BlasArrayFor<A, 3> B, PivotArrayFor<A, 2> IPIV>
114 int getrs_batch(A const &a, B &&b, IPIV const &ipiv) { // NOLINT (temporary views are allowed here)
115 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A, B, IPIV>;
116
117 // transpose ipiv array/view if necessary
118 if constexpr (has_C_layout<IPIV>) return getrs_batch(a, b, transpose(ipiv));
119
120 // transpose A array/view if necessary and call the implementation with the correct cuBLAS op flag
121 constexpr char op = (has_C_layout<A> ? (is_conj_array_expr<A> ? 'C' : 'T') : 'N');
122 if constexpr (has_C_layout<A>) {
123 return detail::getrs_batch_impl<run_on_device>(transpose(a), std::forward<B>(b), ipiv, op);
124 } else {
125 return detail::getrs_batch_impl<run_on_device>(a, std::forward<B>(b), ipiv, op);
126 }
127 }
128
136 template <BlasArrayOrConj<3> A, BlasArrayFor<A, 3> B, PivotArrayFor<A, 2> IPIV>
137 requires((has_F_layout<A> or has_C_layout<A>) and has_F_layout<B> and (not is_conj_array_expr<A> or has_C_layout<A>))
138 int getrs(A const &a, B &&b, IPIV const &ipiv) { // NOLINT (temporary views are allowed here)
139 return getrs_batch(a, std::forward<B>(b), ipiv);
140 }
141
142} // namespace nda::lapack
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides a C++ interface for various BLAS routines.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a generic interface to the LAPACK/cuSOLVER getrs routine.
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
decltype(auto) to_device(A &&a)
Convert an nda::MemoryArray to its regular type on device memory.
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
Definition tools.hpp:68
static constexpr bool is_conj_array_expr
Constexpr variable that is true if the given type is a conjugate lazy expression.
Definition tools.hpp:47
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
Definition tools.hpp:128
auto batch_ptrs(A &&a)
Given a 2- or 3-dimensional array get an array of pointers to each of the submatrices/subvectors inde...
Definition tools.hpp:182
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
Definition tools.hpp:89
int get_ncols(A const &a)
Get the number of columns of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
Definition tools.hpp:148
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
Definition tools.hpp:79
int getrs_batch(A const &a, B &&b, IPIV const &ipiv)
Interface to batched versions of the LAPACK/cuSOLVER getrs routine.
int getrs(A const &a, B &&b, IPIV const &ipiv)
Interface to the LAPACK/cuSOLVER getrs routine.
Definition getrs.hpp:56
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Provides a C++ interface for various LAPACK routines.
Provides functions to transform the memory layout of an nda::basic_array or nda::basic_array_view.
Macros used in the nda library.
Provides type traits for the nda library.