TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
set.hpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
15#include "./tools.hpp"
16#include "../basic_array.hpp"
17#include "../declarations.hpp"
20#include "../mem/policies.hpp"
21#include "../traits.hpp"
22
23#include <array>
24
25namespace nda::tensor {
26
31
53 template <BlasArray A>
54 void set(get_value_t<A> alpha, A &&a) { // NOLINT
55 // compile-time checks
56 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A>;
57 static_assert(!run_on_device || have_cutensor, "nda::tensor::set: cuTENSOR support is required");
58
59 // dispatch to backends
60 if constexpr (run_on_device) {
62 device::elementwise_binary(get_value_t<A>{1}, tmp.data(), "", get_value_t<A>{0}, a, default_index<get_rank<A>>(), a, binary_op::SUM);
63 } else if constexpr (have_tblis) {
64 tblis::set(alpha, a, default_index<get_rank<A>>());
65 } else {
66 a = alpha;
67 }
68 }
69
71
72} // namespace nda::tensor
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides the generic class for arrays.
A generic multi-dimensional array.
Provides a C++ interface for various cuTENSOR routines.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:147
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
void set(get_value_t< A > alpha, A &&a)
In-place tensor constant fill with cuTENSOR/TBLIS/nda dispatch.
Definition set.hpp:54
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
Definition tools.hpp:47
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition tools.hpp:40
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
Definition tools.hpp:265
Provides definitions of various layout policies.
Defines various memory handling policies.
Contiguous layout policy with C-order (row-major order).
Definition policies.hpp:36
Provides a C++ interface for various TBLIS tensor routines.
Provides various traits and utilities for the tensor interface.
Provides type traits for the nda library.