22namespace nda::tensor {
62 template <BlasArrayOrConj A, BlasArrayFor<A> B>
66 static_assert(!run_on_device ||
have_cutensor,
"nda::tensor::add: cuTENSOR support is required");
70 if constexpr (run_on_device) {
71 device::elementwise_binary(alpha, a, idx_a, beta, b, idx_b, b);
73 tblis::add(alpha, a, idx_a, beta, b, idx_b);
76 b = alpha * a + beta * b;
118 template <BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
120 std::string_view idx_c) {
123 static_assert(!run_on_device ||
have_cutensor,
"nda::tensor::add: cuTENSOR support is required");
125 "nda::tensor::add: host fallback requires identical ranks");
128 if constexpr (run_on_device) {
129 device::elementwise_binary(alpha, a, idx_a, beta, b, idx_b, c, binary_op::SUM);
136 c = alpha * a + beta * b;
141 template <BlasArrayOrConj A, BlasArrayFor<A> B>
142 void add(A
const &a, std::string_view idx_a, B &&b, std::string_view idx_b) {
147 template <BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
148 void add(A
const &a, std::string_view idx_a, B
const &b, std::string_view idx_b, C &&c, std::string_view idx_c) {
153 template <BlasArrayOrConj A, BlasArrayFor<A> B>
162 template <BlasArrayOrConj A, BlasArrayFor<A> B>
163 void add(A
const &a, B &&b) {
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various cuTENSOR routines.
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void add(get_value_t< A > alpha, A const &a, std::string_view idx_a, get_value_t< A > beta, B &&b, std::string_view idx_b)
Tensor addition with cuTENSOR/TBLIS/nda dispatch.
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
void require_equal_indices(std::string_view idx_a, std::string_view idx_b, int rank, std::string_view op_name)
Check if two index strings are equal and have a specified length.
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
Provides a C++ interface for various TBLIS tensor routines.
Provides type traits for the nda library.