32#include <itertools/itertools.hpp>
71 auto size =
static_cast<long>(std::ranges::size(rg));
72 EXPECTS_WITH_MESSAGE(
all_equal(size, c),
"Range sizes are not equal on all processes in mpi::broadcast_range");
75 if (size <= 0)
return;
79 if constexpr (MPICompatibleRange<R>) {
81 if (!
has_env || c.size() < 2)
return;
84 constexpr long max_int = std::numeric_limits<int>::max();
85 for (
long offset = 0; size > 0; offset += max_int, size -= max_int) {
86 auto const count =
static_cast<int>(std::min(size, max_int));
87 check_mpi_call(MPI_Bcast(std::ranges::data(rg) + offset, count, mpi_type<std::ranges::range_value_t<R>>::get(), root, c.get()),
"MPI_Bcast");
118 template <std::ranges::sized_range R1, std::ranges::sized_range R2>
120 MPI_Op op = MPI_SUM) {
122 auto size =
static_cast<long>(std::ranges::size(in_rg));
123 EXPECTS_WITH_MESSAGE(
all_equal(size, c),
"Input range sizes are not equal on all processes in mpi::reduce_range");
126 if (size <= 0)
return;
129 bool const receives = (c.rank() == root || all);
130 if (receives) EXPECTS_WITH_MESSAGE(size == std::ranges::size(out_rg),
"Input and output range sizes are not equal in mpi::reduce_range");
133 if constexpr (MPICompatibleRange<R1> && MPICompatibleRange<R2>) {
134 static_assert(std::same_as<std::remove_cvref_t<std::ranges::range_value_t<R1>>, std::remove_cvref_t<std::ranges::range_value_t<R2>>>,
135 "Value types of input and output ranges not compatible in mpi::reduce_range");
138 bool const in_place = (
static_cast<void const *
>(std::ranges::data(in_rg)) ==
static_cast<void *
>(std::ranges::data(out_rg)));
140 EXPECTS_WITH_MESSAGE(
all_equal(
static_cast<int>(in_place), c),
141 "Either zero or all receiving processes have to choose the in place option in mpi::reduce_range");
145 if (!
has_env || c.size() < 2) {
146 std::ranges::copy(std::forward<R1>(in_rg), std::ranges::data(out_rg));
151 constexpr long max_int = std::numeric_limits<int>::max();
152 for (
long offset = 0; size > 0; offset += max_int, size -= max_int) {
153 auto in_data =
static_cast<void const *
>(std::ranges::data(in_rg) + offset);
154 auto out_data = std::ranges::data(out_rg) + offset;
155 if (receives and in_place) in_data = MPI_IN_PLACE;
156 auto const count =
static_cast<int>(std::min(size, max_int));
158 check_mpi_call(MPI_Allreduce(in_data, out_data, count,
mpi_type<std::ranges::range_value_t<R1>>::get(), op, c.get()),
"MPI_Allreduce");
160 check_mpi_call(MPI_Reduce(in_data, out_data, count,
mpi_type<std::ranges::range_value_t<R1>>::get(), op, root, c.get()),
"MPI_Reduce");
165 if (size <= std::ranges::size(out_rg)) {
167 for (
auto &&[x_in, x_out] : itertools::zip(in_rg, out_rg))
reduce_into(x_in, x_out, c, root, all, op);
170 using out_value_t = std::ranges::range_value_t<R2>;
171 if constexpr (std::is_default_constructible_v<out_value_t>) {
172 out_value_t out_dummy{};
173 for (
auto &&x_in : in_rg)
reduce_into(x_in, out_dummy, c, root, all, op);
176 throw std::runtime_error(
"Cannot default construct dummy object in mpi::reduce_range");
211 template <MPICompatibleRange R1, MPICompatibleRange R2>
212 requires(std::same_as<std::remove_cvref_t<std::ranges::range_value_t<R1>>, std::remove_cvref_t<std::ranges::range_value_t<R2>>>)
214 long chunk_size = 1) {
216 EXPECTS_WITH_MESSAGE(
all_equal(scatter_size, c),
"Number of elements to be scattered is not equal on all processes in mpi::scatter_range");
219 if (scatter_size == 0)
return;
222 if (c.rank() == root) {
223 EXPECTS_WITH_MESSAGE(scatter_size == std::ranges::size(in_rg),
224 "Input range size on root is not equal the number of elements to be scattered in mpi::scatter_range");
228 auto const recvcount =
static_cast<int>(
chunk_length(scatter_size, c.size(), c.rank(), chunk_size));
229 EXPECTS_WITH_MESSAGE(recvcount == std::ranges::size(out_rg),
230 "Output range size is not equal the number of elements to be received in mpi::scatter_range");
233 if (!
has_env || c.size() < 2) {
234 std::ranges::copy(std::forward<R1>(in_rg), std::ranges::data(out_rg));
239 auto sendcounts = std::vector<int>(c.size());
240 auto displs = std::vector<int>(c.size() + 1, 0);
241 for (
int i = 0; i < c.size(); ++i) {
242 sendcounts[i] =
static_cast<int>(
chunk_length(scatter_size, c.size(), i, chunk_size));
243 displs[i + 1] = sendcounts[i] + displs[i];
247 check_mpi_call(MPI_Scatterv(std::ranges::data(in_rg), sendcounts.data(), displs.data(),
mpi_type<std::ranges::range_value_t<R1>>::get(),
248 std::ranges::data(out_rg), recvcount,
mpi_type<std::ranges::range_value_t<R2>>::get(), root, c.get()),
276 template <MPICompatibleRange R1, MPICompatibleRange R2>
277 requires(std::same_as<std::remove_cvref_t<std::ranges::range_value_t<R1>>, std::remove_cvref_t<std::ranges::range_value_t<R2>>>)
280 auto sendcount =
static_cast<int>(std::ranges::size(in_rg));
282 auto displs = std::vector<int>(c.size() + 1, 0);
283 std::partial_sum(recvcounts.begin(), recvcounts.end(), displs.begin() + 1);
286 if (displs.back() == 0)
return;
289 if (c.rank() == root || all) {
290 EXPECTS_WITH_MESSAGE(displs.back() == std::ranges::size(out_rg),
291 "Output range size is not equal the number of elements to be received in mpi::gather_range");
295 if (!
has_env || c.size() < 2) {
296 std::ranges::copy(std::forward<R1>(in_rg), std::ranges::data(out_rg));
302 check_mpi_call(MPI_Allgatherv(std::ranges::data(in_rg), sendcount,
mpi_type<std::ranges::range_value_t<R1>>::get(), std::ranges::data(out_rg),
303 recvcounts.data(), displs.data(),
mpi_type<std::ranges::range_value_t<R2>>::get(), c.get()),
306 check_mpi_call(MPI_Gatherv(std::ranges::data(in_rg), sendcount,
mpi_type<std::ranges::range_value_t<R1>>::get(), std::ranges::data(out_rg),
307 recvcounts.data(), displs.data(),
mpi_type<std::ranges::range_value_t<R2>>::get(), root, c.get()),
Provides utilities to distribute a range across MPI processes.
C++ wrapper around MPI_Comm providing various convenience functions.
Provides a C++ wrapper class for an MPI_Comm object.
Provides utilities to map C++ datatypes to MPI datatypes.
Provides an MPI environment for initializing and finalizing an MPI program.
Provides generic implementations for a subset of collective MPI communications (broadcast,...
void reduce_into(T1 &&x_in, T2 &&x_out, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic MPI reduce that reduces directly into an existing output object.
void reduce_range(R1 &&in_rg, R2 &&out_rg, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for std::ranges::sized_range objects.
void gather_range(R1 &&in_rg, R2 &&out_rg, communicator c={}, int root=0, bool all=false)
Implementation of an MPI gather for mpi::MPICompatibleRange objects.
void broadcast_range(R &&rg, communicator c={}, int root=0)
Implementation of an MPI broadcast for std::ranges::sized_range objects.
void scatter_range(R1 &&in_rg, R2 &&out_rg, long scatter_size, communicator c={}, int root=0, long chunk_size=1)
Implementation of an MPI scatter for mpi::MPICompatibleRange objects.
bool all_equal(T const &x, communicator c={})
Checks if a given object is equal across all ranks in the given communicator.
void broadcast(T &&x, communicator c={}, int root=0)
Generic MPI broadcast.
decltype(auto) all_gather(T &&x, communicator c={})
Generic MPI all-gather.
static const bool has_env
Boolean variable that is true, if one of the environment variables OMPI_COMM_WORLD_RANK,...
long chunk_length(long end, int nranges, int i, long min_size=1)
Get the length of the ith subrange after splitting the integer range [0, end) as evenly as possible a...
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Macros used in the mpi library.
Map C++ datatypes to the corresponding MPI datatypes.
Provides general utilities related to MPI.