TRIQS/mpi 1.3.0
C++ interface to MPI
Loading...
Searching...
No Matches
ranges.hpp
Go to the documentation of this file.
1// Copyright (c) 2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Alexander Hampel, Olivier Parcollet, Nils Wentzell
16
21
22#pragma once
23
24#include "./chunk.hpp"
25#include "./communicator.hpp"
26#include "./datatypes.hpp"
27#include "./environment.hpp"
29#include "./macros.hpp"
30#include "./utils.hpp"
31
32#include <itertools/itertools.hpp>
33#include <mpi.h>
34
35#include <algorithm>
36#include <concepts>
37#include <limits>
38#include <numeric>
39#include <ranges>
40#include <stdexcept>
41#include <type_traits>
42#include <utility>
43#include <vector>
44
45namespace mpi {
46
51
69 template <std::ranges::sized_range R> void broadcast_range(R &&rg, communicator c = {}, int root = 0) { // NOLINT (ranges need not be forwarded)
70 // check the size of the range
71 auto size = static_cast<long>(std::ranges::size(rg));
72 EXPECTS_WITH_MESSAGE(all_equal(size, c), "Range sizes are not equal on all processes in mpi::broadcast_range");
73
74 // do nothing if no elements are broadcasted
75 if (size <= 0) return;
76
77 // call the MPI C library if the ranges are contiguous with MPI compatible value types, otherwise do element-wise
78 // broadcasts
79 if constexpr (MPICompatibleRange<R>) {
80 // in case there is no active MPI environment or if the communicator size is < 2, do nothing
81 if (!has_env || c.size() < 2) return;
82
83 // make the MPI C library call (allow the number of elements to larger than INT_MAX)
84 constexpr long max_int = std::numeric_limits<int>::max();
85 for (long offset = 0; size > 0; offset += max_int, size -= max_int) {
86 auto const count = static_cast<int>(std::min(size, max_int));
87 check_mpi_call(MPI_Bcast(std::ranges::data(rg) + offset, count, mpi_type<std::ranges::range_value_t<R>>::get(), root, c.get()), "MPI_Bcast");
88 }
89 } else {
90 // otherwise call the generic broadcast for each element separately
91 for (auto &x : rg) broadcast(x, c, root);
92 }
93 }
94
118 template <std::ranges::sized_range R1, std::ranges::sized_range R2>
119 void reduce_range(R1 &&in_rg, R2 &&out_rg, communicator c = {}, int root = 0, bool all = false, // NOLINT (ranges need not be forwarded)
120 MPI_Op op = MPI_SUM) {
121 // check the size of the input range
122 auto size = static_cast<long>(std::ranges::size(in_rg));
123 EXPECTS_WITH_MESSAGE(all_equal(size, c), "Input range sizes are not equal on all processes in mpi::reduce_range");
124
125 // do nothing if no elements are reduced
126 if (size <= 0) return;
127
128 // check the size of the output range
129 bool const receives = (c.rank() == root || all);
130 if (receives) EXPECTS_WITH_MESSAGE(size == std::ranges::size(out_rg), "Input and output range sizes are not equal in mpi::reduce_range");
131
132 // call the MPI C library if the ranges are contiguous with MPI compatible value types
133 if constexpr (MPICompatibleRange<R1> && MPICompatibleRange<R2>) {
134 static_assert(std::same_as<std::remove_cvref_t<std::ranges::range_value_t<R1>>, std::remove_cvref_t<std::ranges::range_value_t<R2>>>,
135 "Value types of input and output ranges not compatible in mpi::reduce_range");
136
137 // check if the reduction is in place
138 bool const in_place = (static_cast<void const *>(std::ranges::data(in_rg)) == static_cast<void *>(std::ranges::data(out_rg)));
139 if (all) {
140 EXPECTS_WITH_MESSAGE(all_equal(static_cast<int>(in_place), c),
141 "Either zero or all receiving processes have to choose the in place option in mpi::reduce_range");
142 }
143
144 // in case there is no active MPI environment or if the communicator size is < 2, copy to the output range
145 if (!has_env || c.size() < 2) {
146 std::ranges::copy(std::forward<R1>(in_rg), std::ranges::data(out_rg));
147 return;
148 }
149
150 // make the MPI C library call (allow the number of elements to larger than INT_MAX)
151 constexpr long max_int = std::numeric_limits<int>::max();
152 for (long offset = 0; size > 0; offset += max_int, size -= max_int) {
153 auto in_data = static_cast<void const *>(std::ranges::data(in_rg) + offset);
154 auto out_data = std::ranges::data(out_rg) + offset;
155 if (receives and in_place) in_data = MPI_IN_PLACE;
156 auto const count = static_cast<int>(std::min(size, max_int));
157 if (all) {
158 check_mpi_call(MPI_Allreduce(in_data, out_data, count, mpi_type<std::ranges::range_value_t<R1>>::get(), op, c.get()), "MPI_Allreduce");
159 } else {
160 check_mpi_call(MPI_Reduce(in_data, out_data, count, mpi_type<std::ranges::range_value_t<R1>>::get(), op, root, c.get()), "MPI_Reduce");
161 }
162 }
163 } else {
164 // fallback to element-wise reduction if the range is not contiguous with an MPI compatible value type
165 if (size <= std::ranges::size(out_rg)) {
166 // on ranks where the output range size is large enough, reduce into the output elements
167 for (auto &&[x_in, x_out] : itertools::zip(in_rg, out_rg)) reduce_into(x_in, x_out, c, root, all, op);
168 } else {
169 // on all other ranks, reduce into a dummy output object (needs to be default constructible)
170 using out_value_t = std::ranges::range_value_t<R2>;
171 if constexpr (std::is_default_constructible_v<out_value_t>) {
172 out_value_t out_dummy{};
173 for (auto &&x_in : in_rg) reduce_into(x_in, out_dummy, c, root, all, op);
174 } else {
175 // if it is not default constructible, is there something we can do?
176 throw std::runtime_error("Cannot default construct dummy object in mpi::reduce_range");
177 }
178 }
179 }
180 }
181
211 template <MPICompatibleRange R1, MPICompatibleRange R2>
212 requires(std::same_as<std::remove_cvref_t<std::ranges::range_value_t<R1>>, std::remove_cvref_t<std::ranges::range_value_t<R2>>>)
213 void scatter_range(R1 &&in_rg, R2 &&out_rg, long scatter_size, communicator c = {}, int root = 0, // NOLINT (ranges need not be forwarded)
214 long chunk_size = 1) {
215 // check the number of elements to be scattered
216 EXPECTS_WITH_MESSAGE(all_equal(scatter_size, c), "Number of elements to be scattered is not equal on all processes in mpi::scatter_range");
217
218 // do nothing if no elements are scattered
219 if (scatter_size == 0) return;
220
221 // check the size of the input range on root
222 if (c.rank() == root) {
223 EXPECTS_WITH_MESSAGE(scatter_size == std::ranges::size(in_rg),
224 "Input range size on root is not equal the number of elements to be scattered in mpi::scatter_range");
225 }
226
227 // check the size of the output range
228 auto const recvcount = static_cast<int>(chunk_length(scatter_size, c.size(), c.rank(), chunk_size));
229 EXPECTS_WITH_MESSAGE(recvcount == std::ranges::size(out_rg),
230 "Output range size is not equal the number of elements to be received in mpi::scatter_range");
231
232 // in case there is no active MPI environment or if the communicator size is < 2, copy to output range
233 if (!has_env || c.size() < 2) {
234 std::ranges::copy(std::forward<R1>(in_rg), std::ranges::data(out_rg));
235 return;
236 }
237
238 // prepare arguments for the MPI call
239 auto sendcounts = std::vector<int>(c.size());
240 auto displs = std::vector<int>(c.size() + 1, 0);
241 for (int i = 0; i < c.size(); ++i) {
242 sendcounts[i] = static_cast<int>(chunk_length(scatter_size, c.size(), i, chunk_size));
243 displs[i + 1] = sendcounts[i] + displs[i];
244 }
245
246 // make the MPI C library call
247 check_mpi_call(MPI_Scatterv(std::ranges::data(in_rg), sendcounts.data(), displs.data(), mpi_type<std::ranges::range_value_t<R1>>::get(),
248 std::ranges::data(out_rg), recvcount, mpi_type<std::ranges::range_value_t<R2>>::get(), root, c.get()),
249 "MPI_Scatterv");
250 }
251
276 template <MPICompatibleRange R1, MPICompatibleRange R2>
277 requires(std::same_as<std::remove_cvref_t<std::ranges::range_value_t<R1>>, std::remove_cvref_t<std::ranges::range_value_t<R2>>>)
278 void gather_range(R1 &&in_rg, R2 &&out_rg, communicator c = {}, int root = 0, bool all = false) { // NOLINT (ranges need not be forwarded)
279 // get the receive counts (sendcount from each process) and the displacements
280 auto sendcount = static_cast<int>(std::ranges::size(in_rg));
281 auto recvcounts = all_gather(sendcount, c);
282 auto displs = std::vector<int>(c.size() + 1, 0);
283 std::partial_sum(recvcounts.begin(), recvcounts.end(), displs.begin() + 1);
284
285 // do nothing if there are no elements to gather
286 if (displs.back() == 0) return;
287
288 // check the size of the output range on receiving ranks
289 if (c.rank() == root || all) {
290 EXPECTS_WITH_MESSAGE(displs.back() == std::ranges::size(out_rg),
291 "Output range size is not equal the number of elements to be received in mpi::gather_range");
292 }
293
294 // in case there is no active MPI environment or if the communicator size is < 2, copy to the output range
295 if (!has_env || c.size() < 2) {
296 std::ranges::copy(std::forward<R1>(in_rg), std::ranges::data(out_rg));
297 return;
298 }
299
300 // make the MPI C library call
301 if (all) {
302 check_mpi_call(MPI_Allgatherv(std::ranges::data(in_rg), sendcount, mpi_type<std::ranges::range_value_t<R1>>::get(), std::ranges::data(out_rg),
303 recvcounts.data(), displs.data(), mpi_type<std::ranges::range_value_t<R2>>::get(), c.get()),
304 "MPI_Allgatherv");
305 } else {
306 check_mpi_call(MPI_Gatherv(std::ranges::data(in_rg), sendcount, mpi_type<std::ranges::range_value_t<R1>>::get(), std::ranges::data(out_rg),
307 recvcounts.data(), displs.data(), mpi_type<std::ranges::range_value_t<R2>>::get(), root, c.get()),
308 "MPI_Gatherv");
309 }
310 }
311
313
314} // namespace mpi
Provides utilities to distribute a range across MPI processes.
C++ wrapper around MPI_Comm providing various convenience functions.
Provides a C++ wrapper class for an MPI_Comm object.
Provides utilities to map C++ datatypes to MPI datatypes.
Provides an MPI environment for initializing and finalizing an MPI program.
Provides generic implementations for a subset of collective MPI communications (broadcast,...
void reduce_into(T1 &&x_in, T2 &&x_out, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic MPI reduce that reduces directly into an existing output object.
void reduce_range(R1 &&in_rg, R2 &&out_rg, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for std::ranges::sized_range objects.
Definition ranges.hpp:119
void gather_range(R1 &&in_rg, R2 &&out_rg, communicator c={}, int root=0, bool all=false)
Implementation of an MPI gather for mpi::MPICompatibleRange objects.
Definition ranges.hpp:278
void broadcast_range(R &&rg, communicator c={}, int root=0)
Implementation of an MPI broadcast for std::ranges::sized_range objects.
Definition ranges.hpp:69
void scatter_range(R1 &&in_rg, R2 &&out_rg, long scatter_size, communicator c={}, int root=0, long chunk_size=1)
Implementation of an MPI scatter for mpi::MPICompatibleRange objects.
Definition ranges.hpp:213
bool all_equal(T const &x, communicator c={})
Checks if a given object is equal across all ranks in the given communicator.
void broadcast(T &&x, communicator c={}, int root=0)
Generic MPI broadcast.
decltype(auto) all_gather(T &&x, communicator c={})
Generic MPI all-gather.
static const bool has_env
Boolean variable that is true, if one of the environment variables OMPI_COMM_WORLD_RANK,...
long chunk_length(long end, int nranges, int i, long min_size=1)
Get the length of the ith subrange after splitting the integer range [0, end) as evenly as possible a...
Definition chunk.hpp:50
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Definition utils.hpp:51
Macros used in the mpi library.
Map C++ datatypes to the corresponding MPI datatypes.
Definition datatypes.hpp:57
Provides general utilities related to MPI.