TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
basic_functions.hpp
Go to the documentation of this file.
1// Copyright (c) 2019--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./clef/clef.hpp"
14#include "./declarations.hpp"
15#include "./exceptions.hpp"
16#include "./layout/for_each.hpp"
18#include "./traits.hpp"
19
20#include <itertools/itertools.hpp>
21
22#include <array>
23#include <concepts>
24#include <optional>
25#include <random>
26#include <tuple>
27#include <type_traits>
28#include <utility>
29
30namespace nda {
31
36
49 template <typename T, mem::AddressSpace AdrSp = mem::Host, std::integral Int, auto Rank>
50 auto zeros(std::array<Int, Rank> const &shape) {
51 static_assert(AdrSp != mem::None);
52 if constexpr (Rank == 0)
53 return T{};
54 else if constexpr (AdrSp == mem::Host)
55 return array<T, Rank>::zeros(shape);
56 else
57 return cuarray<T, Rank>::zeros(shape);
58 }
59
71 template <typename T, mem::AddressSpace AdrSp = mem::Host, std::integral... Ints>
72 auto zeros(Ints... is) {
73 return zeros<T, AdrSp>(std::array<long, sizeof...(Ints)>{static_cast<long>(is)...});
74 }
75
87 template <typename T, std::integral Int, auto Rank>
88 auto ones(std::array<Int, Rank> const &shape)
89 requires(nda::is_scalar_v<T>)
90 {
91 if constexpr (Rank == 0)
92 return T{1};
93 else { return array<T, Rank>::ones(shape); }
94 }
95
106 template <typename T, std::integral... Ints>
107 auto ones(Ints... is) {
108 return ones<T>(std::array<long, sizeof...(Ints)>{is...});
109 }
110
120 template <std::integral Int = long>
121 auto arange(long first, long last, long step = 1) {
122 auto r = range(first, last, step);
123 auto a = array<Int, 1>(r.size());
124 for (auto [x, v] : itertools::zip(a, r)) x = v;
125 return a;
126 }
127
136 template <std::integral Int = long>
137 auto arange(long last) {
138 return arange<Int>(0, last);
139 }
140
153 template <typename RealType = double, std::integral Int, auto Rank>
154 auto rand(std::array<Int, Rank> const &shape)
155 requires(std::is_floating_point_v<RealType>)
156 {
157 if constexpr (Rank == 0) {
158 auto static gen = std::mt19937{};
159 auto static dist = std::uniform_real_distribution<>{0.0, 1.0};
160 return dist(gen);
161 } else {
162 return array<RealType, Rank>::rand(shape);
163 }
164 }
165
177 template <typename RealType = double, std::integral... Ints>
178 auto rand(Ints... is) {
179 return rand<RealType>(std::array<long, sizeof...(Ints)>{is...});
180 }
181
192 template <Array A>
193 long first_dim(A const &a) {
194 return a.extent(0);
195 }
196
207 template <Array A>
208 long second_dim(A const &a) {
209 return a.extent(1);
210 }
211
226 template <typename A, typename A_t = std::decay_t<A>>
227 decltype(auto) make_regular(A &&a) {
228 if constexpr (Array<A> and not is_regular_v<A>) {
229 return basic_array{std::forward<A>(a)};
230 } else if constexpr (requires { typename A_t::regular_t; }) {
231 if constexpr (not std::is_same_v<A_t, typename A_t::regular_t>)
232 return typename A_t::regular_t{std::forward<A>(a)};
233 else
234 return std::forward<A>(a);
235 } else {
236 return std::forward<A>(a);
237 }
238 }
239
249 template <MemoryArray A>
250 decltype(auto) to_host(A &&a) {
251 if constexpr (not mem::on_host<A>) {
252 return get_regular_host_t<A>{std::forward<A>(a)};
253 } else {
254 return std::forward<A>(a);
255 }
256 }
257
267 template <MemoryArray A>
268 decltype(auto) to_device(A &&a) {
269 if constexpr (not mem::on_device<A>) {
270 return get_regular_device_t<A>{std::forward<A>(a)};
271 } else {
272 return std::forward<A>(a);
273 }
274 }
275
285 template <MemoryArray A>
286 decltype(auto) to_unified(A &&a) {
287 if constexpr (not mem::on_unified<A>) {
288 return get_regular_unified_t<A>{std::forward<A>(a)};
289 } else {
290 return std::forward<A>(a);
291 }
292 }
293
306 template <typename A>
307 void resize_or_check_if_view(A &a, std::array<long, A::rank> const &sha)
309 {
310 if (a.shape() == sha) return;
311 if constexpr (is_regular_v<A>) {
312 a.resize(sha);
313 } else {
314 NDA_RUNTIME_ERROR << "Error in nda::resize_or_check_if_view: Size mismatch: " << a.shape() << " != " << sha;
315 }
316 }
317
330 template <typename T, int R, typename LP, char A, typename CP>
334
348 template <typename T, int R, typename LP, char A, typename AP, typename OP>
352
365 template <typename T, int R, typename LP, char A, typename CP>
367 return array_view<T, R>{a};
368 }
369
383 template <typename T, int R, typename LP, char A, typename AP, typename OP>
387
400 template <typename T, int R, typename LP, char A, typename CP>
404
418 template <typename T, int R, typename LP, char A, typename AP, typename OP>
422
435 template <typename T, int R, typename LP, char A, typename CP>
439
453 template <typename T, int R, typename LP, char A, typename AP, typename OP>
457
468 template <Array LHS, Array RHS>
469 bool operator==(LHS const &lhs, RHS const &rhs) {
470 // FIXME not implemented in clang
471#ifndef __clang__
472 static_assert(std::equality_comparable_with<get_value_t<LHS>, get_value_t<RHS>>,
473 "Error in nda::operator==: Only defined when elements are comparable");
474#endif
475 if (lhs.shape() != rhs.shape()) return false;
476 bool r = true;
477 nda::for_each(lhs.shape(), [&](auto &&...x) { r &= (lhs(x...) == rhs(x...)); });
478 return r;
479 }
480
491 template <ArrayOfRank<1> A, std::ranges::contiguous_range R>
492 bool operator==(A const &a, R const &rg) {
493 return a == basic_array_view{rg};
494 }
495
506 template <std::ranges::contiguous_range R, ArrayOfRank<1> A>
507 bool operator==(R const &rg, A const &a) {
508 return a == rg;
509 }
510
520 template <Array A, typename F>
521 void clef_auto_assign(A &&a, F &&f) { // NOLINT (Should we forward the references?)
522 nda::for_each(a.shape(), [&a, &f](auto &&...x) {
523 if constexpr (clef::is_function<std::decay_t<decltype(f(x...))>>) {
524 clef_auto_assign(a(x...), f(x...));
525 } else {
526 a(x...) = f(x...);
527 }
528 });
529 }
530
548 template <MemoryArray A>
549 auto get_block_layout(A const &a) {
550 EXPECTS(!a.empty());
551 using opt_t = std::optional<std::tuple<int, int, int>>;
552
553 auto const &shape = a.indexmap().lengths();
554 auto const &strides = a.indexmap().strides();
555 auto const &order = a.indexmap().stride_order;
556
557 int data_size = shape[order[0]] * strides[order[0]];
558 int block_size = data_size;
559 int block_str = data_size;
560 int n_blocks = 1;
561
562 for (auto n : range(A::rank)) {
563 auto inner_size = (n == A::rank - 1) ? 1 : strides[order[n + 1]] * shape[order[n + 1]];
564 if (strides[order[n]] != inner_size) {
565 if (block_size < data_size) // second strided dimension
566 return opt_t{};
567 // found a strided dimension with (assumed) contiguous inner blocks
568 n_blocks = a.size() / inner_size;
569 block_size = inner_size;
570 block_str = strides[order[n]];
571 }
572 }
573 ASSERT(n_blocks * block_size == a.size());
574 ASSERT(n_blocks * block_str == shape[order[0]] * strides[order[0]]);
575 return opt_t{std::make_tuple(n_blocks, block_size, block_str)};
576 }
577
591 template <size_t Axis = 0, Array A0, Array... As>
592 auto concatenate(A0 const &a0, As const &...as) {
593 // sanity checks
594 auto constexpr rank = A0::rank;
595 static_assert(Axis < rank);
596 static_assert(have_same_rank_v<A0, As...>);
597 static_assert(have_same_value_type_v<A0, As...>);
598 for (auto ax [[maybe_unused]] : range(rank)) { EXPECTS(ax == Axis or ((a0.extent(ax) == as.extent(ax)) and ... and true)); }
599
600 // construct concatenated array
601 auto new_shape = a0.shape();
602 new_shape[Axis] = (as.extent(Axis) + ... + new_shape[Axis]);
603 auto new_array = array<get_value_t<A0>, rank>(new_shape);
604
605 // slicing helper function
606 auto slice_Axis = [](Array auto &a, range r) {
607 auto all_or_range = std::make_tuple(range::all, r);
608 return [&]<auto... Is>(std::index_sequence<Is...>) { return a(std::get<Is == Axis>(all_or_range)...); }(std::make_index_sequence<rank>{});
609 };
610
611 // initialize concatenated array
612 long offset = 0;
613 for (auto const &a_view : {basic_array_view(a0), basic_array_view(as)...}) {
614 slice_Axis(new_array, range(offset, offset + a_view.extent(Axis))) = a_view;
615 offset += a_view.extent(Axis);
616 }
617
618 return new_array;
619 };
620
622
623} // namespace nda
Provides definitions and type traits involving the different memory address spaces supported by nda.
A generic view of a multi-dimensional array.
A generic multi-dimensional array.
long extent(int i) const noexcept
Get the extent of the ith dimension.
static basic_array ones(std::array< Int, Rank > const &shape)
static basic_array zeros(std::array< Int, Rank > const &shape)
static basic_array rand(std::array< Int, Rank > const &shape)
Includes all relevant headers for the core clef library.
Check if a given type satisfies the array concept.
Definition concepts.hpp:230
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
Provides for_each functions for multi-dimensional arrays/views.
auto rand(std::array< Int, Rank > const &shape)
Make an array of the given shape and initialize it with random values from the uniform distribution o...
decltype(auto) make_regular(A &&a)
Make a given object regular.
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
decltype(auto) to_unified(A &&a)
Convert an nda::MemoryArray to its regular type on unified memory.
auto make_const_view(basic_array< T, R, LP, A, CP > const &a)
Make an nda::basic_array_view with a const value type from a given nda::basic_array.
auto zeros(std::array< Int, Rank > const &shape)
Make an array of the given shape on the given address space and zero-initialize it.
auto arange(long first, long last, long step=1)
Make a 1-dimensional integer array and initialize it with values of a given nda::range.
decltype(auto) to_host(A &&a)
Convert an nda::MemoryArray to its regular type on host memory.
auto make_matrix_view(basic_array< T, R, LP, A, CP > const &a)
Make an nda::matrix_view of a given nda::basic_array.
decltype(auto) to_device(A &&a)
Convert an nda::MemoryArray to its regular type on device memory.
auto make_array_view(basic_array< T, R, LP, A, CP > const &a)
Make an nda::array_view of a given nda::basic_array.
auto make_array_const_view(basic_array< T, R, LP, A, CP > const &a)
Make an nda::array_const_view of a given nda::basic_array.
auto ones(std::array< Int, Rank > const &shape)
Make an array of the given shape and one-initialize it.
auto concatenate(A0 const &a0, As const &...as)
Join a sequence of nda::Array types along an existing axis.
basic_array_view< ValueType, Rank, Layout, 'A', default_accessor, borrowed<> > array_view
Alias template of an nda::basic_array_view with an 'A' algebra, nda::default_accessor and nda::borrow...
basic_array_view< ValueType, 2, Layout, 'M', default_accessor, borrowed<> > matrix_view
Alias template of an nda::basic_array_view with rank 2, an 'M' algebra, nda::default_accessor and nda...
basic_array_view< ValueType const, Rank, Layout, 'A', default_accessor, borrowed<> > array_const_view
Same as nda::array_view except for const value types.
basic_array< ValueType, Rank, Layout, 'A', ContainerPolicy > array
Alias template of an nda::basic_array with an 'A' algebra.
constexpr bool is_regular_v
Constexpr variable that is true if type A is a regular array, i.e. an nda::basic_array.
Definition traits.hpp:134
std::conditional_t< mem::on_device< RT >, RT, basic_array< get_value_t< RT >, get_rank< RT >, get_contiguous_layout_policy< get_rank< RT >, get_layout_info< RT >.stride_order >, get_algebra< RT >, heap< mem::Device > > > get_regular_device_t
Get the type of the nda::basic_array that would be obtained by constructing an array on device memory...
std::conditional_t< mem::on_unified< RT >, RT, basic_array< get_value_t< RT >, get_rank< RT >, get_contiguous_layout_policy< get_rank< RT >, get_layout_info< RT >.stride_order >, get_algebra< RT >, heap< mem::Unified > > > get_regular_unified_t
Get the type of the nda::basic_array that would be obtained by constructing an array on unified memor...
long first_dim(A const &a)
Get the extent of the first dimension of the array.
std::conditional_t< mem::on_host< RT >, RT, basic_array< get_value_t< RT >, get_rank< RT >, get_contiguous_layout_policy< get_rank< RT >, get_layout_info< RT >.stride_order >, get_algebra< RT >, heap< mem::Host > > > get_regular_host_t
Get the type of the nda::basic_array that would be obtained by constructing an array on host memory f...
constexpr bool have_same_value_type_v
Constexpr variable that is true if all types in As have the same value type as A0.
Definition traits.hpp:185
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:181
constexpr bool have_same_rank_v
Constexpr variable that is true if all types in As have the same rank as A0.
Definition traits.hpp:189
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:152
bool operator==(LHS const &lhs, RHS const &rhs)
Equal-to comparison operator for two nda::Array objects.
long second_dim(A const &a)
Get the extent of the second dimension of the array.
void clef_auto_assign(A &&a, F &&f)
Overload of nda::clef::clef_auto_assign function for nda::Array objects.
__inline__ void for_each(std::array< Int, R > const &shape, F &&f)
Loop over all possible index values of a given shape and apply a function to them.
Definition for_each.hpp:116
auto get_block_layout(A const &a)
Check if a given nda::MemoryArray has a block-strided layout.
AddressSpace
Enum providing identifiers for the different memory address spaces.
static constexpr bool on_device
Constexpr variable that is true if all given types have a Device address space.
static constexpr bool on_unified
Constexpr variable that is true if all given types have a Unified address space.
static constexpr bool on_host
Constexpr variable that is true if all given types have a Host address space.
constexpr bool is_scalar_v
Constexpr variable that is true if type S is a scalar type, i.e. arithmetic or complex.
Definition traits.hpp:68
Provides type traits for the nda library.