TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
arithmetic.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./concepts.hpp"
14#include "./declarations.hpp"
15#include "./linalg/inv.hpp"
16#include "./linalg/matmul.hpp"
18#include "./macros.hpp"
19#include "./stdutil/complex.hpp"
20#include "./traits.hpp"
21
22#include <functional>
23#include <type_traits>
24#include <utility>
25
26#ifdef NDA_ENFORCE_BOUNDCHECK
27#include "./exceptions.hpp"
28#endif // NDA_ENFORCE_BOUNDCHECK
29
30namespace nda {
31
36
48 template <char OP, Array A>
49 struct expr_unary {
50 static_assert(OP == '-', "Error in nda::expr_unary: Only negation is supported");
51
53 A a;
54
65 template <typename... Args>
66 auto operator()(Args &&...args) const {
67 return -a(std::forward<Args>(args)...);
68 }
69
74 [[nodiscard]] constexpr auto shape() const { return a.shape(); }
75
80 [[nodiscard]] constexpr long size() const { return a.size(); }
81 };
82
95 template <char OP, ArrayOrScalar L, ArrayOrScalar R>
96 struct expr {
98 L l;
99
101 R r;
102
104 using L_t = std::decay_t<L>;
105
107 using R_t = std::decay_t<R>;
108
109 // FIXME : we should use is_scalar_for_v but the trait needs work to accommodate scalar L or R
111 static constexpr bool l_is_scalar = nda::is_scalar_v<L>;
112
114 static constexpr bool r_is_scalar = nda::is_scalar_v<R>;
115
117 static constexpr char algebra = (l_is_scalar ? get_algebra<R> : get_algebra<L>);
118
124 if (l_is_scalar) return (algebra == 'A' ? get_layout_info<R> : layout_info_t{}); // 1 as an array has all flags, it is just 1
125 if (r_is_scalar) return (algebra == 'A' ? get_layout_info<L> : layout_info_t{}); // 1 as a matrix does not, as it is diagonal only.
126 return get_layout_info<R> & get_layout_info<L>; // default case. Take the logical and of all flags
127 }
128
133 [[nodiscard]] constexpr decltype(auto) shape() const {
134 if constexpr (l_is_scalar) {
135 return r.shape();
136 } else if constexpr (r_is_scalar) {
137 return l.shape();
138 } else {
139 EXPECTS(l.shape() == r.shape());
140 return l.shape();
141 }
142 }
143
148 [[nodiscard]] constexpr long size() const {
149 if constexpr (l_is_scalar) {
150 return r.size();
151 } else if constexpr (r_is_scalar) {
152 return l.size();
153 } else {
154 EXPECTS(l.size() == r.size());
155 return l.size();
156 }
157 }
158
169 template <typename... Args>
170 auto operator()(Args const &...args) const {
171 // addition
172 if constexpr (OP == '+') {
173 if constexpr (l_is_scalar) {
174 // lhs is a scalar
175 if constexpr (algebra == 'M')
176 // rhs is a matrix
177 return (std::equal_to{}(args...) ? l + r(args...) : r(args...));
178 else
179 // rhs is an array
180 return l + r(args...);
181 } else if constexpr (r_is_scalar) {
182 // rhs is a scalar
183 if constexpr (algebra == 'M')
184 // lhs is a matrix
185 return (std::equal_to{}(args...) ? l(args...) + r : l(args...));
186 else
187 // lhs is an array
188 return l(args...) + r;
189 } else
190 // both are arrays or matrices
191 return l(args...) + r(args...);
192 }
193
194 // subtraction
195 if constexpr (OP == '-') {
196 if constexpr (l_is_scalar) {
197 // lhs is a scalar
198 if constexpr (algebra == 'M')
199 // rhs is a matrix
200 return (std::equal_to{}(args...) ? l - r(args...) : -r(args...));
201 else
202 // rhs is an array
203 return l - r(args...);
204 } else if constexpr (r_is_scalar) {
205 // rhs is a scalar
206 if constexpr (algebra == 'M')
207 // lhs is a matrix
208 return (std::equal_to{}(args...) ? l(args...) - r : l(args...));
209 else
210 // lhs is an array
211 return l(args...) - r;
212 } else
213 // both are arrays or matrices
214 return l(args...) - r(args...);
215 }
216
217 // multiplication
218 if constexpr (OP == '*') {
219 if constexpr (l_is_scalar)
220 // lhs is a scalar
221 return l * r(args...);
222 else if constexpr (r_is_scalar)
223 // rhs is a scalar
224 return l(args...) * r;
225 else {
226 // both are arrays (matrix product is not supported here)
227 static_assert(algebra != 'M', "Error in nda::expr: Matrix algebra not supported");
228 return l(args...) * r(args...);
229 }
230 }
231
232 // division
233 if constexpr (OP == '/') {
234 if constexpr (l_is_scalar) {
235 // lhs is a scalar
236 static_assert(algebra != 'M', "Error in nda::expr: Matrix algebra not supported");
237 return l / r(args...);
238 } else if constexpr (r_is_scalar)
239 // rhs is a scalar
240 return l(args...) / r;
241 else {
242 // both are arrays (matrix division is not supported here)
243 static_assert(algebra != 'M', "Error in nda::expr: Matrix algebra not supported");
244 return l(args...) / r(args...);
245 }
246 }
247 }
248
258 template <typename Arg>
259 auto operator[](Arg &&arg) const {
260 static_assert(get_rank<expr> == 1, "Error in nda::expr: Subscript operator only available for expressions of rank 1");
261 return operator()(std::forward<Arg>(arg));
262 }
263 };
264
274 template <Array A>
275 expr_unary<'-', A> operator-(A &&a) {
276 return {std::forward<A>(a)};
277 }
278
290 template <Array L, Array R>
291 Array auto operator+(L &&l, R &&r) {
292 static_assert(get_rank<L> == get_rank<R>, "Error in lazy nda::operator+: Rank mismatch");
293 return expr<'+', L, R>{std::forward<L>(l), std::forward<R>(r)};
294 }
295
309 template <Array A, Scalar S>
310 Array auto operator+(A &&a, S &&s) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
311 return expr<'+', A, std::decay_t<S>>{std::forward<A>(a), s};
312 }
313
327 template <Scalar S, Array A>
328 Array auto operator+(S &&s, A &&a) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
329 return expr<'+', std::decay_t<S>, A>{s, std::forward<A>(a)};
330 }
331
343 template <Array L, Array R>
344 Array auto operator-(L &&l, R &&r) {
345 static_assert(get_rank<L> == get_rank<R>, "Error in lazy nda::operator-: Rank mismatch");
346 return expr<'-', L, R>{std::forward<L>(l), std::forward<R>(r)};
347 }
348
362 template <Array A, Scalar S>
363 Array auto operator-(A &&a, S &&s) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
364 return expr<'-', A, std::decay_t<S>>{std::forward<A>(a), s};
365 }
366
380 template <Scalar S, Array A>
381 Array auto operator-(S &&s, A &&a) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
382 return expr<'-', std::decay_t<S>, A>{s, std::forward<A>(a)};
383 }
384
402 template <Array L, Array R>
403 auto operator*(L &&l, R &&r) {
404 // allowed algebras: A * A or M * M or M * V
405 static constexpr char l_algebra = get_algebra<L>;
406 static constexpr char r_algebra = get_algebra<R>;
407 static_assert(l_algebra != 'V', "Error in nda::operator*: Can not multiply vector by an array or a matrix");
408
409 // two arrays: A * A
410 if constexpr (l_algebra == 'A') {
411 static_assert(r_algebra == 'A', "Error in nda::operator*: Both types need to be arrays");
412 static_assert(get_rank<L> == get_rank<R>, "Error in nda::operator*: Rank mismatch");
413#ifdef NDA_ENFORCE_BOUNDCHECK
414 if (l.shape() != r.shape()) NDA_RUNTIME_ERROR << "Error in nda::operator*: Dimension mismatch: " << l.shape() << " != " << r.shape();
415#endif
416 return expr<'*', L, R>{std::forward<L>(l), std::forward<R>(r)};
417 }
418
419 // two matrices: M * M
420 if constexpr (l_algebra == 'M') {
421 static_assert(r_algebra != 'A', "Error in nda::operator*: Can not multiply a matrix by an array");
422 if constexpr (r_algebra == 'M')
423 // matrix * matrix
424 return linalg::matmul(std::forward<L>(l), std::forward<R>(r));
425 else
426 // matrix * vector
427 return linalg::matvecmul(std::forward<L>(l), std::forward<R>(r));
428 }
429 }
430
442 template <Array A, Scalar S>
443 Array auto operator*(A &&a, S &&s) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
444 return expr<'*', A, std::decay_t<S>>{std::forward<A>(a), s};
445 }
446
458 template <Scalar S, Array A>
459 Array auto operator*(S &&s, A &&a) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
460 return expr<'*', std::decay_t<S>, A>{s, std::forward<A>(a)};
461 }
462
479 template <Array L, Array R>
480 Array auto operator/(L &&l, R &&r) {
481 // allowed algebras: A / A or M / M
482 static constexpr char l_algebra = get_algebra<L>;
483 static constexpr char r_algebra = get_algebra<R>;
484 static_assert(l_algebra != 'V', "Error in nda::operator/: Can not divide a vector by an array or a matrix");
485
486 // two arrays: A / A
487 if constexpr (l_algebra == 'A') {
488 static_assert(r_algebra == 'A', "Error in nda::operator/: Both types need to be arrays");
489 static_assert(get_rank<L> == get_rank<R>, "Error in nda::operator/: Rank mismatch");
490#ifdef NDA_ENFORCE_BOUNDCHECK
491 if (l.shape() != r.shape()) NDA_RUNTIME_ERROR << "Error in nda::operator/: Dimension mismatch: " << l.shape() << " != " << r.shape();
492#endif
493 return expr<'/', L, R>{std::forward<L>(l), std::forward<R>(r)};
494 }
495
496 // two matrices: M / M
497 if constexpr (l_algebra == 'M') {
498 static_assert(r_algebra == 'M', "Error in nda::operator*: Can not divide a matrix by an array/vector");
499 return std::forward<L>(l) * linalg::inv(matrix<get_value_t<R>>{std::forward<R>(r)});
500 }
501 }
502
514 template <Array A, Scalar S>
515 Array auto operator/(A &&a, S &&s) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
516 return expr<'/', A, std::decay_t<S>>{std::forward<A>(a), s};
517 }
518
532 template <Scalar S, Array A>
533 Array auto operator/(S &&s, A &&a) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
534 static constexpr char algebra = get_algebra<A>;
535 if constexpr (algebra == 'M')
536 return s * linalg::inv(matrix<get_value_t<A>>{std::forward<A>(a)});
537 else
538 return expr<'/', std::decay_t<S>, A>{s, std::forward<A>(a)};
539 }
540
542
543} // namespace nda
Provides additional operators for std::complex and other arithmetic types.
Check if a given type satisfies the array concept.
Definition concepts.hpp:205
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
Array auto operator+(L &&l, R &&r)
Addition operator for two nda::Array types.
Array auto operator/(L &&l, R &&r)
Division operator for two nda::Array types.
auto operator*(L &&l, R &&r)
Multiplication operator for two nda::Array types.
expr_unary<'-', A > operator-(A &&a)
Unary minus operator for nda::Array types.
basic_array< ValueType, 2, Layout, 'M', ContainerPolicy > matrix
Alias template of an nda::basic_array with rank 2 and an 'M' algebra.
constexpr char get_algebra
Constexpr variable that specifies the algebra of a type.
Definition traits.hpp:116
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:126
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:182
constexpr layout_info_t get_layout_info
Constexpr variable that specifies the nda::layout_info_t of type A.
Definition traits.hpp:311
auto matmul(A &&a, B &&b)
Compute the matrix-matrix product of two nda::matrix objects.
Definition matmul.hpp:128
auto inv(M const &m)
Compute the inverse of an matrix .
Definition inv.hpp:141
auto matvecmul(A const &a, X const &x)
Compute the matrix-vector product of an nda::matrix and an nda::vector object.
constexpr bool is_scalar_v
Constexpr variable that is true if type S is a scalar type, i.e. arithmetic or complex.
Definition traits.hpp:69
Provides functions to compute the inverse of a matrix.
Macros used in the nda library.
Provides a generic matrix-matrix multiplication.
Provides a generic matrix-vector multiplication.
Lazy unary expression for nda::Array types.
A a
nda::Array object.
auto operator()(Args &&...args) const
Function call operator.
constexpr long size() const
Get the total size of the nda::Array operand.
constexpr auto shape() const
Get the shape of the nda::Array operand.
Lazy binary expression for nda::ArrayOrScalar types.
L l
nda::ArrayOrScalar left hand side operand.
constexpr long size() const
Get the total size of the expression (result of the operation).
static constexpr bool r_is_scalar
Constexpr variable that is true if the right hand side operand is a scalar.
std::decay_t< L > L_t
Decay type of the left hand side operand.
constexpr decltype(auto) shape() const
Get the shape of the expression (result of the operation).
auto operator[](Arg &&arg) const
Subscript operator.
static constexpr layout_info_t compute_layout_info()
Compute the layout information of the expression.
static constexpr char algebra
Constexpr variable specifying the algebra of one of the non-scalar operands.
R r
nda::ArrayOrScalar right hand side operand.
auto operator()(Args const &...args) const
Function call operator.
static constexpr bool l_is_scalar
Constexpr variable that is true if the left hand side operand is a scalar.
std::decay_t< R > R_t
Decay type of the right hand side operand.
Stores information about the memory layout and the stride order of an array/view.
Definition traits.hpp:285
Provides type traits for the nda library.