32#include <itertools/itertools.hpp>
95 using value_t = std::ranges::range_value_t<R>;
96 auto const size = std::ranges::size(rg);
97 EXPECTS_WITH_MESSAGE(
all_equal(size, c),
"Range sizes are not equal across all processes in mpi::broadcast_range");
100 if (size == 0 || !
has_env || c.size() < 2)
return;
105 check_mpi_call(MPI_Bcast(std::ranges::data(rg), size, mpi_type<value_t>::get(), root, c.get()),
"MPI_Bcast");
108 for (
auto &val : rg)
broadcast(val, c, root);
154 template <contiguous_sized_range R>
156 MPI_Op op = MPI_SUM) {
158 using value_t = std::ranges::range_value_t<R>;
159 auto const size = std::ranges::size(rg);
160 EXPECTS_WITH_MESSAGE(
all_equal(size, c),
"Range sizes are not equal across all processes in mpi::reduce_in_place_range");
163 if (size == 0 || !
has_env || c.size() < 2)
return;
168 auto data = std::ranges::data(rg);
170 check_mpi_call(MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : data), data, size, mpi_type<value_t>::get(), op, root, c.get()),
"MPI_Reduce");
172 check_mpi_call(MPI_Allreduce(MPI_IN_PLACE, data, size, mpi_type<value_t>::get(), op, c.get()),
"MPI_Allreduce");
225 template <contiguous_sized_range R1, contiguous_sized_range R2>
227 MPI_Op op = MPI_SUM) {
229 auto const in_size = std::ranges::size(in_rg);
230 EXPECTS_WITH_MESSAGE(
all_equal(in_size, c),
"Input range sizes are not equal across all processes in mpi::reduce_range");
231 if (c.rank() == root || all) {
232 EXPECTS_WITH_MESSAGE(in_size == std::ranges::size(out_rg),
"Input and output range sizes are not equal in mpi::reduce_range");
236 if (in_size == 0)
return;
239 if (!
has_env || c.size() < 2) {
240 std::ranges::copy(std::forward<R1>(in_rg), std::ranges::data(out_rg));
245 using in_value_t = std::ranges::range_value_t<R1>;
246 using out_value_t = std::ranges::range_value_t<R2>;
249 auto const in_data = std::ranges::data(in_rg);
250 auto out_data = std::ranges::data(out_rg);
252 check_mpi_call(MPI_Reduce(in_data, out_data, in_size, mpi_type<in_value_t>::get(), op, root, c.get()),
"MPI_Reduce");
254 check_mpi_call(MPI_Allreduce(in_data, out_data, in_size, mpi_type<in_value_t>::get(), op, c.get()),
"MPI_Allreduce");
258 if (c.rank() == root || all)
259 std::ranges::transform(std::forward<R1>(in_rg), std::ranges::data(out_rg), [&](
auto const &val) {
return reduce(val, c, root, all, op); });
262 std::ranges::for_each(std::forward<R1>(in_rg), [&](
auto const &val) { [[maybe_unused]] out_value_t ignore =
reduce(val, c, root, all, op); });
315 template <contiguous_sized_range R1, contiguous_sized_range R2>
316 requires(std::same_as<std::ranges::range_value_t<R1>, std::ranges::range_value_t<R2>>)
318 long chunk_size = 1) {
320 if (c.rank() == root) {
321 EXPECTS_WITH_MESSAGE(in_size == std::ranges::size(in_rg),
"Input range size not equal to provided size in mpi::scatter_range");
323 EXPECTS_WITH_MESSAGE(in_size ==
all_reduce(std::ranges::size(out_rg), c),
324 "Output range sizes don't add up to input range size in mpi::scatter_range");
327 if (in_size == 0)
return;
330 if (!
has_env || c.size() < 2) {
331 std::ranges::copy(std::forward<R1>(in_rg), std::ranges::data(out_rg));
336 int recvcount =
static_cast<int>(
chunk_length(in_size, c.size(), c.rank(), chunk_size));
337 EXPECTS_WITH_MESSAGE(recvcount == std::ranges::size(out_rg),
"Output range size is incorrect in mpi::scatter_range");
340 auto sendcounts = std::vector<int>(c.size());
341 auto displs = std::vector<int>(c.size() + 1, 0);
342 for (
int i = 0; i < c.size(); ++i) {
343 sendcounts[i] =
static_cast<int>(
chunk_length(in_size, c.size(), i, chunk_size));
344 displs[i + 1] = sendcounts[i] + displs[i];
348 using in_value_t = std::ranges::range_value_t<R1>;
349 using out_value_t = std::ranges::range_value_t<R2>;
352 auto const in_data = std::ranges::data(in_rg);
353 auto out_data = std::ranges::data(out_rg);
354 check_mpi_call(MPI_Scatterv(in_data, sendcounts.data(), displs.data(), mpi_type<in_value_t>::get(), out_data, recvcount,
355 mpi_type<out_value_t>::get(), root, c.get()),
359 throw std::runtime_error{
"Error in mpi::scatter_range: Types with no corresponding datatype can only be all-gathered"};
411 template <contiguous_sized_range R1, contiguous_sized_range R2>
415 auto const in_size = std::ranges::size(in_rg);
416 EXPECTS_WITH_MESSAGE(out_size =
all_reduce(in_size, c),
"Input range sizes don't add up to output range size in mpi::gather_range");
417 if (c.rank() == root || all) {
418 EXPECTS_WITH_MESSAGE(out_size == std::ranges::size(out_rg),
"Output range size is incorrect in mpi::gather_range");
422 if (out_size == 0)
return;
425 if (!
has_env || c.size() < 2) {
426 std::ranges::copy(std::forward<R1>(in_rg), std::ranges::data(out_rg));
431 auto recvcounts = std::vector<int>(c.size());
432 auto displs = std::vector<int>(c.size() + 1, 0);
433 int sendcount = in_size;
435 check_mpi_call(MPI_Gather(&sendcount, 1, mpi_type<int>::get(), recvcounts.data(), 1, mpi_type<int>::get(), root, c.get()),
"MPI_Gather");
437 check_mpi_call(MPI_Allgather(&sendcount, 1, mpi_type<int>::get(), recvcounts.data(), 1, mpi_type<int>::get(), c.get()),
"MPI_Allgather");
438 for (
int i = 0; i < c.size(); ++i) displs[i + 1] = recvcounts[i] + displs[i];
441 using in_value_t = std::ranges::range_value_t<R1>;
442 using out_value_t = std::ranges::range_value_t<R2>;
445 auto const in_data = std::ranges::data(in_rg);
446 auto out_data = std::ranges::data(out_rg);
448 check_mpi_call(MPI_Gatherv(in_data, sendcount, mpi_type<in_value_t>::get(), out_data, recvcounts.data(), displs.data(),
449 mpi_type<out_value_t>::get(), root, c.get()),
452 check_mpi_call(MPI_Allgatherv(in_data, sendcount, mpi_type<in_value_t>::get(), out_data, recvcounts.data(), displs.data(),
453 mpi_type<out_value_t>::get(), c.get()),
458 for (
int i = 0; i < c.size(); ++i) {
459 auto view = std::views::drop(out_rg, displs[i]) | std::views::take(displs[i + 1] - displs[i]);
460 if (c.rank() == i) std::ranges::copy(in_rg, std::ranges::begin(view));
465 throw std::runtime_error{
"Error in mpi::gather_range: Types with no corresponding datatype can only be all-gathered"};
Provides utilities to distribute a range across MPI processes.
C++ wrapper around MPI_Comm providing various convenience functions.
Provides a C++ wrapper class for an MPI_Comm object.
Provides utilities to map C++ datatypes to MPI datatypes.
Provides an MPI environment for initializing and finalizing an MPI program.
Provides generic implementations for a subset of collective MPI communications (broadcast,...
void reduce_in_place_range(R &&rg, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an in-place MPI reduce for an mpi::contiguous_sized_range object.
void reduce_range(R1 &&in_rg, R2 &&out_rg, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for an mpi::contiguous_sized_range.
void scatter_range(R1 &&in_rg, R2 &&out_rg, long in_size, communicator c={}, int root=0, long chunk_size=1)
Implementation of an MPI scatter for an mpi::contiguous_sized_range.
decltype(auto) reduce(T &&x, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic MPI reduce.
void broadcast_range(R &&rg, communicator c={}, int root=0)
Implementation of an MPI broadcast for an mpi::contiguous_sized_range object.
void gather_range(R1 &&in_rg, R2 &&out_rg, long out_size, communicator c={}, int root=0, bool all=false)
Implementation of an MPI gather for an mpi::contiguous_sized_range.
bool all_equal(T const &x, communicator c={})
Checks if a given object is equal across all ranks in the given communicator.
decltype(auto) all_reduce(T &&x, communicator c={}, MPI_Op op=MPI_SUM)
Generic MPI all-reduce.
void broadcast(T &&x, communicator c={}, int root=0)
Generic MPI broadcast.
void mpi_reduce_in_place(std::array< T, N > &arr, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an in-place MPI reduce for a std::array.
static const bool has_env
Boolean variable that is true, if one of the environment variables OMPI_COMM_WORLD_RANK,...
constexpr bool has_mpi_type
Type trait to check if a type T has a corresponding MPI datatype, i.e. if mpi::mpi_type has been spec...
long chunk_length(long end, int nranges, int i, long min_size=1)
Get the length of the ith subrange after splitting the integer range [0, end) as evenly as possible a...
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Macros used in the mpi library.
Provides general utilities related to MPI.