TRIQS/nda
2.0.0
Multi-dimensional array library for C++
Toggle main menu visibility
Loading...
Searching...
No Matches
set.hpp
Go to the documentation of this file.
1
// Copyright (c) 2024--present, The Simons Foundation
2
// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3
// SPDX-License-Identifier: Apache-2.0
4
// See LICENSE in the root of this distribution for details.
5
10
11
#pragma once
12
13
#include "
./interface/cutensor_interface.hpp
"
14
#include "
./interface/tblis_interface.hpp
"
15
#include "
./tools.hpp
"
16
#include "
../basic_array.hpp
"
17
#include "
../declarations.hpp
"
18
#include "
../layout/policies.hpp
"
19
#include "
../mem/address_space.hpp
"
20
#include "
../mem/policies.hpp
"
21
#include "
../traits.hpp
"
22
23
#include <array>
24
25
namespace
nda::tensor {
26
31
53
template
<BlasArray A>
54
void
set
(
get_value_t<A>
alpha, A &&a) {
// NOLINT
55
// compile-time checks
56
constexpr
bool
run_on_device =
mem::have_device_compatible_addr_space<A>
;
57
static_assert
(!run_on_device ||
have_cutensor
,
"nda::tensor::set: cuTENSOR support is required"
);
58
59
// dispatch to backends
60
if
constexpr
(run_on_device) {
61
auto
tmp =
basic_array<get_value_t<A>
, 1,
C_layout
,
'A'
,
heap<mem::get_addr_space<A>
>>(1, alpha);
62
device::elementwise_binary(
get_value_t<A>
{1}, tmp.data(),
""
,
get_value_t<A>
{0}, a,
default_index<get_rank<A>
>(), a, binary_op::SUM);
63
}
else
if
constexpr
(
have_tblis
) {
64
tblis::set(alpha, a,
default_index
<
get_rank<A>
>());
65
}
else
{
66
a = alpha;
67
}
68
}
69
71
72
}
// namespace nda::tensor
address_space.hpp
Provides definitions and type traits involving the different memory address spaces supported by nda.
basic_array.hpp
Provides the generic class for arrays.
nda::basic_array
A generic multi-dimensional array.
Definition
basic_array.hpp:92
cutensor_interface.hpp
Provides a C++ interface for various cuTENSOR routines.
declarations.hpp
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
nda::get_rank
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition
traits.hpp:147
nda::get_value_t
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition
traits.hpp:212
nda::mem::have_device_compatible_addr_space
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Definition
address_space.hpp:177
nda::heap
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition
policies.hpp:52
nda::tensor::set
void set(get_value_t< A > alpha, A &&a)
In-place tensor constant fill with cuTENSOR/TBLIS/nda dispatch.
Definition
set.hpp:54
nda::tensor::have_tblis
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
Definition
tools.hpp:47
nda::tensor::have_cutensor
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition
tools.hpp:40
nda::tensor::default_index
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
Definition
tools.hpp:265
policies.hpp
Provides definitions of various layout policies.
policies.hpp
Defines various memory handling policies.
nda::C_layout
Contiguous layout policy with C-order (row-major order).
Definition
policies.hpp:36
tblis_interface.hpp
Provides a C++ interface for various TBLIS tensor routines.
tools.hpp
Provides various traits and utilities for the tensor interface.
traits.hpp
Provides type traits for the nda library.
nda
tensor
set.hpp
Generated by
1.17.0