TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
geqrf_batch.hpp
Go to the documentation of this file.
1// Copyright (c) 2021--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./geqrf.hpp"
15#include "../basic_array.hpp"
18#include "../blas/tools.hpp"
19#include "../concepts.hpp"
20#include "../declarations.hpp"
21#include "../device.hpp"
23#include "../macros.hpp"
25#include "../traits.hpp"
26
27#include <algorithm>
28#include <type_traits>
29#include <utility>
30
31namespace nda::lapack {
32
65 template <BlasArray<3> A, BlasArrayFor<A, 2> TAU, BlasArrayFor<A, 1> W = vector_value_t<A>>
66 requires(has_F_layout<A, TAU>)
67 int geqrf_batch(A &&a, TAU &&tau, [[maybe_unused]] W &&work = vector_value_t<A>{}) { // NOLINT (temporary views are allowed here)
68 // check the dimensions of the input/output arrays/views and resize if necessary
69 auto const [m, n, n_b] = a.shape();
70 resize_or_check_if_view(tau, {std::min(m, n), n_b});
71
72 // arrays/views must be LAPACK compatible
73 EXPECTS(a.indexmap().min_stride() == 1);
74 EXPECTS(tau.indexmap().min_stride() == 1);
75
76#if defined(__has_feature)
77#if __has_feature(memory_sanitizer)
78 tau = get_value_t<A>{0};
79#endif
80#endif
81
82 // perform actual library call(s)
83 int info = 0;
85 // get pointers to each matrix/vector in the batch
86 auto a_ptrs = to_device(batch_ptrs(a));
87 auto tau_ptrs = to_device(batch_ptrs(tau));
88
89 blas::device::geqrf_batch(m, n, a_ptrs.data(), get_ld(a(range::all, range::all, 0)), tau_ptrs.data(), info, n_b);
90 } else {
91 // for host, fall back to looping over batches
92 for (int i = 0; i < n_b; ++i) {
93 auto a_i = a(range::all, range::all, i);
94 auto tau_i = tau(range::all, i);
95 int local_info = geqrf(a_i, tau_i, work);
96 if (local_info != 0 && info == 0) info = local_info;
97 }
98 }
99 return info;
100 }
101
109 template <BlasArray<3> A, BlasArrayFor<A, 2> TAU, BlasArrayFor<A, 1> W = vector_value_t<A>>
110 requires(has_F_layout<A, TAU>)
111 int geqrf(A &&a, TAU &&tau, W &&work = vector_value_t<A>{}) { // NOLINT (temporary views are allowed here)
112 return geqrf_batch(std::forward<A>(a), std::forward<TAU>(tau), std::forward<W>(work));
113 }
114
115} // namespace nda::lapack
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
Provides a C++ interface for various BLAS routines.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
Provides a generic interface to the LAPACK/cuSOLVER geqrf routine.
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
decltype(auto) to_device(A &&a)
Convert an nda::MemoryArray to its regular type on device memory.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
Definition tools.hpp:128
auto batch_ptrs(A &&a)
Given a 2- or 3-dimensional array get an array of pointers to each of the submatrices/subvectors inde...
Definition tools.hpp:182
vector< get_value_t< A >, heap< mem::get_addr_space< A > > > vector_value_t
Alias for an nda::vector with the same value type and address space as the given type.
Definition tools.hpp:161
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
Definition tools.hpp:79
int geqrf(A &&a, TAU &&tau, W &&work=vector_value_t< A >{})
Interface to the LAPACK/cuSOLVER geqrf routine.
Definition geqrf.hpp:72
int geqrf_batch(A &&a, TAU &&tau, W &&work=vector_value_t< A >{})
Interface to batched versions of the LAPACK/cuSOLVER geqrf routine.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Provides a C++ interface for various LAPACK routines.
Provides functions to transform the memory layout of an nda::basic_array or nda::basic_array_view.
Macros used in the nda library.
Provides type traits for the nda library.