TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
getrf_batch.hpp
Go to the documentation of this file.
1// Copyright (c) 2021--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./getrf.hpp"
15#include "../basic_array.hpp"
18#include "../blas/tools.hpp"
19#include "../concepts.hpp"
20#include "../declarations.hpp"
21#include "../device.hpp"
23#include "../macros.hpp"
25#include "../traits.hpp"
26
27#include <algorithm>
28#include <type_traits>
29#include <utility>
30
31namespace nda::lapack {
32
33 namespace detail {
34
35 // Implementation of the batched getrf routine.
36 template <bool run_on_device>
37 auto getrf_batch_impl(auto &&a, auto &&ipiv, [[maybe_unused]] auto &&work) {
38 // check the dimensions of the input/output arrays/views and resize if necessary
39 auto const m = a.extent(0);
40 auto const n = a.extent(1);
41 auto const n_b = a.extent(2);
42 resize_or_check_if_view(ipiv, {std::min(m, n), n_b});
43
44 // arrays/views must be LAPACK compatible
45 EXPECTS(a.indexmap().min_stride() == 1);
46 EXPECTS(ipiv.indexmap().min_stride() == 1);
47
48#if defined(__has_feature)
49#if __has_feature(memory_sanitizer)
50 ipiv = 0;
51#endif
52#endif
53
54 // loop over batches and call getrf for each matrix
55 auto loop_getrf = [n_b, &a, &ipiv, &work](auto &info) {
56 for (int i = 0; i < n_b; ++i) {
57 auto a_i = a(range::all, range::all, i);
58 auto ipiv_i = ipiv(range::all, i);
59 info(i) = getrf(a_i, ipiv_i, work);
60 }
61 };
62
63 // perform actual library call(s)
64 auto info = array<int, 1>(n_b, 0);
65 if constexpr (run_on_device) {
66 if (m == n) {
67 // for square matrices on the device use cuBLAS
68 using arr_t = std::remove_cvref_t<decltype(a)>;
69 auto ptr_d = to_device(batch_ptrs(a));
70 auto info_d = vector<int, heap<mem::get_addr_space<arr_t>>>(n_b, 0);
71 blas::device::getrf_batch(n, ptr_d.data(), get_ld(a(range::all, range::all, 0)), ipiv.data(), info_d.data(), n_b);
72 info = info_d;
73 } else {
74 // for rectangular matrices on the device, fall back to looping over batches
75 loop_getrf(info);
76 }
77 } else {
78 // for host, fall back to looping over batches
79 loop_getrf(info);
80 }
81 return info;
82 }
83
84 } // namespace detail
85
121 template <BlasArray<3> A, PivotArrayFor<A, 2> IPIV, BlasArrayFor<A, 1> W = vector_value_t<A>>
123 auto getrf_batch(A &&a, IPIV &&ipiv, W &&work = vector_value_t<A>{}) { // NOLINT (temporary views are allowed here)
124 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A, IPIV, W>;
125
126 // transpose ipiv array/view if necessary
127 if constexpr (has_C_layout<IPIV>) return getrf_batch(a, transpose(ipiv), work);
128
129 // transpose A array/view if necessary and call the implementation
130 if constexpr (has_C_layout<A>) {
131 return detail::getrf_batch_impl<run_on_device>(transpose(a), std::forward<IPIV>(ipiv), std::forward<W>(work));
132 } else {
133 return detail::getrf_batch_impl<run_on_device>(std::forward<A>(a), std::forward<IPIV>(ipiv), std::forward<W>(work));
134 }
135 }
136
144 template <BlasArray<3> A, PivotArrayFor<A, 2> IPIV, BlasArrayFor<A, 1> W = vector_value_t<A>>
146 auto getrf(A &&a, IPIV &&ipiv, W &&work = vector_value_t<A>{}) { // NOLINT (temporary views are allowed here)
147 return getrf_batch(std::forward<A>(a), std::forward<IPIV>(ipiv), std::forward<W>(work));
148 }
149
150} // namespace nda::lapack
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Provides a C++ interface for various BLAS routines.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a generic interface to the LAPACK/cuSOLVER getrf routine.
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
decltype(auto) to_device(A &&a)
Convert an nda::MemoryArray to its regular type on device memory.
basic_array< ValueType, Rank, Layout, 'A', ContainerPolicy > array
Alias template of an nda::basic_array with an 'A' algebra.
basic_array< ValueType, 1, C_layout, 'V', ContainerPolicy > vector
Alias template of an nda::basic_array with rank 1 and a 'V' algebra.
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
Definition tools.hpp:128
auto batch_ptrs(A &&a)
Given a 2- or 3-dimensional array get an array of pointers to each of the submatrices/subvectors inde...
Definition tools.hpp:182
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
Definition tools.hpp:89
vector< get_value_t< A >, heap< mem::get_addr_space< A > > > vector_value_t
Alias for an nda::vector with the same value type and address space as the given type.
Definition tools.hpp:161
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
Definition tools.hpp:79
auto getrf_batch(A &&a, IPIV &&ipiv, W &&work=vector_value_t< A >{})
Interface to batched versions of the LAPACK/cuSOLVER getrf routine.
int getrf(A &&a, IPIV &&ipiv, W &&work=vector_value_t< A >{})
Interface to the LAPACK/cuSOLVER getrf routine.
Definition getrf.hpp:61
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Provides a C++ interface for various LAPACK routines.
Provides functions to transform the memory layout of an nda::basic_array or nda::basic_array_view.
Macros used in the nda library.
Provides type traits for the nda library.