TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
gesvd.hpp
Go to the documentation of this file.
1// Copyright (c) 2020--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
14#include "../basic_array.hpp"
16#include "../blas/tools.hpp"
17#include "../concepts.hpp"
18#include "../declarations.hpp"
19#include "../device.hpp"
21#include "../macros.hpp"
23#include "../mem/policies.hpp"
24#include "../traits.hpp"
25
26#include <algorithm>
27#include <cmath>
28#include <complex>
29#include <concepts>
30#include <utility>
31
32namespace nda::lapack {
33
68 template <BlasArray<2> A, BlasArrayRealFor<A, 1> S, BlasArrayFor<A, 2> U, BlasArrayFor<A, 2> VH, BlasArrayFor<A, 1> W1 = vector_value_t<A>,
69 BlasArrayRealFor<A, 1> W2 = vector_fp_t<A>>
71 int gesvd(A &&a, S &&s, U &&u, VH &&vh, W1 &&work = vector_value_t<A>{}, W2 &&rwork = vector_fp_t<A>{}) { // NOLINT (tmp views)
73
74 // check the dimensions of the output arrays/views and resize if necessary
75 auto [m, n] = a.shape();
76 auto const k = std::min(m, n);
78 resize_or_check_if_view(u, {m, m});
79 resize_or_check_if_view(vh, {n, n});
80 resize_or_check_work_buffer(rwork, 5 * k);
81
82 // arrays/views must be LAPACK compatible
83 EXPECTS(a.indexmap().min_stride() == 1);
84 EXPECTS(s.indexmap().min_stride() == 1);
85 EXPECTS(u.indexmap().min_stride() == 1);
86 EXPECTS(vh.indexmap().min_stride() == 1);
87
88 // take care of C-layouts by swapping U and V^H
89 auto u_data = u.data();
90 auto vh_data = vh.data();
91 auto u_ld = get_ld(u);
92 auto vh_ld = get_ld(vh);
93 if constexpr (has_C_layout<A>) {
94 std::swap(u_data, vh_data);
95 std::swap(u_ld, vh_ld);
96 std::swap(m, n);
97 }
98
99 // cusolverDn?gesvd only supports matrices with m >= n
100 if constexpr (run_on_device) { EXPECTS(m >= n); }
101
102 // first call to get the optimal buffer size
103 auto tmp_lwork = get_value_t<A>{};
104 int info = 0;
105 if constexpr (run_on_device) {
106 tmp_lwork = device::gesvd_buffer_size(m, n, a.data());
107 } else {
108 f77::gesvd('A', 'A', m, n, a.data(), get_ld(a), s.data(), u_data, u_ld, vh_data, vh_ld, &tmp_lwork, -1, rwork.data(), info);
109 }
110 int lwork = static_cast<int>(std::ceil(std::real(tmp_lwork)));
111
112 // resize/check work buffer
113 resize_or_check_work_buffer(work, lwork);
114
115 // perform actual library call
116 if constexpr (run_on_device) {
117 device::gesvd('A', 'A', m, n, a.data(), get_ld(a), s.data(), u_data, u_ld, vh_data, vh_ld, work.data(), lwork, rwork.data(), info);
118 } else {
119 f77::gesvd('A', 'A', m, n, a.data(), get_ld(a), s.data(), u_data, u_ld, vh_data, vh_ld, work.data(), lwork, rwork.data(), info);
120 }
121
122 return info;
123 }
124
125} // namespace nda::lapack
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
void swap(nda::basic_array_view< V1, R1, LP1, A1, AP1, OP1 > &a, nda::basic_array_view< V2, R2, LP2, A2, AP2, OP2 > &b)=delete
std::swap is deleted for nda::basic_array_view.
Provides basic functions to create and manipulate arrays and views.
Provides various traits and utilities for the BLAS interface.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides GPU and non-GPU specific functionality.
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
Definition tools.hpp:128
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
Definition tools.hpp:89
void resize_or_check_work_buffer(A &a, long min_size)
Resize or check the size of a 1D array/view.
Definition tools.hpp:207
vector< get_value_t< A >, heap< mem::get_addr_space< A > > > vector_value_t
Alias for an nda::vector with the same value type and address space as the given type.
Definition tools.hpp:161
vector< get_fp_t< A >, heap< mem::get_addr_space< A > > > vector_fp_t
Alias for an nda::vector with the same address space as the given type and its value type determined ...
Definition tools.hpp:170
int gesvd(A &&a, S &&s, U &&u, VH &&vh, W1 &&work=vector_value_t< A >{}, W2 &&rwork=vector_fp_t< A >{})
Interface to the LAPACK/cuSOLVER gesvd routine.
Definition gesvd.hpp:71
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Provides a C++ interface for various LAPACK routines.
Provides definitions of various layout policies.
Macros used in the nda library.
Defines various memory handling policies.
Provides type traits for the nda library.