TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
gather.hpp
Go to the documentation of this file.
1// Copyright (c) 2020-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Olivier Parcollet, Nils Wentzell
16
22#pragma once
23
25#include "../concepts.hpp"
26#include "../exceptions.hpp"
27#include "../traits.hpp"
28
29#include <mpi/mpi.hpp>
30
31#include <type_traits>
32#include <utility>
33#include <vector>
34
49template <nda::Array A>
50struct mpi::lazy<mpi::tag::gather, A> {
52 using value_type = typename std::decay_t<A>::value_type;
53
55 using const_view_type = decltype(std::declval<const A>()());
56
59
61 mpi::communicator comm;
62
64 const int root{0}; // NOLINT (const is fine here)
65
67 const bool all{false}; // NOLINT (const is fine here)
68
80 [[nodiscard]] auto shape() const {
81 auto dims = rhs.shape();
82 long dim0 = dims[0];
83 if (!all) {
84 dims[0] = mpi::reduce(dim0, comm, root);
85 if (comm.rank() != root) dims[0] = 1;
86 } else
87 dims[0] = mpi::all_reduce(dim0, comm);
88 return dims;
89 }
90
97 template <nda::Array T>
98 void invoke(T &&target) const { // NOLINT (temporary views are allowed here)
99 // check if the arrays can be used in the MPI call
100 if (not target.is_contiguous()) NDA_RUNTIME_ERROR << "Error in MPI gather for nda::Array: Target array needs to be contiguous";
101 static_assert(std::decay_t<A>::layout_t::stride_order_encoded == std::decay_t<T>::layout_t::stride_order_encoded,
102 "Error in MPI gather for nda::Array: Incompatible stride orders");
103
104 // special case for non-mpi runs
105 if (not mpi::has_env) {
106 target = rhs;
107 return;
108 }
109
110 // get target shape and resize or check the target array
111 auto dims = shape();
112 if (all || (comm.rank() == root)) nda::resize_or_check_if_view(target, dims);
113
114 // gather receive counts and memory displacements
115 auto recvcounts = std::vector<int>(comm.size());
116 auto displs = std::vector<int>(comm.size() + 1, 0);
117 int sendcount = rhs.size();
118 auto mpi_int_type = mpi::mpi_type<int>::get();
119 if (!all)
120 MPI_Gather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, root, comm.get());
121 else
122 MPI_Allgather(&sendcount, 1, mpi_int_type, &recvcounts[0], 1, mpi_int_type, comm.get());
123
124 for (int r = 0; r < comm.size(); ++r) displs[r + 1] = recvcounts[r] + displs[r];
125
126 // gather the data
127 auto mpi_value_type = mpi::mpi_type<value_type>::get();
128 if (!all)
129 MPI_Gatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, root, comm.get());
130 else
131 MPI_Allgatherv((void *)rhs.data(), sendcount, mpi_value_type, target.data(), &recvcounts[0], &displs[0], mpi_value_type, comm.get());
132 }
133};
134
135namespace nda {
136
165 template <typename A>
166 ArrayInitializer<std::remove_reference_t<A>> auto mpi_gather(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false)
168 {
169 if (not a.is_contiguous()) NDA_RUNTIME_ERROR << "Error in MPI gather for nda::Array: Array needs to be contiguous";
170 return mpi::lazy<mpi::tag::gather, A>{std::forward<A>(a), comm, root, all};
171 }
172
173} // namespace nda
Provides basic functions to create and manipulate arrays and views.
Check if a given type satisfies the array initializer concept for a given nda::MemoryArray type.
Definition concepts.hpp:325
Provides concepts for the nda library.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
bool all(A const &a)
Do all elements of the array evaluate to true?
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
ArrayInitializer< std::remove_reference_t< A > > auto mpi_gather(A &&a, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types.
Definition gather.hpp:166
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:163
auto shape() const
Compute the shape of the target array.
Definition gather.hpp:80
const_view_type rhs
View of the array/view to be gathered.
Definition gather.hpp:58
void invoke(T &&target) const
Execute the lazy MPI operation and write the result to a target array/view.
Definition gather.hpp:98
decltype(std::declval< const A >()()) const_view_type
Const view type of the array/view stored in the lazy object.
Definition gather.hpp:55
typename std::decay_t< A >::value_type value_type
Value type of the array/view.
Definition gather.hpp:52
mpi::communicator comm
MPI communicator.
Definition gather.hpp:61
Provides type traits for the nda library.