TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
gather.hpp
Go to the documentation of this file.
1// Copyright (c) 2020-2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Olivier Parcollet, Nils Wentzell
16
22#pragma once
23
24#include "./utils.hpp"
26#include "../concepts.hpp"
27#include "../declarations.hpp"
28#include "../layout/range.hpp"
29#include "../macros.hpp"
30#include "../stdutil/array.hpp"
31#include "../traits.hpp"
32
33#include <mpi/mpi.hpp>
34
35#include <cstddef>
36#include <functional>
37#include <numeric>
38#include <span>
39#include <type_traits>
40#include <utility>
41
42namespace nda::detail {
43
44 // Helper function to get the shape and total size of the gathered array/view.
45 template <typename A>
46 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
47 auto mpi_gather_shape_impl(A const &a, mpi::communicator comm, int root, bool all) {
48 auto dims = a.shape();
49 dims[0] = mpi::all_reduce(dims[0], comm);
50 auto gathered_size = std::accumulate(dims.begin(), dims.end(), 1l, std::multiplies<>());
51 if (!all && comm.rank() != root) dims = nda::stdutil::make_initialized_array<dims.size()>(0l);
52 return std::make_pair(dims, gathered_size);
53 }
54
55} // namespace nda::detail
56
57namespace nda {
58
96 template <typename A1, typename A2>
97 requires(is_regular_or_view_v<A1> and std::decay_t<A1>::is_stride_order_C()
98 and is_regular_or_view_v<A2> and std::decay_t<A2>::is_stride_order_C())
99 void mpi_gather_capi(A1 const &a_in, A2 &&a_out, mpi::communicator comm = {}, int root = 0, bool all = false) { // NOLINT
100 // check the shape of the input arrays/views
101 EXPECTS_WITH_MESSAGE(detail::have_mpi_equal_shapes(a_in(nda::range(1), nda::ellipsis{}), comm),
102 "Error in nda::mpi_gather_capi: Shapes of arrays/views must be equal save the first one");
103
104 // simply copy if there is no active MPI environment or if the communicator size is < 2
105 if (not mpi::has_env || comm.size() < 2) {
106 a_out = a_in;
107 return;
108 }
109
110 // check if the input arrays/views can be used in the MPI call
111 detail::check_layout_mpi_compatible(a_in, "mpi_gather_capi");
112
113 // get output shape, resize or check the output array/view and prepare output span
114 auto [dims, gathered_size] = detail::mpi_gather_shape_impl(a_in, comm, root, all);
115 auto a_out_span = std::span{a_out.data(), 0};
116 if (all || (comm.rank() == root)) {
117 // check if the output array/view can be used in the MPI call
118 detail::check_layout_mpi_compatible(a_out, "mpi_gather_capi");
119
120 // resize/check the size of the output array/view
121 resize_or_check_if_view(a_out, dims);
122
123 // prepare the output span
124 a_out_span = std::span{a_out.data(), static_cast<std::size_t>(a_out.size())};
125 }
126
127 // gather the data
128 auto a_in_span = std::span{a_in.data(), static_cast<std::size_t>(a_in.size())};
129 mpi::gather_range(a_in_span, a_out_span, gathered_size, comm, root, all);
130 }
131
151 template <typename A>
152 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
153 auto lazy_mpi_gather(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false) {
154 return mpi::lazy<mpi::tag::gather, A>{std::forward<A>(a), comm, root, all};
155 }
156
175 template <typename A>
176 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
177 auto mpi_gather(A const &a, mpi::communicator comm = {}, int root = 0, bool all = false) {
178 using return_t = get_regular_t<A>;
179 return_t a_out;
180 mpi_gather_capi(a, a_out, comm, root, all);
181 return a_out;
182 }
183
186} // namespace nda
187
202template <nda::Array A>
203struct mpi::lazy<mpi::tag::gather, A> {
205 using value_type = typename std::decay_t<A>::value_type;
206
208 using stored_type = A;
209
212
214 mpi::communicator comm;
215
217 const int root{0}; // NOLINT (const is fine here)
218
220 const bool all{false}; // NOLINT (const is fine here)
221
235 [[nodiscard]] auto shape() const { return nda::detail::mpi_gather_shape_impl(rhs, comm, root, all).first; }
236
252 template <nda::Array T>
253 requires(std::decay_t<T>::is_stride_order_C())
254 void invoke(T &&target) const { // NOLINT (temporary views are allowed here)
255 nda::mpi_gather_capi(rhs, target, comm, root, all);
256 }
257};
Provides utility functions for std::array.
Provides basic functions to create and manipulate arrays and views.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
bool all(A const &a)
Do all elements of the array evaluate to true?
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
auto lazy_mpi_gather(A &&a, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of a lazy MPI gather for nda::basic_array or nda::basic_array_view types.
Definition gather.hpp:153
auto mpi_gather(A const &a, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types.
Definition gather.hpp:177
void mpi_gather_capi(A1 const &a_in, A2 &&a_out, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types using a C-style A...
Definition gather.hpp:99
decltype(basic_array{std::declval< T >()}) get_regular_t
Get the type of the nda::basic_array that would be obtained by constructing an array from a given typ...
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:163
constexpr std::array< T, R > make_initialized_array(T v)
Create a new std::array object initialized with a specific value.
Definition array.hpp:168
Macros used in the nda library.
Provides various utility functions used by the MPI interface of nda.
Includes the itertools header and provides some additional utilities.
void invoke(T &&target) const
Execute the lazy MPI operation and write the result to a target array/view.
Definition gather.hpp:254
auto shape() const
Compute the shape of the nda::ArrayInitializer object.
Definition gather.hpp:235
A stored_type
Type of the array/view stored in the lazy object.
Definition gather.hpp:208
typename std::decay_t< A >::value_type value_type
Value type of the array/view.
Definition gather.hpp:205
mpi::communicator comm
MPI communicator.
Definition gather.hpp:214
stored_type rhs
Array/View to be gathered.
Definition gather.hpp:211
Mimics Python's ... syntax.
Definition range.hpp:49
Provides type traits for the nda library.