TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
gather.hpp
Go to the documentation of this file.
1// Copyright (c) 2020-2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Olivier Parcollet, Nils Wentzell
16
21
22#pragma once
23
24#include "./utils.hpp"
26#include "../concepts.hpp"
27#include "../declarations.hpp"
28#include "../layout/range.hpp"
29#include "../macros.hpp"
30#include "../stdutil/array.hpp"
31#include "../traits.hpp"
32
33#include <mpi/mpi.hpp>
34
35#include <cstddef>
36#include <functional>
37#include <numeric>
38#include <span>
39#include <type_traits>
40#include <utility>
41
42namespace nda::detail {
43
44 // Helper function to get the shape and total size of the gathered array/view.
45 template <typename A>
46 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
47 auto mpi_gather_shape_impl(A const &a, mpi::communicator comm, int root, bool all) {
48 auto dims = a.shape();
49 dims[0] = mpi::all_reduce(dims[0], comm);
50 auto gathered_size = std::accumulate(dims.begin(), dims.end(), 1l, std::multiplies<>());
51 if (!all && comm.rank() != root) dims = nda::stdutil::make_initialized_array<dims.size()>(0l);
52 return std::make_pair(dims, gathered_size);
53 }
54
55} // namespace nda::detail
56
57namespace nda {
58
63
96 template <typename A1, typename A2>
97 requires(is_regular_or_view_v<A1> and std::decay_t<A1>::is_stride_order_C()
98 and is_regular_or_view_v<A2> and std::decay_t<A2>::is_stride_order_C())
99 void mpi_gather_capi(A1 const &a_in, A2 &&a_out, mpi::communicator comm = {}, int root = 0, bool all = false) { // NOLINT
100 // check the shape of the input arrays/views
101 EXPECTS_WITH_MESSAGE(detail::have_mpi_equal_shapes(a_in(nda::range(1), nda::ellipsis{}), comm),
102 "Error in nda::mpi_gather_capi: Shapes of arrays/views must be equal save the first one");
103
104 // simply copy if there is no active MPI environment or if the communicator size is < 2
105 if (not mpi::has_env || comm.size() < 2) {
106 a_out = a_in;
107 return;
108 }
109
110 // check if the input arrays/views can be used in the MPI call
111 detail::check_layout_mpi_compatible(a_in, "mpi_gather_capi");
112
113 // get output shape, resize or check the output array/view and prepare output span
114 auto [dims, gathered_size] = detail::mpi_gather_shape_impl(a_in, comm, root, all);
115 auto a_out_span = std::span{a_out.data(), 0};
116 if (all || (comm.rank() == root)) {
117 // check if the output array/view can be used in the MPI call
118 detail::check_layout_mpi_compatible(a_out, "mpi_gather_capi");
119
120 // resize/check the size of the output array/view
121 resize_or_check_if_view(a_out, dims);
122
123 // prepare the output span
124 a_out_span = std::span{a_out.data(), static_cast<std::size_t>(a_out.size())};
125 }
126
127 // gather the data
128 auto a_in_span = std::span{a_in.data(), static_cast<std::size_t>(a_in.size())};
129 mpi::gather_range(a_in_span, a_out_span, gathered_size, comm, root, all);
130 }
131
151 template <typename A>
152 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
153 auto lazy_mpi_gather(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false) {
154 return mpi::lazy<mpi::tag::gather, A>{std::forward<A>(a), comm, root, all};
155 }
156
175 template <typename A>
176 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
177 auto mpi_gather(A const &a, mpi::communicator comm = {}, int root = 0, bool all = false) {
178 using return_t = get_regular_t<A>;
179 return_t a_out;
180 mpi_gather_capi(a, a_out, comm, root, all);
181 return a_out;
182 }
183
185
186} // namespace nda
187
202template <nda::Array A>
203struct mpi::lazy<mpi::tag::gather, A> {
205 using value_type = typename std::decay_t<A>::value_type;
206
208 using stored_type = A;
209
212
214 mpi::communicator comm;
215
217 const int root{0}; // NOLINT (const is fine here)
218
220 const bool all{false}; // NOLINT (const is fine here)
221
235 [[nodiscard]] auto shape() const { return nda::detail::mpi_gather_shape_impl(rhs, comm, root, all).first; }
236
252 template <nda::Array T>
253 requires(std::decay_t<T>::is_stride_order_C())
254 void invoke(T &&target) const { // NOLINT (temporary views are allowed here)
256 }
257};
Provides utility functions for std::array.
Provides basic functions to create and manipulate arrays and views.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
bool all(A const &a)
Do all elements of the array evaluate to true?
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
auto lazy_mpi_gather(A &&a, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of a lazy MPI gather for nda::basic_array or nda::basic_array_view types.
Definition gather.hpp:153
auto mpi_gather(A const &a, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types.
Definition gather.hpp:177
void mpi_gather_capi(A1 const &a_in, A2 &&a_out, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types using a C-style A...
Definition gather.hpp:99
decltype(basic_array{std::declval< T >()}) get_regular_t
Get the type of the nda::basic_array that would be obtained by constructing an array from a given typ...
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:163
constexpr std::array< T, R > make_initialized_array(T v)
Create a new std::array object initialized with a specific value.
Definition array.hpp:168
Macros used in the nda library.
Provides various utility functions used by the MPI interface of nda.
Includes the itertools header and provides some additional utilities.
const int root
MPI root process.
Definition gather.hpp:217
void invoke(T &&target) const
Execute the lazy MPI operation and write the result to a target array/view.
Definition gather.hpp:254
auto shape() const
Compute the shape of the nda::ArrayInitializer object.
Definition gather.hpp:235
const bool all
Should all processes receive the result.
Definition gather.hpp:220
A stored_type
Type of the array/view stored in the lazy object.
Definition gather.hpp:208
typename std::decay_t< A >::value_type value_type
Value type of the array/view.
Definition gather.hpp:205
mpi::communicator comm
MPI communicator.
Definition gather.hpp:214
stored_type rhs
Array/View to be gathered.
Definition gather.hpp:211
Mimics Python's ... syntax.
Definition range.hpp:49
Provides type traits for the nda library.