TRIQS/mpi 1.3.0
C++ interface to MPI
Loading...
Searching...
No Matches
generic_communication.hpp
Go to the documentation of this file.
1// Copyright (c) 2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Alexander Hampel, Olivier Parcollet, Nils Wentzell
16
24
25#pragma once
26
27#include "./communicator.hpp"
28#include "./datatypes.hpp"
29#include "./macros.hpp"
30#include "./utils.hpp"
31
32#include <mpi.h>
33
34#include <algorithm>
35#include <concepts>
36#include <ranges>
37#include <type_traits>
38#include <vector>
39
40namespace mpi {
41
47 template <typename R>
48 concept MPICompatibleRange = std::ranges::contiguous_range<R> && std::ranges::sized_range<R> && has_mpi_type<std::ranges::range_value_t<R>>;
49
54
68 template <typename T> [[gnu::always_inline]] void broadcast(T &&x, communicator c = {}, int root = 0) { // NOLINT (forwarding is not needed)
69 mpi_broadcast(x, c, root);
70 }
71
89 template <typename T>
90 [[gnu::always_inline]] decltype(auto) reduce(T &&x, communicator c = {}, int root = 0, bool all = false, // NOLINT (forwarding is not needed)
91 MPI_Op op = MPI_SUM) {
92 if constexpr (requires { mpi_reduce(x, c, root, all, op); }) {
93 return mpi_reduce(x, c, root, all, op);
94 } else {
95 std::remove_cvref_t<T> res;
96 reduce_into(x, res, c, root, all, op);
97 return res;
98 }
99 }
100
116 template <typename T>
117 [[gnu::always_inline]] void reduce_in_place(T &&x, communicator c = {}, int root = 0, bool all = false, // NOLINT (forwarding is not needed)
118 MPI_Op op = MPI_SUM) {
119 mpi_reduce_into(x, x, c, root, all, op);
120 }
121
139 template <typename T1, typename T2>
140 [[gnu::always_inline]] void reduce_into(T1 &&x_in, T2 &&x_out, communicator c = {}, int root = 0, // NOLINT (forwarding is not needed)
141 bool all = false, MPI_Op op = MPI_SUM) {
142 mpi_reduce_into(x_in, x_out, c, root, all, op);
143 }
144
160 template <typename T>
161 [[gnu::always_inline]] decltype(auto) scatter(T &&x, mpi::communicator c = {}, int root = 0) { // NOLINT (forwarding is not needed)
162 if constexpr (requires { mpi_scatter(x, c, root); }) {
163 return mpi_scatter(x, c, root);
164 } else {
165 std::remove_cvref_t<T> res;
166 scatter_into(x, res, c, root);
167 return res;
168 }
169 }
170
186 template <typename T1, typename T2>
187 [[gnu::always_inline]] void scatter_into(T1 &&x_in, T2 &&x_out, communicator c = {}, int root = 0) { // NOLINT (forwarding is not needed)
188 mpi_scatter_into(x_in, x_out, c, root);
189 }
190
207 template <typename T>
208 [[gnu::always_inline]] decltype(auto) gather(T &&x, communicator c = {}, int root = 0, bool all = false) { // NOLINT (forwarding is not needed)
209 if constexpr (requires { mpi_gather(x, c, root, all); }) {
210 return mpi_gather(x, c, root, all);
211 } else {
212 std::remove_cvref_t<T> res;
213 gather_into(x, res, c, root, all);
214 return res;
215 }
216 }
217
234 template <typename T1, typename T2>
235 [[gnu::always_inline]] void gather_into(T1 &&x_in, T2 &&x_out, communicator c = {}, int root = 0, // NOLINT (forwarding is not needed)
236 bool all = false) {
237 mpi_gather_into(x_in, x_out, c, root, all);
238 }
239
244 template <typename T>
245 [[gnu::always_inline]] decltype(auto) all_reduce(T &&x, communicator c = {}, MPI_Op op = MPI_SUM) { // NOLINT (forwarding is not needed)
246 return reduce(x, c, 0, true, op);
247 }
248
253 template <typename T>
254 [[gnu::always_inline]] void all_reduce_in_place(T &&x, communicator c = {}, MPI_Op op = MPI_SUM) { // NOLINT (forwarding is not needed)
255 reduce_in_place(x, c, 0, true, op);
256 }
257
262 template <typename T1, typename T2>
263 [[gnu::always_inline]] void all_reduce_into(T1 &&x_in, T2 &&x_out, communicator c = {}, MPI_Op op = MPI_SUM) { // NOLINT (forwarding is not needed)
264 return reduce_into(x_in, x_out, c, 0, true, op);
265 }
266
271 template <typename T> [[gnu::always_inline]] decltype(auto) all_gather(T &&x, communicator c = {}) { // NOLINT (forwarding is not needed)
272 return gather(x, c, 0, true);
273 }
274
279 template <typename T1, typename T2>
280 [[gnu::always_inline]] void all_gather_into(T1 &&x_in, T2 &&x_out, communicator c = {}) { // NOLINT (forwarding is not needed)
281 return gather_into(x_in, x_out, c, 0, true);
282 }
283
297 template <typename T> bool all_equal(T const &x, communicator c = {}) {
298 if (!has_env || c.size() < 2) return true;
299 auto min_obj = all_reduce(x, c, MPI_MIN);
300 auto max_obj = all_reduce(x, c, MPI_MAX);
301 return min_obj == max_obj;
302 }
303
317 template <typename T>
318 requires(has_mpi_type<T>)
319 void mpi_broadcast(T &x, communicator c = {}, int root = 0) {
320 // in case there is no active MPI environment or if the communicator size is < 2, do nothing
321 if (!has_env || c.size() < 2) return;
322
323 // make the MPI C library call
324 check_mpi_call(MPI_Bcast(&x, 1, mpi_type<T>::get(), root, c.get()), "MPI_Bcast");
325 }
326
343 template <typename T>
344 requires(has_mpi_type<T>)
345 T mpi_reduce(T const &x, communicator c = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
346 // in case there is no active MPI environment or if the communicator size is < 2, return the input object
347 if (!has_env || c.size() < 2) return x;
348
349 // make the MPI C library call with a default constructed output object
350 T res;
351 if (all) {
352 check_mpi_call(MPI_Allreduce(&x, &res, 1, mpi_type<T>::get(), op, c.get()), "MPI_Allreduce");
353 } else {
354 check_mpi_call(MPI_Reduce(&x, &res, 1, mpi_type<T>::get(), op, root, c.get()), "MPI_Reduce");
355 }
356 return res;
357 }
358
379 template <typename T>
380 requires(has_mpi_type<T>)
381 void mpi_reduce_into(T const &x_in, T &x_out, communicator c = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
382 // check if the reduction is in place
383 auto in_ptr = static_cast<void const *>(&x_in);
384 auto out_ptr = static_cast<void *>(&x_out);
385 bool const in_place = (in_ptr == out_ptr);
386 if (all) {
387 EXPECTS_WITH_MESSAGE(all_equal(static_cast<int>(in_place), c),
388 "Either zero or all receiving processes have to choose the in place option in mpi_reduce_into");
389 }
390
391 // in case there is no active MPI environment or if the communicator size is < 2, do nothing (in place) or copy
392 if (!has_env || c.size() < 2) {
393 if (!in_place) x_out = x_in;
394 return;
395 }
396
397 // make the MPI C library call
398 if (in_place && (c.rank() == root || all)) in_ptr = MPI_IN_PLACE;
399 if (all) {
400 check_mpi_call(MPI_Allreduce(in_ptr, out_ptr, 1, mpi_type<T>::get(), op, c.get()), "MPI_Allreduce");
401 } else {
402 check_mpi_call(MPI_Reduce(in_ptr, out_ptr, 1, mpi_type<T>::get(), op, root, c.get()), "MPI_Reduce");
403 }
404 }
405
419 template <typename T>
420 requires(has_mpi_type<T>)
421 std::vector<T> mpi_gather(T const &x, communicator c = {}, int root = 0, bool all = false) {
422 std::vector<T> res(c.rank() == root || all ? c.size() : 0);
423 mpi_gather_into(x, res, c, root, all);
424 return res;
425 }
426
445 template <typename T, MPICompatibleRange R>
446 requires(has_mpi_type<T> && std::same_as<T, std::remove_cvref_t<std::ranges::range_value_t<R>>>)
447 void mpi_gather_into(T const &x, R &&rg, communicator c = {}, int root = 0, bool all = false) { // NOLINT (ranges need not be forwarded)
448 // check the size of the output range
449 if (c.rank() == root || all) {
450 EXPECTS_WITH_MESSAGE(c.size() == std::ranges::size(rg), "Output range size is not equal the number of ranks in mpi_gather_into");
451 }
452
453 // in case there is no active MPI environment or if the communicator size is < 2, copy the input into the range
454 if (!has_env || c.size() < 2) {
455 std::ranges::copy(std::views::single(x), std::ranges::begin(rg));
456 return;
457 }
458
459 // make the MPI C library call
460 using value_t = std::ranges::range_value_t<R>;
461 if (all) {
462 check_mpi_call(MPI_Allgather(&x, 1, mpi_type<T>::get(), std::ranges::data(rg), 1, mpi_type<value_t>::get(), c.get()), "MPI_Allgather");
463 } else {
464 check_mpi_call(MPI_Gather(&x, 1, mpi_type<T>::get(), std::ranges::data(rg), 1, mpi_type<value_t>::get(), root, c.get()), "MPI_Gather");
465 }
466 }
467
469
470} // namespace mpi
C++ wrapper around MPI_Comm providing various convenience functions.
Provides a C++ wrapper class for an MPI_Comm object.
A concept that checks if a range type is contiguous and sized and has an MPI compatible value type.
Provides utilities to map C++ datatypes to MPI datatypes.
void scatter_into(T1 &&x_in, T2 &&x_out, communicator c={}, int root=0)
Generic MPI scatter that scatters directly into an existing output object.
decltype(auto) scatter(T &&x, mpi::communicator c={}, int root=0)
Generic MPI scatter.
void gather_into(T1 &&x_in, T2 &&x_out, communicator c={}, int root=0, bool all=false)
Generic MPI gather that gathers directly into an existing output object.
void all_gather_into(T1 &&x_in, T2 &&x_out, communicator c={})
Generic MPI all-gather that gathers directly into an existing output object.
auto mpi_reduce(std::array< T, N > const &arr, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for a std::array.
Definition array.hpp:74
void mpi_reduce_into(std::array< T1, N1 > const &arr_in, std::array< T2, N2 > &arr_out, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for a std::array that reduces directly into an existing output array.
Definition array.hpp:99
void mpi_scatter_into(std::vector< T > const &v_in, std::vector< T > &v_out, communicator c={}, int root=0)
Implementation of an MPI scatter for a std::vector that scatters directly into an existing output vec...
Definition vector.hpp:119
void mpi_gather_into(T const &x, R &&rg, communicator c={}, int root=0, bool all=false)
Implementation of an MPI gather that gathers directly into an existing output range for types that ha...
void reduce_into(T1 &&x_in, T2 &&x_out, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic MPI reduce that reduces directly into an existing output object.
void reduce_in_place(T &&x, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic in place MPI reduce.
void all_reduce_in_place(T &&x, communicator c={}, MPI_Op op=MPI_SUM)
Generic MPI all-reduce in place.
decltype(auto) reduce(T &&x, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic MPI reduce.
void mpi_broadcast(std::array< T, N > &arr, communicator c={}, int root=0)
Implementation of an MPI broadcast for a std::array.
Definition array.hpp:53
bool all_equal(T const &x, communicator c={})
Checks if a given object is equal across all ranks in the given communicator.
std::vector< T > mpi_gather(T const &x, communicator c={}, int root=0, bool all=false)
Implementation of an MPI gather for types that have a corresponding MPI datatype.
decltype(auto) all_reduce(T &&x, communicator c={}, MPI_Op op=MPI_SUM)
Generic MPI all-reduce.
void broadcast(T &&x, communicator c={}, int root=0)
Generic MPI broadcast.
void all_reduce_into(T1 &&x_in, T2 &&x_out, communicator c={}, MPI_Op op=MPI_SUM)
Generic MPI all-reduce that reduces directly into an existing output object.
decltype(auto) gather(T &&x, communicator c={}, int root=0, bool all=false)
Generic MPI gather.
decltype(auto) all_gather(T &&x, communicator c={})
Generic MPI all-gather.
static const bool has_env
Boolean variable that is true, if one of the environment variables OMPI_COMM_WORLD_RANK,...
constexpr bool has_mpi_type
Type trait to check if a type T has a corresponding MPI datatype, i.e. if mpi::mpi_type has been spec...
Definition datatypes.hpp:98
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Definition utils.hpp:51
Macros used in the mpi library.
Map C++ datatypes to the corresponding MPI datatypes.
Definition datatypes.hpp:57
Provides general utilities related to MPI.