26namespace nda::tensor {
29 using namespace nda::blas_lapack;
37#if defined(NDA_HAVE_CUDA) && defined(NDA_HAVE_CUTENSOR)
51 template <BlasArrayOrConj A>
67 enum class binary_op : std::uint8_t { SUM, PROD, SUM_ABS, MAX, MAX_ABS, MIN, MIN_ABS, NORM_2 };
104 IDENTITY, SQRT, RELU, CONJ, RCP, SIGMOID, TANH, EXP, LOG, ABS, NEG,
105 SIN, COS, TAN, SINH, COSH, ASIN, ACOS, ATAN, ASINH, ACOSH, ATANH,
106 CEIL, FLOOR, MISH, SWISH, SOFT_PLUS, SOFT_SIGN
113 template <
typename T>
116 case binary_op::SUM:
return x + y;
117 case binary_op::PROD:
return x * y;
118 case binary_op::SUM_ABS:
return std::abs(x) + std::abs(y);
119 case binary_op::MAX_ABS:
return std::max(std::abs(x), std::abs(y));
120 case binary_op::MIN_ABS:
return std::min(std::abs(x), std::abs(y));
121 case binary_op::NORM_2:
return std::sqrt(std::norm(x) + std::norm(y));
125 return (op == binary_op::MAX ? std::max(x, y) : std::min(x, y));
127 NDA_RUNTIME_ERROR <<
"nda::tensor: binary_op::MAX/MIN are unsupported for complex value types";
145 template <
typename T>
188 template <BlasArray A>
189 requires std::convertible_to<data_ptr_t<A>, T *>
206 template <BlasArrayOrConj A>
207 requires std::convertible_to<data_ptr_t<A>, T *>
219 template <
typename U>
220 requires(!std::same_as<U, T> && std::convertible_to<U *, T *>)
226 template <BlasArray A>
227 tensor_view(A &&,
unary_op) -> tensor_view<std::remove_pointer_t<data_ptr_t<A>>>;
229 template <BlasArrayOrConj A>
230 tensor_view(A &&) -> tensor_view<std::remove_pointer_t<data_ptr_t<A>>>;
233 template <
typename T>
247 inline void require_equal_indices(std::string_view idx_a, std::string_view idx_b,
int rank, std::string_view op_name) {
248 if (
static_cast<int>(idx_a.size()) != rank || idx_a != idx_b) {
249 NDA_RUNTIME_ERROR <<
"nda::tensor::" << op_name <<
": fallback to nda operations requires identical index strings of length " << rank
250 <<
": idx_a = '" << idx_a <<
"', idx_b = '" << idx_b <<
"'";
264 requires(R >= 0 && R <= 26)
266 static const auto arr = []()
constexpr {
267 std::array<char, R> s{};
268 for (
int i = 0; i < R; ++i) s[i] = static_cast<char>(
'a' + i);
272 return {arr.data(), arr.size()};
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
static constexpr bool is_conj_array_expr
Constexpr variable that is true if the given type is a conjugate lazy expression.
decltype(get_array(std::declval< A >()).data()) data_ptr_t
Data pointer type of an nda::blas_lapack::BlasArrayOrConj.
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
unary_op
Unary element-wise operations for tensor operations.
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
binary_op
Binary operations for tensor operations.
tensor_view< const T > const_tensor_view
Alias for a tensor_view with const value type.
void require_equal_indices(std::string_view idx_a, std::string_view idx_b, int rank, std::string_view op_name)
Check if two index strings are equal and have a specified length.
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
constexpr bool is_complex_v
Constexpr variable that is true if type T is a std::complex type.
A type-erased, non-owning view of an nda::MemoryArray or a conjugate lazy expression.
T value_type
Value type of the tensor (can be const).
tensor_view(A &&a)
Construct a tensor view from an nda::MemoryArray or a conjugate lazy expression.
tensor_view()=default
Default constructor initializes an empty view.
tensor_view(T *p)
Construct a rank-0 tensor view from a pointer to a scalar value.
tensor_view(tensor_view< U > tv)
Construct a tensor view from from another tensor view with a convertible value type.
tensor_view(A &&a, unary_op op)
Construct a tensor view from an nda::MemoryArray and an nda::tensor::unary_op.
Provides type traits for the nda library.