TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
arithmetic.hpp
Go to the documentation of this file.
1// Copyright (c) 2019-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Olivier Parcollet, Nils Wentzell
16
17/**
18 * @file
19 * @brief Provides lazy expressions for nda::Array types.
20 */
21
22#pragma once
23
24#include "./concepts.hpp"
25#include "./declarations.hpp"
26#include "./linalg/matmul.hpp"
27#include "./linalg/det_and_inverse.hpp"
28#include "./macros.hpp"
29#include "./stdutil/complex.hpp"
30#include "./traits.hpp"
31
32#include <functional>
33#include <type_traits>
34#include <utility>
35
36#ifdef NDA_ENFORCE_BOUNDCHECK
37#include "./exceptions.hpp"
38#endif // NDA_ENFORCE_BOUNDCHECK
39
40namespace nda {
41
42 /**
43 * @addtogroup av_ops
44 * @{
45 */
46
47 /**
48 * @brief Lazy unary expression for nda::Array types.
49 *
50 * @details A lazy unary expression contains a single operand and a unary operation. It fulfills the nda::Array
51 * concept and can therefore be used in any other expression or function that expects an nda::Array type.
52 *
53 * The only supported unary operation is the negation operation ('-').
54 *
55 * @tparam OP Char representing the unary operation.
56 * @param A nda::Array type.
57 */
58 template <char OP, Array A>
59 struct expr_unary {
60 static_assert(OP == '-', "Error in nda::expr_unary: Only negation is supported");
61
62 /// nda::Array object.
63 A a;
64
65 /**
66 * @brief Function call operator.
67 *
68 * @details Forwards the arguments to the nda::Array operand and negates the result.
69 *
70 * @tparam Args Types of the arguments.
71 * @param args Function call arguments.
72 * @return If the result of the forwarded function call is another nda::Array, a new lazy expression is returned.
73 * Otherwise the result is negated and returned.
74 */
75 template <typename... Args>
76 auto operator()(Args &&...args) const {
77 return -a(std::forward<Args>(args)...);
78 }
79
80 /**
81 * @brief Get the shape of the nda::Array operand.
82 * @return `std::array<long, Rank>` object specifying the shape of the operand.
83 */
84 [[nodiscard]] constexpr auto shape() const { return a.shape(); }
85
86 /**
87 * @brief Get the total size of the nda::Array operand.
88 * @return Number of elements contained in the operand.
89 */
90 [[nodiscard]] constexpr long size() const { return a.size(); }
91 };
92
93 /**
94 * @brief Lazy binary expression for nda::ArrayOrScalar types.
95 *
96 * @details A lazy binary expression contains a two operands and a binary operation. It fulfills the nda::Array
97 * concept and can therefore be used in any other expression or function that expects an nda::Array type.
98 *
99 * The supported binary operations are addition ('+'), subtraction ('-'), multiplication ('*') and division ('/').
100 *
101 * @tparam OP Char representing the unary operation.
102 * @param L nda::ArrayOrScalar type of left hand side.
103 * @param R nda::ArrayOrScalar type of right hand side.
104 */
105 template <char OP, ArrayOrScalar L, ArrayOrScalar R>
106 struct expr {
107 /// nda::ArrayOrScalar left hand side operand.
108 L l;
109
110 /// nda::ArrayOrScalar right hand side operand.
111 R r;
112
113 /// Decay type of the left hand side operand.
114 using L_t = std::decay_t<L>;
115
116 /// Decay type of the right hand side operand.
117 using R_t = std::decay_t<R>;
118
119 // FIXME : we should use is_scalar_for_v but the trait needs work to accommodate scalar L or R
120 /// Constexpr variable that is true if the left hand side operand is a scalar.
121 static constexpr bool l_is_scalar = nda::is_scalar_v<L>;
122
123 /// Constexpr variable that is true if the right hand side operand is a scalar.
124 static constexpr bool r_is_scalar = nda::is_scalar_v<R>;
125
126 /// Constexpr variable specifying the algebra of one of the non-scalar operands.
127 static constexpr char algebra = (l_is_scalar ? get_algebra<R> : get_algebra<L>);
128
129 /**
130 * @brief Compute the layout information of the expression.
131 * @return nda::layout_info_t object.
132 */
134 if (l_is_scalar) return (algebra == 'A' ? get_layout_info<R> : layout_info_t{}); // 1 as an array has all flags, it is just 1
135 if (r_is_scalar) return (algebra == 'A' ? get_layout_info<L> : layout_info_t{}); // 1 as a matrix does not, as it is diagonal only.
136 return get_layout_info<R> & get_layout_info<L>; // default case. Take the logical and of all flags
137 }
138
139 /**
140 * @brief Get the shape of the expression (result of the operation).
141 * @return `std::array<long, Rank>` object specifying the shape of the expression.
142 */
143 [[nodiscard]] constexpr decltype(auto) shape() const {
144 if constexpr (l_is_scalar) {
145 return r.shape();
146 } else if constexpr (r_is_scalar) {
147 return l.shape();
148 } else {
149 EXPECTS(l.shape() == r.shape());
150 return l.shape();
151 }
152 }
153
154 /**
155 * @brief Get the total size of the expression (result of the operation).
156 * @return Number of elements contained in the expression.
157 */
158 [[nodiscard]] constexpr long size() const {
159 if constexpr (l_is_scalar) {
160 return r.size();
161 } else if constexpr (r_is_scalar) {
162 return l.size();
163 } else {
164 EXPECTS(l.size() == r.size());
165 return l.size();
166 }
167 }
168
169 /**
170 * @brief Function call operator.
171 *
172 * @details Forwards the arguments to the nda::Array operands and performs the binary operation.
173 *
174 * @tparam Args Types of the arguments.
175 * @param args Function call arguments.
176 * @return If the result of the forwarded function calls contains another nda::Array, a new lazy expression is
177 * returned. Otherwise the result of the binary operation is returned.
178 */
179 template <typename... Args>
180 auto operator()(Args const &...args) const {
181 // addition
182 if constexpr (OP == '+') {
183 if constexpr (l_is_scalar) {
184 // lhs is a scalar
185 if constexpr (algebra == 'M')
186 // rhs is a matrix
187 return (std::equal_to{}(args...) ? l + r(args...) : r(args...));
188 else
189 // rhs is an array
190 return l + r(args...);
191 } else if constexpr (r_is_scalar) {
192 // rhs is a scalar
193 if constexpr (algebra == 'M')
194 // lhs is a matrix
195 return (std::equal_to{}(args...) ? l(args...) + r : l(args...));
196 else
197 // lhs is an array
198 return l(args...) + r;
199 } else
200 // both are arrays or matrices
201 return l(args...) + r(args...);
202 }
203
204 // subtraction
205 if constexpr (OP == '-') {
206 if constexpr (l_is_scalar) {
207 // lhs is a scalar
208 if constexpr (algebra == 'M')
209 // rhs is a matrix
210 return (std::equal_to{}(args...) ? l - r(args...) : -r(args...));
211 else
212 // rhs is an array
213 return l - r(args...);
214 } else if constexpr (r_is_scalar) {
215 // rhs is a scalar
216 if constexpr (algebra == 'M')
217 // lhs is a matrix
218 return (std::equal_to{}(args...) ? l(args...) - r : l(args...));
219 else
220 // lhs is an array
221 return l(args...) - r;
222 } else
223 // both are arrays or matrices
224 return l(args...) - r(args...);
225 }
226
227 // multiplication
228 if constexpr (OP == '*') {
229 if constexpr (l_is_scalar)
230 // lhs is a scalar
231 return l * r(args...);
232 else if constexpr (r_is_scalar)
233 // rhs is a scalar
234 return l(args...) * r;
235 else {
236 // both are arrays (matrix product is not supported here)
237 static_assert(algebra != 'M', "Error in nda::expr: Matrix algebra not supported");
238 return l(args...) * r(args...);
239 }
240 }
241
242 // division
243 if constexpr (OP == '/') {
244 if constexpr (l_is_scalar) {
245 // lhs is a scalar
246 static_assert(algebra != 'M', "Error in nda::expr: Matrix algebra not supported");
247 return l / r(args...);
248 } else if constexpr (r_is_scalar)
249 // rhs is a scalar
250 return l(args...) / r;
251 else {
252 // both are arrays (matrix division is not supported here)
253 static_assert(algebra != 'M', "Error in nda::expr: Matrix algebra not supported");
254 return l(args...) / r(args...);
255 }
256 }
257 }
258
259 /**
260 * @brief Subscript operator.
261 *
262 * @details Simply forwards the argument to the function call operator.
263 *
264 * @tparam Arg Type of the argument.
265 * @param arg Subscript argument.
266 * @return Result of the corresponding function call.
267 */
268 template <typename Arg>
269 auto operator[](Arg &&arg) const {
270 static_assert(get_rank<expr> == 1, "Error in nda::expr: Subscript operator only available for expressions of rank 1");
271 return operator()(std::forward<Arg>(arg));
272 }
273 };
274
275 /**
276 * @brief Unary minus operator for nda::Array types.
277 *
278 * @details It performs lazy elementwise negation.
279 *
280 * @tparam A nda::Array type.
281 * @param a nda::Array operand.
282 * @return Lazy unary expression for the negation operation.
283 */
284 template <Array A>
285 expr_unary<'-', A> operator-(A &&a) {
286 return {std::forward<A>(a)};
287 }
288
289 /**
290 * @brief Addition operator for two nda::Array types.
291 *
292 * @details It performs lazy elementwise addition.
293 *
294 * @tparam L nda::Array type of left hand side.
295 * @tparam R nda::Array type of right hand side.
296 * @param l nda::Array left hand side operand.
297 * @param r nda::Array right hand side operand.
298 * @return Lazy binary expression for the addition operation.
299 */
300 template <Array L, Array R>
301 Array auto operator+(L &&l, R &&r) {
302 static_assert(get_rank<L> == get_rank<R>, "Error in lazy nda::operator+: Rank mismatch");
303 return expr<'+', L, R>{std::forward<L>(l), std::forward<R>(r)};
304 }
305
306 /**
307 * @brief Addition operator for an nda::Array and an nda::Scalar.
308 *
309 * @details Depending on the algebra of the nda::Array, it performs the following lazy operations:
310 * - 'A': Elementwise addition.
311 * - 'M': Addition of the nda::Scalar to the elements on the shorter diagonal of the matrix.
312 *
313 * @tparam A nda::Array type.
314 * @tparam S nda::Scalar type.
315 * @param a nda::Array left hand side operand.
316 * @param s nda::Scalar right hand side operand.
317 * @return Lazy binary expression for the addition operation.
318 */
319 template <Array A, Scalar S>
320 Array auto operator+(A &&a, S &&s) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
321 return expr<'+', A, std::decay_t<S>>{std::forward<A>(a), s};
322 }
323
324 /**
325 * @brief Addition operator for an nda::Scalar and an nda::Array.
326 *
327 * @details Depending on the algebra of the nda::Array, it performs the following lazy operations:
328 * - 'A': Elementwise addition.
329 * - 'M': Addition of the nda::Scalar to the elements on the shorter diagonal of the matrix.
330 *
331 * @tparam S nda::Scalar type.
332 * @tparam A nda::Array type.
333 * @param s nda::Scalar left hand side operand.
334 * @param a nda::Array right hand side operand.
335 * @return Lazy binary expression for the addition operation.
336 */
337 template <Scalar S, Array A>
338 Array auto operator+(S &&s, A &&a) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
339 return expr<'+', std::decay_t<S>, A>{s, std::forward<A>(a)};
340 }
341
342 /**
343 * @brief Subtraction operator for two nda::Array types.
344 *
345 * @details It performs lazy elementwise subtraction.
346 *
347 * @tparam L nda::Array type of left hand side.
348 * @tparam R nda::Array type of right hand side.
349 * @param l nda::Array left hand side operand.
350 * @param r nda::Array right hand side operand.
351 * @return Lazy binary expression for the subtraction operation.
352 */
353 template <Array L, Array R>
354 Array auto operator-(L &&l, R &&r) {
355 static_assert(get_rank<L> == get_rank<R>, "Error in lazy nda::operator-: Rank mismatch");
356 return expr<'-', L, R>{std::forward<L>(l), std::forward<R>(r)};
357 }
358
359 /**
360 * @brief Subtraction operator for an nda::Array and an nda::Scalar.
361 *
362 * @details Depending on the algebra of the nda::Array, it performs the following lazy operations:
363 * - 'A': Elementwise subtraction.
364 * - 'M': Subtraction of the nda::Scalar from the elements on the shorter diagonal of the matrix.
365 *
366 * @tparam A nda::Array type.
367 * @tparam S nda::Scalar type.
368 * @param a nda::Array left hand side operand.
369 * @param s nda::Scalar right hand side operand.
370 * @return Lazy binary expression for the subtraction operation.
371 */
372 template <Array A, Scalar S>
373 Array auto operator-(A &&a, S &&s) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
374 return expr<'-', A, std::decay_t<S>>{std::forward<A>(a), s};
375 }
376
377 /**
378 * @brief Subtraction operator for an nda::Scalar and an nda::Array.
379 *
380 * @details Depending on the algebra of the nda::Array, it performs the following lazy operations:
381 * - 'A': Elementwise subtraction.
382 * - 'M': Subtraction of the elements on the shorter diagonal of the matrix from the nda::Scalar.
383 *
384 * @tparam S nda::Scalar type.
385 * @tparam A nda::Array type.
386 * @param s nda::Scalar left hand side operand.
387 * @param a nda::Array right hand side operand.
388 * @return Lazy binary expression for the subtraction operation.
389 */
390 template <Scalar S, Array A>
391 Array auto operator-(S &&s, A &&a) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
392 return expr<'-', std::decay_t<S>, A>{s, std::forward<A>(a)};
393 }
394
395 /**
396 * @brief Multiplication operator for two nda::Array types.
397 *
398 * @details The input arrays must have one of the following algebras:
399 * - 'A' * 'A': Elementwise multiplication of two arrays returns a lazy nda::expr object.
400 * - 'M' * 'M': Matrix-matrix multiplication calls nda::matmul and returns the result.
401 * - 'M' * 'V': Matrix-vector multiplication calls nda::matvecmul and returns the result.
402 *
403 * Obvious restrictions on the ranks and shapes of the input arrays apply.
404 *
405 * @tparam L nda::Array type of left hand side.
406 * @tparam R nda::Array type of right hand side.
407 * @param l nda::Array left hand side operand.
408 * @param r nda::Array right hand side operand.
409 * @return Either a lazy binary expression for the multiplication operation ('A' * 'A') or the result
410 * of the matrix-matrix or matrix-vector multiplication.
411 */
412 template <Array L, Array R>
413 auto operator*(L &&l, R &&r) {
414 // allowed algebras: A * A or M * M or M * V
415 static constexpr char l_algebra = get_algebra<L>;
416 static constexpr char r_algebra = get_algebra<R>;
417 static_assert(l_algebra != 'V', "Error in nda::operator*: Can not multiply vector by an array or a matrix");
418
419 // two arrays: A * A
420 if constexpr (l_algebra == 'A') {
421 static_assert(r_algebra == 'A', "Error in nda::operator*: Both types need to be arrays");
422 static_assert(get_rank<L> == get_rank<R>, "Error in nda::operator*: Rank mismatch");
423#ifdef NDA_ENFORCE_BOUNDCHECK
424 if (l.shape() != r.shape()) NDA_RUNTIME_ERROR << "Error in nda::operator*: Dimension mismatch: " << l.shape() << " != " << r.shape();
425#endif
426 return expr<'*', L, R>{std::forward<L>(l), std::forward<R>(r)};
427 }
428
429 // two matrices: M * M
430 if constexpr (l_algebra == 'M') {
431 static_assert(r_algebra != 'A', "Error in nda::operator*: Can not multiply a matrix by an array");
432 if constexpr (r_algebra == 'M')
433 // matrix * matrix
434 return matmul(std::forward<L>(l), std::forward<R>(r));
435 else
436 // matrix * vector
437 return matvecmul(std::forward<L>(l), std::forward<R>(r));
438 }
439 }
440
441 /**
442 * @brief Multiplication operator for an nda::Array and an nda::Scalar.
443 *
444 * @details It performs lazy elementwise multiplication.
445 *
446 * @tparam A nda::Array type.
447 * @tparam S nda::Scalar type.
448 * @param a nda::Array left hand side operand.
449 * @param s nda::Scalar right hand side operand.
450 * @return Lazy binary expression for the multiplication operation.
451 */
452 template <Array A, Scalar S>
453 Array auto operator*(A &&a, S &&s) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
454 return expr<'*', A, std::decay_t<S>>{std::forward<A>(a), s};
455 }
456
457 /**
458 * @brief Multiplication operator for an nda::Scalar and an nda::Array.
459 *
460 * @details It performs elementwise multiplication.
461 *
462 * @tparam S nda::Scalar type.
463 * @tparam A nda::Array type.
464 * @param s nda::Scalar left hand side operand.
465 * @param a nda::Array right hand side operand.
466 * @return Lazy binary expression for the multiplication operation.
467 */
468 template <Scalar S, Array A>
469 Array auto operator*(S &&s, A &&a) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
470 return expr<'*', std::decay_t<S>, A>{s, std::forward<A>(a)};
471 }
472
473 /**
474 * @brief Division operator for two nda::Array types.
475 *
476 * @details The input arrays must have one of the following algebras:
477 * - 'A' / 'A': Elementwise division of two arrays returns a lazy nda::expr object.
478 * - 'M' / 'M': Multiplies the lhs matrix with the inverse of the rhs matrix and returns the result.
479 *
480 * Obvious restrictions on the ranks and shapes of the input arrays apply.
481 *
482 * @tparam L nda::Array type of left hand side.
483 * @tparam R nda::Array type of right hand side.
484 * @param l nda::Array left hand side operand.
485 * @param r nda::Array right hand side operand.
486 * @return Either a lazy binary expression for the division operation ('A' * 'A') or the result
487 * of the matrix-inverse matrix multiplication.
488 */
489 template <Array L, Array R>
490 Array auto operator/(L &&l, R &&r) {
491 // allowed algebras: A / A or M / M
492 static constexpr char l_algebra = get_algebra<L>;
493 static constexpr char r_algebra = get_algebra<R>;
494 static_assert(l_algebra != 'V', "Error in nda::operator/: Can not divide a vector by an array or a matrix");
495
496 // two arrays: A / A
497 if constexpr (l_algebra == 'A') {
498 static_assert(r_algebra == 'A', "Error in nda::operator/: Both types need to be arrays");
499 static_assert(get_rank<L> == get_rank<R>, "Error in nda::operator/: Rank mismatch");
500#ifdef NDA_ENFORCE_BOUNDCHECK
501 if (l.shape() != r.shape()) NDA_RUNTIME_ERROR << "Error in nda::operator/: Dimension mismatch: " << l.shape() << " != " << r.shape();
502#endif
503 return expr<'/', L, R>{std::forward<L>(l), std::forward<R>(r)};
504 }
505
506 // two matrices: M / M
507 if constexpr (l_algebra == 'M') {
508 static_assert(r_algebra == 'M', "Error in nda::operator*: Can not divide a matrix by an array/vector");
509 return std::forward<L>(l) * inverse(matrix<get_value_t<R>>{std::forward<R>(r)});
510 }
511 }
512
513 /**
514 * @brief Division operator for an nda::Array and an nda::Scalar.
515 *
516 * @details It performs lazy elementwise division.
517 *
518 * @tparam A nda::Array type.
519 * @tparam S nda::Scalar type.
520 * @param a nda::Array left hand side operand.
521 * @param s nda::Scalar right hand side operand.
522 * @return Lazy binary expression for the division operation.
523 */
524 template <Array A, Scalar S>
525 Array auto operator/(A &&a, S &&s) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
526 return expr<'/', A, std::decay_t<S>>{std::forward<A>(a), s};
527 }
528
529 /**
530 * @brief Division operator for an nda::Scalar and an nda::Array.
531 *
532 * @details Depending on the algebra of the nda::Array, it performs the following lazy operations:
533 * - 'A': Elementwise division.
534 * - 'M': Multiplication of the nda::Scalar with the inverse of the matrix.
535 *
536 * @tparam S nda::Scalar type.
537 * @tparam A nda::Array type.
538 * @param s nda::Scalar left hand side operand.
539 * @param a nda::Array right hand side operand.
540 * @return Lazy binary expression for the division operation (multiplication in case of a matrix).
541 */
542 template <Scalar S, Array A>
543 Array auto operator/(S &&s, A &&a) { // NOLINT (S&& is mandatory for proper concept Array <: typename to work)
544 static constexpr char algebra = get_algebra<A>;
545 if constexpr (algebra == 'M')
546 return s * inverse(matrix<get_value_t<A>>{std::forward<A>(a)});
547 else
548 return expr<'/', std::decay_t<S>, A>{s, std::forward<A>(a)};
549 }
550
551 /** @} */
552
553} // namespace nda
Array auto operator-(L &&l, R &&r)
Subtraction operator for two nda::Array types.
Array auto operator+(L &&l, R &&r)
Addition operator for two nda::Array types.
Array auto operator/(L &&l, R &&r)
Division operator for two nda::Array types.
auto operator*(L &&l, R &&r)
Multiplication operator for two nda::Array types.
expr_unary<'-', A > operator-(A &&a)
Unary minus operator for nda::Array types.
#define EXPECTS(X)
Definition macros.hpp:59
Lazy unary expression for nda::Array types.
A a
nda::Array object.
auto operator()(Args &&...args) const
Function call operator.
constexpr long size() const
Get the total size of the nda::Array operand.
constexpr auto shape() const
Get the shape of the nda::Array operand.
Lazy binary expression for nda::ArrayOrScalar types.
L l
nda::ArrayOrScalar left hand side operand.
constexpr long size() const
Get the total size of the expression (result of the operation).
static constexpr bool r_is_scalar
Constexpr variable that is true if the right hand side operand is a scalar.
constexpr decltype(auto) shape() const
Get the shape of the expression (result of the operation).
auto operator[](Arg &&arg) const
Subscript operator.
static constexpr layout_info_t compute_layout_info()
Compute the layout information of the expression.
static constexpr char algebra
Constexpr variable specifying the algebra of one of the non-scalar operands.
R r
nda::ArrayOrScalar right hand side operand.
auto operator()(Args const &...args) const
Function call operator.
static constexpr bool l_is_scalar
Constexpr variable that is true if the left hand side operand is a scalar.
Stores information about the memory layout and the stride order of an array/view.
Definition traits.hpp:295