TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
add.hpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
15#include "./tools.hpp"
17#include "../traits.hpp"
18
19#include <string_view>
20#include <utility>
21
22namespace nda::tensor {
23
28
62 template <BlasArrayOrConj A, BlasArrayFor<A> B>
63 void add(get_value_t<A> alpha, A const &a, std::string_view idx_a, get_value_t<A> beta, B &&b, std::string_view idx_b) { // NOLINT
64 // compile-time checks
65 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A, B>;
66 static_assert(!run_on_device || have_cutensor, "nda::tensor::add: cuTENSOR support is required");
67 static_assert(run_on_device || have_tblis || get_rank<A> == get_rank<B>, "nda::tensor::add: host fallback requires identical ranks");
68
69 // dispatch to backends
70 if constexpr (run_on_device) {
71 device::elementwise_binary(alpha, a, idx_a, beta, b, idx_b, b);
72 } else if constexpr (have_tblis) {
73 tblis::add(alpha, a, idx_a, beta, b, idx_b);
74 } else {
75 require_equal_indices(idx_a, idx_b, get_rank<A>, "add");
76 b = alpha * a + beta * b;
77 }
78 }
79
118 template <BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
119 void add(get_value_t<A> alpha, A const &a, std::string_view idx_a, get_value_t<A> beta, B const &b, std::string_view idx_b, C &&c, // NOLINT
120 std::string_view idx_c) {
121 // compile-time checks
122 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A, B, C>;
123 static_assert(!run_on_device || have_cutensor, "nda::tensor::add: cuTENSOR support is required");
124 static_assert(run_on_device || have_tblis || (get_rank<A> == get_rank<B> && get_rank<A> == get_rank<C>),
125 "nda::tensor::add: host fallback requires identical ranks");
126
127 // dispatch to backends
128 if constexpr (run_on_device) {
129 device::elementwise_binary(alpha, a, idx_a, beta, b, idx_b, c, binary_op::SUM);
130 } else if constexpr (have_tblis) {
131 tblis::add(beta, b, idx_b, get_value_t<A>{0}, c, idx_c);
132 tblis::add(alpha, a, idx_a, get_value_t<A>{1}, c, idx_c);
133 } else {
134 require_equal_indices(idx_a, idx_b, get_rank<A>, "add");
135 require_equal_indices(idx_b, idx_c, get_rank<A>, "add");
136 c = alpha * a + beta * b;
137 }
138 }
139
141 template <BlasArrayOrConj A, BlasArrayFor<A> B>
142 void add(A const &a, std::string_view idx_a, B &&b, std::string_view idx_b) { // NOLINT
143 add(get_value_t<A>{1}, a, idx_a, get_value_t<A>{0}, std::forward<B>(b), idx_b);
144 }
145
147 template <BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
148 void add(A const &a, std::string_view idx_a, B const &b, std::string_view idx_b, C &&c, std::string_view idx_c) { // NOLINT
149 add(get_value_t<A>{1}, a, idx_a, get_value_t<A>{1}, b, idx_b, std::forward<C>(c), idx_c);
150 }
151
153 template <BlasArrayOrConj A, BlasArrayFor<A> B>
154 void add(get_value_t<A> alpha, A const &a, get_value_t<A> beta, B &&b) { // NOLINT
155 add(alpha, a, default_index<get_rank<A>>(), beta, std::forward<B>(b), default_index<get_rank<B>>());
156 }
157
162 template <BlasArrayOrConj A, BlasArrayFor<A> B>
163 void add(A const &a, B &&b) { // NOLINT
165 }
166
168
169} // namespace nda::tensor
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various cuTENSOR routines.
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:147
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void add(get_value_t< A > alpha, A const &a, std::string_view idx_a, get_value_t< A > beta, B &&b, std::string_view idx_b)
Tensor addition with cuTENSOR/TBLIS/nda dispatch.
Definition add.hpp:63
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
Definition tools.hpp:47
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition tools.hpp:40
void require_equal_indices(std::string_view idx_a, std::string_view idx_b, int rank, std::string_view op_name)
Check if two index strings are equal and have a specified length.
Definition tools.hpp:247
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
Definition tools.hpp:265
Provides a C++ interface for various TBLIS tensor routines.
Provides various traits and utilities for the tensor interface.
Provides type traits for the nda library.