TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
reduce.hpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
15#include "./tools.hpp"
16#include "../algorithms.hpp"
17#include "../basic_array.hpp"
18#include "../concepts.hpp"
19#include "../declarations.hpp"
21#include "../macros.hpp"
25#include "../mem/policies.hpp"
26#include "../traits.hpp"
27
28#include <cmath>
29#include <string_view>
30
31namespace nda::tensor {
32
37
66 template <BlasArrayOrConj A>
67 get_value_t<A> reduce(A const &a, binary_op op_reduce = binary_op::SUM) {
68 // compile-time checks
69 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A>;
70 static_assert(!run_on_device || have_cutensor, "nda::tensor::reduce: cuTENSOR support is required");
71
72 // dispatch to backends
73 if constexpr (run_on_device) {
75 device::reduce(get_value_t<A>{1}, a, default_index<get_rank<A>>(), get_value_t<A>{0}, z.data(), "", z.data(), op_reduce);
76 return nda::to_host(z)(0);
77 } else {
78 // TBLIS handles every op except PROD; fall through to the nda fallback for PROD
79 if constexpr (have_tblis) {
80 if (op_reduce != binary_op::PROD) return tblis::reduce(op_reduce, a, default_index<get_rank<A>>());
81 }
82 // MAX/MIN are not defined for complex value types
83 auto max_min = [&](binary_op op) -> get_value_t<A> {
84 if constexpr (is_complex_v<get_value_t<A>>) {
85 NDA_RUNTIME_ERROR << "nda::tensor::reduce: binary_op::MAX/MIN are unsupported for complex value types";
86 } else {
87 return op == binary_op::MAX ? nda::max_element(a) : nda::min_element(a);
88 }
89 };
90 switch (op_reduce) {
91 case binary_op::SUM: return nda::sum(a);
92 case binary_op::PROD: return nda::product(a);
93 case binary_op::SUM_ABS: return nda::sum(nda::abs(a));
94 case binary_op::MAX_ABS: return nda::max_element(nda::abs(a));
95 case binary_op::MIN_ABS: return nda::min_element(nda::abs(a));
96 case binary_op::NORM_2: return std::sqrt(nda::sum(nda::abs2(a)));
97 case binary_op::MAX:
98 case binary_op::MIN: return max_min(op_reduce);
99 default: NDA_RUNTIME_ERROR << "nda::tensor::reduce: unknown binary_op on nda host fallback";
100 }
101 }
102 }
103
105
106} // namespace nda::tensor
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides various algorithms to be used with nda::Array objects.
Provides the generic class for arrays.
A generic multi-dimensional array.
Provides concepts for the nda library.
Provides a C++ interface for various cuTENSOR routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
auto max_element(A const &a)
Find the maximum element of an array.
auto sum(A const &a)
Sum all the elements of an nda::Array object.
auto product(A const &a)
Multiply all the elements of an nda::Array object.
auto min_element(A const &a)
Find the minimum element of an array.
auto zeros(std::array< Int, Rank > const &shape)
Make an array of the given shape on the given address space and zero-initialize it.
decltype(auto) to_host(A &&a)
Convert an nda::MemoryArray to its regular type on host memory.
auto abs(A &&a)
Function abs for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types).
auto abs2(A &&a)
Function abs2 for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types).
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:147
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
get_value_t< A > reduce(A const &a, binary_op op_reduce=binary_op::SUM)
Full tensor reduction with cuTENSOR/TBLIS/nda dispatch.
Definition reduce.hpp:67
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
Definition tools.hpp:47
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition tools.hpp:40
binary_op
Binary operations for tensor operations.
Definition tools.hpp:67
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
Definition tools.hpp:265
constexpr bool is_complex_v
Constexpr variable that is true if type T is a std::complex type.
Definition traits.hpp:65
Provides definitions of various layout policies.
Macros used in the nda library.
Provides some custom implementations of standard mathematical functions used for lazy,...
Provides lazy, coefficient-wise array operations of standard mathematical functions together with ove...
Defines various memory handling policies.
Contiguous layout policy with C-order (row-major order).
Definition policies.hpp:36
Provides a C++ interface for various TBLIS tensor routines.
Provides various traits and utilities for the tensor interface.
Provides type traits for the nda library.