TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
getri_batch.hpp
Go to the documentation of this file.
1// Copyright (c) 2021--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./getri.hpp"
15#include "../basic_array.hpp"
17#include "../blas/tools.hpp"
18#include "../concepts.hpp"
19#include "../declarations.hpp"
20#include "../device.hpp"
22#include "../macros.hpp"
24#include "../traits.hpp"
25
26#include <algorithm>
27#include <type_traits>
28#include <utility>
29
30namespace nda::lapack {
31
32 namespace detail {
33
34 // Implementation of the batched getri routine.
35 template <bool run_on_device>
36 auto getri_batch_impl(auto &&a, auto const &ipiv, [[maybe_unused]] auto &&work) {
37 // check the dimensions of the input/output arrays/views
38 auto const [m, n, n_b] = a.shape();
39 EXPECTS(m == n);
40 EXPECTS(ipiv.extent(0) == n);
41 EXPECTS(ipiv.extent(1) == n_b);
42
43 // arrays/views must be LAPACK compatible
44 EXPECTS(a.indexmap().min_stride() == 1);
45 EXPECTS(ipiv.indexmap().min_stride() == 1);
46
47 // perform actual library call(s)
48 auto info = array<int, 1>(n_b, 0);
49 if constexpr (run_on_device) {
50 using arr_t = std::remove_cvref_t<decltype(a)>;
51
52 // resize/check work buffer
53 resize_or_check_if_view(work, {a.size()});
54 EXPECTS(work.indexmap().min_stride() == 1);
55
56 // output buffer for inverted matrices
57 auto c = cuarray_view<get_value_t<arr_t>, 3, F_layout>(a.shape(), work.data());
58
59 auto a_ptrs = to_device(batch_ptrs(a));
60 auto c_ptrs = to_device(batch_ptrs(c));
61 auto info_d = vector<int, heap<mem::get_addr_space<arr_t>>>(n_b, 0);
62 blas::device::getri_batch(n, a_ptrs.data(), get_ld(a(range::all, range::all, 0)), ipiv.data(), c_ptrs.data(),
63 get_ld(c(range::all, range::all, 0)), info_d.data(), n_b);
64 info = info_d;
65
66 // copy result back to a
67 a = c;
68 } else {
69 // for host, fall back to looping over batches
70 for (int i = 0; i < n_b; ++i) {
71 auto a_i = a(range::all, range::all, i);
72 auto ipiv_i = ipiv(range::all, i);
73 info(i) = getri(a_i, ipiv_i, work);
74 }
75 }
76 return info;
77 }
78
79 } // namespace detail
80
109 template <BlasArray<3> A, PivotArrayFor<A, 2> IPIV, BlasArrayFor<A, 1> W = vector_value_t<A>>
111 auto getri_batch(A &&a, IPIV const &ipiv, W &&work = vector_value_t<A>{}) { // NOLINT (temporary views are allowed here)
112 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A, IPIV, W>;
113
114 // transpose ipiv array/view if necessary
115 if constexpr (has_C_layout<IPIV>) return getri_batch(a, transpose(ipiv), work);
116
117 // transpose A array/view if necessary and call the implementation
118 if constexpr (has_C_layout<A>) {
119 return detail::getri_batch_impl<run_on_device>(transpose(a), ipiv, std::forward<W>(work));
120 } else {
121 return detail::getri_batch_impl<run_on_device>(std::forward<A>(a), ipiv, std::forward<W>(work));
122 }
123 }
124
132 template <BlasArray<3> A, PivotArrayFor<A, 2> IPIV, BlasArrayFor<A, 1> W = vector_value_t<A>>
134 auto getri(A &&a, IPIV const &ipiv, W &&work = vector_value_t<A>{}) { // NOLINT (temporary views are allowed here)
135 return getri_batch(std::forward<A>(a), ipiv, std::forward<W>(work));
136 }
137
138} // namespace nda::lapack
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides a C++ interface for various BLAS routines.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a generic interface to the LAPACK getri routine.
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
decltype(auto) to_device(A &&a)
Convert an nda::MemoryArray to its regular type on device memory.
basic_array< ValueType, Rank, Layout, 'A', ContainerPolicy > array
Alias template of an nda::basic_array with an 'A' algebra.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
basic_array_view< ValueType, Rank, Layout, 'A', default_accessor, borrowed< mem::Device > > cuarray_view
Similar to nda::array_view except the memory is stored on the device.
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
Definition tools.hpp:128
auto batch_ptrs(A &&a)
Given a 2- or 3-dimensional array get an array of pointers to each of the submatrices/subvectors inde...
Definition tools.hpp:182
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
Definition tools.hpp:89
vector< get_value_t< A >, heap< mem::get_addr_space< A > > > vector_value_t
Alias for an nda::vector with the same value type and address space as the given type.
Definition tools.hpp:161
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
Definition tools.hpp:79
int getri(A &&a, IPIV const &ipiv, W &&work=vector_value_t< A >{})
Interface to the LAPACK getri routine.
Definition getri.hpp:53
auto getri_batch(A &&a, IPIV const &ipiv, W &&work=vector_value_t< A >{})
Interface to batched versions of the LAPACK/cuSOLVER getri routine.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Provides a C++ interface for various LAPACK routines.
Provides functions to transform the memory layout of an nda::basic_array or nda::basic_array_view.
Macros used in the nda library.
Provides type traits for the nda library.