TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
tools.hpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../blas/tools.hpp"
14#include "../exceptions.hpp"
16#include "../traits.hpp"
17
18#include <algorithm>
19#include <cmath>
20#include <concepts>
21#include <cstdint>
22#include <string_view>
23#include <type_traits>
24#include <utility>
25
26namespace nda::tensor {
27
28 // Import tools from the blas_lapack namespace.
29 using namespace nda::blas_lapack;
30
35
37#if defined(NDA_HAVE_CUDA) && defined(NDA_HAVE_CUTENSOR)
38 static constexpr bool have_cutensor = true;
39#else
40 static constexpr bool have_cutensor = false;
41#endif // NDA_HAVE_CUTENSOR
42
44#ifdef NDA_HAVE_TBLIS
45 static constexpr bool have_tblis = true;
46#else
47 static constexpr bool have_tblis = false;
48#endif // NDA_HAVE_TBLIS
49
51 template <BlasArrayOrConj A>
52 using data_ptr_t = decltype(get_array(std::declval<A>()).data());
53
67 enum class binary_op : std::uint8_t { SUM, PROD, SUM_ABS, MAX, MAX_ABS, MIN, MIN_ABS, NORM_2 };
68
69 // clang-format off
103 enum class unary_op : std::uint8_t {
104 IDENTITY, SQRT, RELU, CONJ, RCP, SIGMOID, TANH, EXP, LOG, ABS, NEG,
105 SIN, COS, TAN, SINH, COSH, ASIN, ACOS, ATAN, ASINH, ACOSH, ATANH,
106 CEIL, FLOOR, MISH, SWISH, SOFT_PLUS, SOFT_SIGN
107 };
108 // clang-format on
109
110 namespace detail {
111
112 // Apply a nda::tensor::binary_op to two scalar operands.
113 template <typename T>
114 T apply_binary(binary_op op, T x, T y) {
115 switch (op) {
116 case binary_op::SUM: return x + y;
117 case binary_op::PROD: return x * y;
118 case binary_op::SUM_ABS: return std::abs(x) + std::abs(y);
119 case binary_op::MAX_ABS: return std::max(std::abs(x), std::abs(y));
120 case binary_op::MIN_ABS: return std::min(std::abs(x), std::abs(y));
121 case binary_op::NORM_2: return std::sqrt(std::norm(x) + std::norm(y));
122 case binary_op::MAX:
123 case binary_op::MIN:
124 if constexpr (!is_complex_v<T>) {
125 return (op == binary_op::MAX ? std::max(x, y) : std::min(x, y));
126 } else {
127 NDA_RUNTIME_ERROR << "nda::tensor: binary_op::MAX/MIN are unsupported for complex value types";
128 }
129 }
130 return T{}; // unreachable
131 }
132
133 } // namespace detail
134
145 template <typename T>
146 struct tensor_view {
148 using value_type = T;
149
151 T *data = nullptr;
152
154 const long *extents = nullptr;
155
157 const long *strides = nullptr;
158
160 int ndim = 0;
161
163 unary_op op = unary_op::IDENTITY;
164
166 tensor_view() = default;
167
176 tensor_view(T *p) : data(p) {}
177
188 template <BlasArray A>
189 requires std::convertible_to<data_ptr_t<A>, T *>
190 tensor_view(A &&a, unary_op op) // NOLINT
191 : data(get_array(a).data()),
192 extents(get_array(a).indexmap().lengths().data()),
193 strides(get_array(a).indexmap().strides().data()),
194 ndim(get_rank<decltype(get_array(a))>),
195 op(op) {}
196
206 template <BlasArrayOrConj A>
207 requires std::convertible_to<data_ptr_t<A>, T *>
208 tensor_view(A &&a) : tensor_view(get_array(a), is_conj_array_expr<A> ? unary_op::CONJ : unary_op::IDENTITY) {} // NOLINT
209
219 template <typename U>
220 requires(!std::same_as<U, T> && std::convertible_to<U *, T *>)
222 : data(tv.data), extents(tv.extents), strides(tv.strides), ndim(tv.ndim), op(tv.op) {}
223 };
224
225 // Deduction guide: tensor_view(array) deduces T from the data pointer of the (underlying) array.
226 template <BlasArray A>
227 tensor_view(A &&, unary_op) -> tensor_view<std::remove_pointer_t<data_ptr_t<A>>>;
228
229 template <BlasArrayOrConj A>
230 tensor_view(A &&) -> tensor_view<std::remove_pointer_t<data_ptr_t<A>>>;
231
233 template <typename T>
235
247 inline void require_equal_indices(std::string_view idx_a, std::string_view idx_b, int rank, std::string_view op_name) {
248 if (static_cast<int>(idx_a.size()) != rank || idx_a != idx_b) {
249 NDA_RUNTIME_ERROR << "nda::tensor::" << op_name << ": fallback to nda operations requires identical index strings of length " << rank
250 << ": idx_a = '" << idx_a << "', idx_b = '" << idx_b << "'";
251 }
252 }
253
263 template <int R>
264 requires(R >= 0 && R <= 26)
265 std::string_view default_index() {
266 static const auto arr = []() constexpr {
267 std::array<char, R> s{};
268 for (int i = 0; i < R; ++i) s[i] = static_cast<char>('a' + i);
269 return s;
270 }();
271
272 return {arr.data(), arr.size()};
273 }
274
276
277} // namespace nda::tensor
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides various traits and utilities for the BLAS interface.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:147
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
Definition tools.hpp:68
static constexpr bool is_conj_array_expr
Constexpr variable that is true if the given type is a conjugate lazy expression.
Definition tools.hpp:47
decltype(get_array(std::declval< A >()).data()) data_ptr_t
Data pointer type of an nda::blas_lapack::BlasArrayOrConj.
Definition tools.hpp:52
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
Definition tools.hpp:47
unary_op
Unary element-wise operations for tensor operations.
Definition tools.hpp:103
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition tools.hpp:40
binary_op
Binary operations for tensor operations.
Definition tools.hpp:67
tensor_view< const T > const_tensor_view
Alias for a tensor_view with const value type.
Definition tools.hpp:234
void require_equal_indices(std::string_view idx_a, std::string_view idx_b, int rank, std::string_view op_name)
Check if two index strings are equal and have a specified length.
Definition tools.hpp:247
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
Definition tools.hpp:265
constexpr bool is_complex_v
Constexpr variable that is true if type T is a std::complex type.
Definition traits.hpp:65
A type-erased, non-owning view of an nda::MemoryArray or a conjugate lazy expression.
Definition tools.hpp:146
T value_type
Value type of the tensor (can be const).
Definition tools.hpp:148
tensor_view(A &&a)
Construct a tensor view from an nda::MemoryArray or a conjugate lazy expression.
Definition tools.hpp:208
tensor_view()=default
Default constructor initializes an empty view.
tensor_view(T *p)
Construct a rank-0 tensor view from a pointer to a scalar value.
Definition tools.hpp:176
tensor_view(tensor_view< U > tv)
Construct a tensor view from from another tensor view with a convertible value type.
Definition tools.hpp:221
tensor_view(A &&a, unary_op op)
Construct a tensor view from an nda::MemoryArray and an nda::tensor::unary_op.
Definition tools.hpp:190
Provides type traits for the nda library.