TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
elementwise_trinary.hpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
14#include "./tools.hpp"
15#include "../exceptions.hpp"
17#include "../traits.hpp"
18
19#include <string_view>
20#include <utility>
21
22namespace nda::tensor {
23
28
71 template <BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
72 void elementwise_trinary(get_value_t<A> alpha, A const &a, std::string_view idx_a, get_value_t<A> beta, B const &b, std::string_view idx_b,
73 get_value_t<A> gamma, C &&c, std::string_view idx_c, binary_op op_AB = binary_op::SUM, // NOLINT
74 binary_op op_ABC = binary_op::SUM) {
75 // compile-time checks
76 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A, B, C>;
77 static_assert(!run_on_device || have_cutensor, "nda::tensor::elementwise_trinary: cuTENSOR support is required");
78 static_assert(run_on_device || (get_rank<A> == get_rank<B> && get_rank<A> == get_rank<C>),
79 "nda::tensor::elementwise_trinary: host fallback requires identical ranks");
80
81 // dispatch to backends
82 if constexpr (run_on_device) {
83 device::elementwise_trinary(alpha, a, idx_a, beta, b, idx_b, gamma, c, idx_c, c, op_AB, op_ABC);
84 } else {
85 require_equal_indices(idx_a, idx_b, get_rank<A>, "elementwise_trinary");
86 require_equal_indices(idx_b, idx_c, get_rank<A>, "elementwise_trinary");
87 c = nda::map([alpha, beta, gamma, op_AB, op_ABC](auto x, auto y, auto z) {
88 return detail::apply_binary(op_ABC, detail::apply_binary(op_AB, alpha * x, beta * y), gamma * z);
89 })(a, b, c);
90 }
91 }
92
94 template <BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
95 void elementwise_trinary(A const &a, std::string_view idx_a, B const &b, std::string_view idx_b, C &&c, std::string_view idx_c, // NOLINT
96 binary_op op_AB = binary_op::SUM, binary_op op_ABC = binary_op::SUM) {
97 elementwise_trinary(get_value_t<A>{1}, a, idx_a, get_value_t<A>{1}, b, idx_b, get_value_t<A>{0}, std::forward<C>(c), idx_c, op_AB, op_ABC);
98 }
99
101
102} // namespace nda::tensor
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various cuTENSOR routines.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
mapped< F > map(F f)
Create a lazy function call expression on arrays/views.
Definition map.hpp:206
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:147
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void elementwise_trinary(get_value_t< A > alpha, A const &a, std::string_view idx_a, get_value_t< A > beta, B const &b, std::string_view idx_b, get_value_t< A > gamma, C &&c, std::string_view idx_c, binary_op op_AB=binary_op::SUM, binary_op op_ABC=binary_op::SUM)
In-place elementwise trinary tensor operation with cuTENSOR/nda dispatch.
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition tools.hpp:40
binary_op
Binary operations for tensor operations.
Definition tools.hpp:67
void require_equal_indices(std::string_view idx_a, std::string_view idx_b, int rank, std::string_view op_name)
Check if two index strings are equal and have a specified length.
Definition tools.hpp:247
Provides various traits and utilities for the tensor interface.
Provides type traits for the nda library.