TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
gesvd.hpp
Go to the documentation of this file.
1// Copyright (c) 2020--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
14#include "../basic_array.hpp"
16#include "../concepts.hpp"
17#include "../declarations.hpp"
19#include "../macros.hpp"
21#include "../mem/policies.hpp"
22#include "../traits.hpp"
23
24#ifndef NDA_HAVE_DEVICE
25#include "../device.hpp"
26#endif // NDA_HAVE_DEVICE
27
28#include <algorithm>
29#include <cmath>
30#include <complex>
31#include <concepts>
32#include <utility>
33
34namespace nda::lapack {
35
64 template <MemoryMatrix A, MemoryVector S, MemoryMatrix U, MemoryMatrix VT>
66 and std::same_as<double, get_value_t<S>>)
67 int gesvd(A &&a, S &&s, U &&u, VT &&vt) { // NOLINT (temporary views are allowed here)
69 "Error in nda::lapack::gesvd: Matrix layouts have to be the same");
70
71 // check the dimensions of the output arrays/views and resize if necessary
72 auto const [m, n] = a.shape();
73 auto const k = std::min(m, n);
75 resize_or_check_if_view(u, {m, m});
76 resize_or_check_if_view(vt, {n, n});
77
78 // cusolverDn?gesvd only supports matrices with m >= n
80 if constexpr (has_C_layout<A>) {
81 EXPECTS(n >= m);
82 } else {
83 EXPECTS(m >= n);
84 }
85 }
86
87 // arrays/views must be LAPACK compatible
88 EXPECTS(a.indexmap().min_stride() == 1);
89 EXPECTS(s.indexmap().min_stride() == 1);
90 EXPECTS(u.indexmap().min_stride() == 1);
91 EXPECTS(vt.indexmap().min_stride() == 1);
92
93 // call host/device implementation depending on address space of input arrays/views
94 auto gesvd_call = []<typename... Ts>(Ts &&...args) {
96#if defined(NDA_HAVE_DEVICE)
97 lapack::device::gesvd(std::forward<Ts>(args)...);
98#else
100#endif
101 } else {
102 lapack::f77::gesvd(std::forward<Ts>(args)...);
103 }
104 };
105
106 // first call to get the optimal buffer size
107 using value_type = get_value_t<A>;
108 value_type tmp_lwork{};
110 int info = 0;
111 if constexpr (has_C_layout<A>) {
112 gesvd_call('A', 'A', n, m, a.data(), get_ld(a), s.data(), vt.data(), get_ld(vt), u.data(), get_ld(u), &tmp_lwork, -1, rwork.data(), info);
113 } else {
114 gesvd_call('A', 'A', m, n, a.data(), get_ld(a), s.data(), u.data(), get_ld(u), vt.data(), get_ld(vt), &tmp_lwork, -1, rwork.data(), info);
115 }
116 int lwork = static_cast<int>(std::ceil(std::real(tmp_lwork)));
117
118 // allocate work buffer and perform actual library call
120 if constexpr (has_C_layout<A>) {
121 gesvd_call('A', 'A', n, m, a.data(), get_ld(a), s.data(), vt.data(), get_ld(vt), u.data(), get_ld(u), work.data(), lwork, rwork.data(), info);
122 } else {
123 gesvd_call('A', 'A', m, n, a.data(), get_ld(a), s.data(), u.data(), get_ld(u), vt.data(), get_ld(vt), work.data(), lwork, rwork.data(), info);
124 }
125
126 return info;
127 }
128
129} // namespace nda::lapack
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
Provides basic functions to create and manipulate arrays and views.
ValueType const * data() const noexcept
Get a pointer to the actual data (in general this is not the beginning of the memory block for a view...
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
basic_array< ValueType, Rank, Layout, 'A', ContainerPolicy > array
Alias template of an nda::basic_array with an 'A' algebra.
constexpr bool have_same_value_type_v
Constexpr variable that is true if all types in As have the same value type as A0.
Definition traits.hpp:186
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:182
int gesvd(A &&a, S &&s, U &&u, VT &&vt)
Interface to the LAPACK gesvd routine.
Definition gesvd.hpp:67
static constexpr bool have_compatible_addr_space
Constexpr variable that is true if all given types have compatible address spaces.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void compile_error_no_gpu()
Trigger a compilation error in case GPU specific functionality is used without configuring the projec...
Definition device.hpp:36
constexpr bool is_blas_lapack_v
Alias for nda::is_double_or_complex_v.
Definition traits.hpp:92
Provides a C++ interface for various LAPACK routines.
static constexpr bool has_C_layout
Constexpr variable that is true if the given nda::Array type has nda::C_layout.
Definition tools.hpp:83
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
Definition tools.hpp:122
Provides definitions of various layout policies.
Macros used in the nda library.
Defines various memory handling policies.
Provides type traits for the nda library.