22namespace nda::tensor {
54 template <BlasArray A>
58 static_assert(!run_on_device ||
have_cutensor,
"nda::tensor::scale: cuTENSOR support is required");
61 if (op == unary_op::NEG) {
63 op = unary_op::IDENTITY;
67 if constexpr (run_on_device) {
72 if (op == unary_op::IDENTITY || op == unary_op::CONJ) {
78 case unary_op::IDENTITY: a = alpha * a;
break;
79 case unary_op::CONJ: a = alpha *
nda::conj(a);
break;
80 case unary_op::ABS: a = alpha *
nda::abs(a);
break;
81 case unary_op::SQRT: a =
nda::map([alpha](
auto x) {
return alpha * std::sqrt(x); })(a);
break;
82 case unary_op::EXP: a =
nda::map([alpha](
auto x) {
return alpha * std::exp(x); })(a);
break;
83 case unary_op::LOG: a =
nda::map([alpha](
auto x) {
return alpha * std::log(x); })(a);
break;
84 case unary_op::RCP: a =
nda::map([alpha](
auto x) {
return alpha / x; })(a);
break;
86 NDA_RUNTIME_ERROR <<
"nda::tensor::scale: unsupported unary_op on nda host fallback "
87 "(supported: IDENTITY, CONJ, NEG, SQRT, ABS, EXP, LOG, RCP)";
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various cuTENSOR routines.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
auto abs(A &&a)
Function abs for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types).
decltype(auto) conj(A &&a)
Function conj for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types with a com...
mapped< F > map(F f)
Create a lazy function call expression on arrays/views.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void scale(get_value_t< A > alpha, A &&a, unary_op op=unary_op::IDENTITY)
In-place tensor scaling with cuTENSOR/TBLIS/nda dispatch and optional element-wise unary operation.
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
unary_op
Unary element-wise operations for tensor operations.
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
Provides some custom implementations of standard mathematical functions used for lazy,...
Provides lazy, coefficient-wise array operations of standard mathematical functions together with ove...
A type-erased, non-owning view of an nda::MemoryArray or a conjugate lazy expression.
Provides a C++ interface for various TBLIS tensor routines.
Provides type traits for the nda library.