TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
ger.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Miguel Morales, Olivier Parcollet, Nils Wentzell
16
17/**
18 * @file
19 * @brief Provides a generic interface to the BLAS `ger` routine and an outer product routine.
20 */
21
22#pragma once
23
24#include "./interface/cxx_interface.hpp"
25#include "./tools.hpp"
26#include "../basic_functions.hpp"
27#include "../concepts.hpp"
28#include "../exceptions.hpp"
29#include "../layout_transforms.hpp"
30#include "../macros.hpp"
31#include "../mem/address_space.hpp"
32#include "../stdutil/array.hpp"
33#include "../traits.hpp"
34
35#ifndef NDA_HAVE_DEVICE
36#include "../device.hpp"
37#endif
38
39#include <array>
40
41namespace nda::blas {
42
43 /**
44 * @addtogroup linalg_blas
45 * @{
46 */
47
48 /**
49 * @brief Interface to the BLAS `ger` routine.
50 *
51 * @details This function performs the rank 1 operation
52 * \f[
53 * \mathbf{M} \leftarrow \alpha \mathbf{x} \mathbf{y}^H + \mathbf{M} ;,
54 * \f]
55 * where \f$ \alpha \f$ is a scalar, \f$ \mathbf{x} \f$ is an m element vector, \f$ \mathbf{y} \f$ is an n element
56 * vector and \f$ \mathbf{M} \f$ is an m-by-n matrix.
57 *
58 * @tparam X nda::MemoryVector type.
59 * @tparam Y nda::MemoryVector type.
60 * @tparam M nda::MemoryMatrix type.
61 * @param alpha Input scalar.
62 * @param x Input left vector (column vector) of size m.
63 * @param y Input right vector (row vector) of size n.
64 * @param m Input/Output matrix of size m-by-n to which the outer product is added.
65 */
66 template <MemoryVector X, MemoryVector Y, MemoryMatrix M>
67 requires(have_same_value_type_v<X, Y, M> and mem::have_compatible_addr_space<X, Y, M> and is_blas_lapack_v<get_value_t<X>>)
68 void ger(get_value_t<X> alpha, X const &x, Y const &y, M &&m) { // NOLINT (temporary views are allowed here)
69 EXPECTS(m.extent(0) == x.extent(0));
70 EXPECTS(m.extent(1) == y.extent(0));
71
72 // must be lapack compatible
73 EXPECTS(m.indexmap().min_stride() == 1);
74
75 // if in C, we need to call fortran with transposed matrix
76 if (has_C_layout<M>) {
77 ger(alpha, y, x, transpose(m));
78 return;
79 }
80
81 if constexpr (mem::have_device_compatible_addr_space<X, Y, M>) {
82#if defined(NDA_HAVE_DEVICE)
83 device::ger(m.extent(0), m.extent(1), alpha, x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0], m.data(), get_ld(m));
84#else
86#endif
87 } else {
88 f77::ger(m.extent(0), m.extent(1), alpha, x.data(), x.indexmap().strides()[0], y.data(), y.indexmap().strides()[0], m.data(), get_ld(m));
89 }
90 }
91
92 /**
93 * @brief Calculate the outer product of two contiguous arrays/views/scalars.
94 *
95 * @details For general multidimensional arrays/views, it calculates their tensor outer product, i.e.
96 * ```
97 * c(i,j,k,...,u,v,w,...) = a(i,j,k,...) * b(u,v,w,...)
98 * ```
99 * If one of the arguments is a scalar, it multiplies each element of the other argument by the scalar which returns a
100 * lazy nda::expr object.
101 *
102 * If both arguments are scalars, it returns their products.
103 *
104 * @tparam A nda::ArrayOrScalar type.
105 * @tparam B nda::ArrayOrScalar type.
106 * @param a Input array/scalar.
107 * @param b Input array/scalar.
108 * @return (Lazy) Outer product.
109 */
110 template <ArrayOrScalar A, ArrayOrScalar B>
111 auto outer_product(A const &a, B const &b) {
112 if constexpr (Scalar<A> or Scalar<B>) {
113 return a * b;
114 } else {
115 if (not a.is_contiguous()) NDA_RUNTIME_ERROR << "Error in nda::blas::outer_product: First argument has non-contiguous layout";
116 if (not b.is_contiguous()) NDA_RUNTIME_ERROR << "Error in nda::blas::outer_product: Second argument has non-contiguous layout";
117
118 // use BLAS ger to calculate the outer product
119 auto res = zeros<get_value_t<A>, mem::common_addr_space<A, B>>(stdutil::join(a.shape(), b.shape()));
120 auto a_vec = reshape(a, std::array{a.size()});
121 auto b_vec = reshape(b, std::array{b.size()});
122 auto mat = reshape(res, std::array{a.size(), b.size()});
123 ger(1.0, a_vec, b_vec, mat);
124
125 return res;
126 }
127 }
128
129 /** @} */
130
131} // namespace nda::blas
#define NDA_RUNTIME_ERROR
auto outer_product(A const &a, B const &b)
Calculate the outer product of two contiguous arrays/views/scalars.
Definition ger.hpp:111
void ger(get_value_t< X > alpha, X const &x, Y const &y, M &&m)
Interface to the BLAS ger routine.
Definition ger.hpp:68
void compile_error_no_gpu()
Trigger a compilation error in case GPU specific functionality is used without configuring the projec...
Definition device.hpp:47
#define EXPECTS(X)
Definition macros.hpp:59