TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
assign.hpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
15#include "./tools.hpp"
16#include "../concepts.hpp"
17#include "../layout/range.hpp"
19#include "../macros.hpp"
21#include "../traits.hpp"
22
23#include <string_view>
24#include <type_traits>
25#include <utility>
26
27namespace nda::tensor {
28
33
34 namespace detail {
35
37 template <int Axis, int Rank>
38 decltype(auto) slice_axis(auto &&arr, long i) {
39 return [&]<size_t... Before, size_t... After>(std::index_sequence<Before...>, std::index_sequence<After...>) -> decltype(auto) {
40 return arr(((void)Before, ::nda::range::all)..., i, ((void)After, ::nda::range::all)...);
41 }(std::make_index_sequence<Axis>{}, std::make_index_sequence<Rank - Axis - 1>{});
42 }
43
44 // Recursively copy from \f$ A \f$ into \f$ B \f$ by descending A's slowest-varying axis until both operands have a
45 // layout that nda::assign_from_ndarray can handle directly.
46 template <MemoryArray A, MemoryArray B>
47 void rec_copy(A const &a, B &&b) { // NOLINT(cppcoreguidelines-missing-std-forward)
48 constexpr bool same_stride_order = get_layout_info<A>.stride_order == get_layout_info<B>.stride_order;
49 if constexpr (get_rank<A> == 1 || (same_stride_order && has_layout_strided_1d<A> && has_layout_strided_1d<B>)) {
50 b = a;
51 } else {
52 constexpr int rank = get_rank<A>;
53 // slowest-varying axis of A — accessed at the type level to avoid taking constexpr of the runtime parameter
54 constexpr int axis = decode<rank>(get_layout_info<A>.stride_order)[0];
55 long n = b.extent(axis);
56 for (long i = 0; i < n; ++i) rec_copy(slice_axis<axis, rank>(a, i), slice_axis<axis, rank>(b, i));
57 }
58 }
59
60 } // namespace detail
61
100 template <MemoryArray A, MemoryArray B>
102 void assign(A const &a, std::string_view idx_a, B &&b, std::string_view idx_b) { // NOLINT
103 // compile-time checks
104 constexpr bool device_compat = mem::have_device_compatible_addr_space<A, B>;
105 constexpr bool host_compat = mem::have_host_compatible_addr_space<A, B>;
106 constexpr bool use_cutensor = device_compat && have_cutensor && is_blas_lapack_v<get_value_t<A>>;
107 constexpr bool use_tblis = host_compat && have_tblis && is_blas_lapack_v<get_value_t<A>>;
108 static_assert(use_cutensor || use_tblis || get_rank<A> == get_rank<B>,
109 "nda::tensor::assign: host/cross-memory fallback requires identical ranks");
110
111 // dispatch to backends
112 if constexpr (use_cutensor) {
113 device::permute(get_value_t<A>{1}, a, idx_a, b, idx_b);
114 } else if constexpr (use_tblis) {
115 tblis::add(get_value_t<A>{1}, a, idx_a, get_value_t<A>{0}, b, idx_b);
116 } else {
117 require_equal_indices(idx_a, idx_b, get_rank<A>, "assign");
118 if constexpr (host_compat) {
119 b = a;
120 } else {
121 detail::rec_copy(a, b);
122 }
123 }
124 }
125
127 template <MemoryArray A, MemoryArray B>
129 void assign(A const &a, B &&b) { // NOLINT
130 assign(a, default_index<get_rank<A>>(), std::forward<B>(b), default_index<get_rank<B>>());
131 }
132
134
135} // namespace nda::tensor
Provides definitions and type traits involving the different memory address spaces supported by nda.
decltype(auto) slice_axis(auto &&arr, long i)
Slice arr on the compile-time axis Axis at index i, padding the other axes with range::all.
Definition assign.hpp:38
Provides concepts for the nda library.
Provides a C++ interface for various cuTENSOR routines.
constexpr bool have_same_value_type_v
Constexpr variable that is true if all types in As have the same value type as A0.
Definition traits.hpp:225
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:147
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
constexpr bool has_layout_strided_1d
Constexpr variable that is true if type A has the strided_1d nda::layout_prop_e guarantee.
Definition traits.hpp:363
constexpr layout_info_t get_layout_info
Constexpr variable that specifies the nda::layout_info_t of type A.
Definition traits.hpp:350
static constexpr bool have_host_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Host.
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void assign(A const &a, std::string_view idx_a, B &&b, std::string_view idx_b)
Tensor assignment with cuTENSOR/TBLIS/nda dispatch.
Definition assign.hpp:102
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
Definition tools.hpp:47
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition tools.hpp:40
void require_equal_indices(std::string_view idx_a, std::string_view idx_b, int rank, std::string_view op_name)
Check if two index strings are equal and have a specified length.
Definition tools.hpp:247
std::string_view default_index()
Generate a default index string ("abc...") of a given length.
Definition tools.hpp:265
constexpr std::array< int, N > decode(uint64_t binary_representation)
Decode a uint64_t into a std::array<int, N>.
constexpr bool is_blas_lapack_v
Constexpr variable that is true if type T is either of type 'float', double, std::complex<float>' or ...
Definition traits.hpp:95
Provides functions to transform the memory layout of an nda::basic_array or nda::basic_array_view.
Macros used in the nda library.
Includes the itertools header and provides some additional utilities.
Provides a C++ interface for various TBLIS tensor routines.
Provides various traits and utilities for the tensor interface.
Provides type traits for the nda library.