TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
scale.hpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
15#include "./tools.hpp"
16#include "../exceptions.hpp"
20#include "../traits.hpp"
21
22namespace nda::tensor {
23
28
54 template <BlasArray A>
55 void scale(get_value_t<A> alpha, A &&a, unary_op op = unary_op::IDENTITY) { // NOLINT
56 // compile-time checks
57 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A>;
58 static_assert(!run_on_device || have_cutensor, "nda::tensor::scale: cuTENSOR support is required");
59
60 // fold NEG into the scalar (alpha * (-x) == -alpha * x) so backends always see IDENTITY
61 if (op == unary_op::NEG) {
62 alpha = -alpha;
63 op = unary_op::IDENTITY;
64 }
65
66 // dispatch to backends
67 if constexpr (run_on_device) {
68 device::permute(alpha, tensor_view{a, op}, default_index<get_rank<A>>(), tensor_view{a, unary_op::IDENTITY}, default_index<get_rank<A>>());
69 } else {
70 // TBLIS only handles IDENTITY and CONJ; other ops fall through to the nda fallback below
71 if constexpr (have_tblis) {
72 if (op == unary_op::IDENTITY || op == unary_op::CONJ) {
73 tblis::scale(alpha, tensor_view{a, op}, default_index<get_rank<A>>());
74 return;
75 }
76 }
77 switch (op) {
78 case unary_op::IDENTITY: a = alpha * a; break;
79 case unary_op::CONJ: a = alpha * nda::conj(a); break;
80 case unary_op::ABS: a = alpha * nda::abs(a); break;
81 case unary_op::SQRT: a = nda::map([alpha](auto x) { return alpha * std::sqrt(x); })(a); break;
82 case unary_op::EXP: a = nda::map([alpha](auto x) { return alpha * std::exp(x); })(a); break;
83 case unary_op::LOG: a = nda::map([alpha](auto x) { return alpha * std::log(x); })(a); break;
84 case unary_op::RCP: a = nda::map([alpha](auto x) { return alpha / x; })(a); break;
85 default:
86 NDA_RUNTIME_ERROR << "nda::tensor::scale: unsupported unary_op on nda host fallback "
87 "(supported: IDENTITY, CONJ, NEG, SQRT, ABS, EXP, LOG, RCP)";
88 }
89 }
90 }
91
93
94} // namespace nda::tensor
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various cuTENSOR routines.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
auto abs(A &&a)
Function abs for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types).
decltype(auto) conj(A &&a)
Function conj for nda::ArrayOrScalar types (lazy and coefficient-wise for nda::Array types with a com...
mapped< F > map(F f)
Create a lazy function call expression on arrays/views.
Definition map.hpp:206
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void scale(get_value_t< A > alpha, A &&a, unary_op op=unary_op::IDENTITY)
In-place tensor scaling with cuTENSOR/TBLIS/nda dispatch and optional element-wise unary operation.
Definition scale.hpp:55
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
Definition tools.hpp:47
unary_op
Unary element-wise operations for tensor operations.
Definition tools.hpp:103
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition tools.hpp:40
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
Definition tools.hpp:265
Provides some custom implementations of standard mathematical functions used for lazy,...
Provides lazy, coefficient-wise array operations of standard mathematical functions together with ove...
A type-erased, non-owning view of an nda::MemoryArray or a conjugate lazy expression.
Definition tools.hpp:146
Provides a C++ interface for various TBLIS tensor routines.
Provides various traits and utilities for the tensor interface.
Provides type traits for the nda library.