TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
gather.hpp
Go to the documentation of this file.
1// Copyright (c) 2020--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./utils.hpp"
15#include "../concepts.hpp"
16#include "../declarations.hpp"
17#include "../layout/range.hpp"
18#include "../macros.hpp"
19#include "../stdutil/array.hpp"
20#include "../traits.hpp"
21
22#include <mpi/mpi.hpp>
23
24#include <cstddef>
25#include <span>
26#include <type_traits>
27
28namespace nda::detail {
29
30 // Helper function to get the shape of the gathered array/view.
31 template <typename A>
32 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
33 auto mpi_gather_shape_impl(A const &a, mpi::communicator comm, int root, bool all) {
34 auto dims = a.shape();
35 dims[0] = mpi::all_reduce(dims[0], comm);
36 if (!all && comm.rank() != root) dims = nda::stdutil::make_initialized_array<dims.size()>(0l);
37 return dims;
38 }
39
40} // namespace nda::detail
41
42namespace nda {
43
48
82 template <typename A1, typename A2>
83 requires(is_regular_or_view_v<A1> and std::decay_t<A1>::is_stride_order_C()
84 and is_regular_or_view_v<A2> and std::decay_t<A2>::is_stride_order_C())
85 void mpi_gather_into(A1 const &a_in, A2 &&a_out, mpi::communicator comm = {}, int root = 0, bool all = false) { // NOLINT
86 // check the shape of the input arrays/views
87 EXPECTS_WITH_MESSAGE(detail::have_mpi_equal_shapes(a_in(nda::range(1), nda::ellipsis{}), comm),
88 "Error in nda::mpi_gather_into: Shapes of arrays/views must be equal save the first one");
89
90 // get output shape
91 auto dims = detail::mpi_gather_shape_impl(a_in, comm, root, all);
92
93 // check if the input and output arrays/views can be used in the MPI call and resize or check the output array/view
94 detail::check_layout_mpi_compatible(a_in, "mpi_gather_into");
95 bool const receives = (all || (comm.rank() == root));
96 if (receives) {
97 detail::check_layout_mpi_compatible(a_out, "mpi_gather_into");
98 resize_or_check_if_view(a_out, dims);
99 }
100
101 // gather the data
102 auto a_in_span = std::span{a_in.data(), static_cast<std::size_t>(a_in.size())};
103 auto a_out_span = std::span{a_out.data(), static_cast<std::size_t>(a_out.size())};
104 mpi::gather_range(a_in_span, a_out_span, comm, root, all);
105 }
106
125 template <typename A>
126 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
127 auto mpi_gather(A const &a, mpi::communicator comm = {}, int root = 0, bool all = false) {
128 using return_t = get_regular_t<A>;
129 return_t a_out;
130 mpi::gather_into(a, a_out, comm, root, all);
131 return a_out;
132 }
133
135
136} // namespace nda
Provides utility functions for std::array.
Provides basic functions to create and manipulate arrays and views.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
bool all(A const &a)
Do all elements of the array evaluate to true?
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
auto mpi_gather(A const &a, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types.
Definition gather.hpp:127
void mpi_gather_into(A1 const &a_in, A2 &&a_out, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types that gathers dire...
Definition gather.hpp:85
decltype(basic_array{std::declval< T >()}) get_regular_t
Get the type of the nda::basic_array that would be obtained by constructing an array from a given typ...
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:152
constexpr std::array< T, R > make_initialized_array(T v)
Create a new std::array object initialized with a specific value.
Definition array.hpp:155
Macros used in the nda library.
Provides various utility functions used by the MPI interface of nda.
Includes the itertools header and provides some additional utilities.
Mimics Python's ... syntax.
Definition range.hpp:36
Provides type traits for the nda library.