TRIQS/mpi 2.0.0
C++ interface to MPI
Loading...
Searching...
No Matches
generic_communication.hpp
Go to the documentation of this file.
1// Copyright (c) 2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Alexander Hampel, Olivier Parcollet, Nils Wentzell
16
24
25#pragma once
26
27#include "./communicator.hpp"
28#include "./datatypes.hpp"
29#include "./macros.hpp"
30#include "./utils.hpp"
31
32#include <mpi.h>
33
34#include <algorithm>
35#include <concepts>
36#include <ranges>
37#include <type_traits>
38#include <vector>
39
40namespace mpi {
41
47 template <typename R>
48 concept MPICompatibleRange = std::ranges::contiguous_range<R> && std::ranges::sized_range<R> && has_mpi_type<std::ranges::range_value_t<R>>;
49
54
68 template <typename T> void broadcast(T &&x, communicator c = {}, int root = 0) { // NOLINT (forwarding is not needed)
69 mpi_broadcast(x, c, root);
70 }
71
89 template <typename T>
90 decltype(auto) reduce(T &&x, communicator c = {}, int root = 0, bool all = false, // NOLINT (forwarding is not needed)
91 MPI_Op op = MPI_SUM) {
92 if constexpr (requires { mpi_reduce(x, c, root, all, op); }) {
93 return mpi_reduce(x, c, root, all, op);
94 } else {
95 std::remove_cvref_t<T> res;
96 reduce_into(x, res, c, root, all, op);
97 return res;
98 }
99 }
100
116 template <typename T>
117 void reduce_in_place(T &&x, communicator c = {}, int root = 0, bool all = false, // NOLINT (forwarding is not needed)
118 MPI_Op op = MPI_SUM) {
119 mpi_reduce_into(x, x, c, root, all, op);
120 }
121
139 template <typename T1, typename T2>
140 void reduce_into(T1 &&x_in, T2 &&x_out, communicator c = {}, int root = 0, // NOLINT (forwarding is not needed)
141 bool all = false, MPI_Op op = MPI_SUM) {
142 mpi_reduce_into(x_in, x_out, c, root, all, op);
143 }
144
160 template <typename T> decltype(auto) scatter(T &&x, mpi::communicator c = {}, int root = 0) { // NOLINT (forwarding is not needed)
161 if constexpr (requires { mpi_scatter(x, c, root); }) {
162 return mpi_scatter(x, c, root);
163 } else {
164 std::remove_cvref_t<T> res;
165 scatter_into(x, res, c, root);
166 return res;
167 }
168 }
169
185 template <typename T1, typename T2>
186 void scatter_into(T1 &&x_in, T2 &&x_out, communicator c = {}, int root = 0) { // NOLINT (forwarding is not needed)
187 mpi_scatter_into(x_in, x_out, c, root);
188 }
189
206 template <typename T> decltype(auto) gather(T &&x, communicator c = {}, int root = 0, bool all = false) { // NOLINT (forwarding is not needed)
207 if constexpr (requires { mpi_gather(x, c, root, all); }) {
208 return mpi_gather(x, c, root, all);
209 } else {
210 std::remove_cvref_t<T> res;
211 gather_into(x, res, c, root, all);
212 return res;
213 }
214 }
215
232 template <typename T1, typename T2>
233 void gather_into(T1 &&x_in, T2 &&x_out, communicator c = {}, int root = 0, // NOLINT (forwarding is not needed)
234 bool all = false) {
235 mpi_gather_into(x_in, x_out, c, root, all);
236 }
237
242 template <typename T> decltype(auto) all_reduce(T &&x, communicator c = {}, MPI_Op op = MPI_SUM) { // NOLINT (forwarding is not needed)
243 return reduce(x, c, 0, true, op);
244 }
245
250 template <typename T> void all_reduce_in_place(T &&x, communicator c = {}, MPI_Op op = MPI_SUM) { // NOLINT (forwarding is not needed)
251 reduce_in_place(x, c, 0, true, op);
252 }
253
258 template <typename T1, typename T2>
259 void all_reduce_into(T1 &&x_in, T2 &&x_out, communicator c = {}, MPI_Op op = MPI_SUM) { // NOLINT (forwarding is not needed)
260 return reduce_into(x_in, x_out, c, 0, true, op);
261 }
262
267 template <typename T> decltype(auto) all_gather(T &&x, communicator c = {}) { // NOLINT (forwarding is not needed)
268 return gather(x, c, 0, true);
269 }
270
275 template <typename T1, typename T2> void all_gather_into(T1 &&x_in, T2 &&x_out, communicator c = {}) { // NOLINT (forwarding is not needed)
276 return gather_into(x_in, x_out, c, 0, true);
277 }
278
292 template <typename T> bool all_equal(T const &x, communicator c = {}) {
293 if (!has_env || c.size() < 2) return true;
294 auto min_obj = all_reduce(x, c, MPI_MIN);
295 auto max_obj = all_reduce(x, c, MPI_MAX);
296 return min_obj == max_obj;
297 }
298
312 template <typename T>
313 requires(has_mpi_type<T>)
314 void mpi_broadcast(T &x, communicator c = {}, int root = 0) {
315 // in case there is no active MPI environment or if the communicator size is < 2, do nothing
316 if (!has_env || c.size() < 2) return;
317
318 // make the MPI C library call
319 check_mpi_call(MPI_Bcast(&x, 1, mpi_type<T>::get(), root, c.get()), "MPI_Bcast");
320 }
321
338 template <typename T>
339 requires(has_mpi_type<T>)
340 T mpi_reduce(T const &x, communicator c = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
341 // in case there is no active MPI environment or if the communicator size is < 2, return the input object
342 if (!has_env || c.size() < 2) return x;
343
344 // make the MPI C library call with a default constructed output object
345 T res;
346 if (all) {
347 check_mpi_call(MPI_Allreduce(&x, &res, 1, mpi_type<T>::get(), op, c.get()), "MPI_Allreduce");
348 } else {
349 check_mpi_call(MPI_Reduce(&x, &res, 1, mpi_type<T>::get(), op, root, c.get()), "MPI_Reduce");
350 }
351 return res;
352 }
353
374 template <typename T>
375 requires(has_mpi_type<T>)
376 void mpi_reduce_into(T const &x_in, T &x_out, communicator c = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
377 // check if the reduction is in place
378 auto in_ptr = static_cast<void const *>(&x_in);
379 auto out_ptr = static_cast<void *>(&x_out);
380 bool const in_place = (in_ptr == out_ptr);
381 if (all) {
382 EXPECTS_WITH_MESSAGE(all_equal(static_cast<int>(in_place), c),
383 "Either zero or all receiving processes have to choose the in place option in mpi_reduce_into");
384 }
385
386 // in case there is no active MPI environment or if the communicator size is < 2, do nothing (in place) or copy
387 if (!has_env || c.size() < 2) {
388 if (!in_place) x_out = x_in;
389 return;
390 }
391
392 // make the MPI C library call
393 if (in_place && (c.rank() == root || all)) in_ptr = MPI_IN_PLACE;
394 if (all) {
395 check_mpi_call(MPI_Allreduce(in_ptr, out_ptr, 1, mpi_type<T>::get(), op, c.get()), "MPI_Allreduce");
396 } else {
397 check_mpi_call(MPI_Reduce(in_ptr, out_ptr, 1, mpi_type<T>::get(), op, root, c.get()), "MPI_Reduce");
398 }
399 }
400
414 template <typename T>
415 requires(has_mpi_type<T>)
416 std::vector<T> mpi_gather(T const &x, communicator c = {}, int root = 0, bool all = false) {
417 std::vector<T> res(c.rank() == root || all ? c.size() : 0);
418 mpi_gather_into(x, res, c, root, all);
419 return res;
420 }
421
440 template <typename T, MPICompatibleRange R>
441 requires(has_mpi_type<T> && std::same_as<T, std::remove_cvref_t<std::ranges::range_value_t<R>>>)
442 void mpi_gather_into(T const &x, R &&rg, communicator c = {}, int root = 0, bool all = false) { // NOLINT (ranges need not be forwarded)
443 // check the size of the output range
444 if (c.rank() == root || all) {
445 EXPECTS_WITH_MESSAGE(c.size() == std::ranges::size(rg), "Output range size is not equal the number of ranks in mpi_gather_into");
446 }
447
448 // in case there is no active MPI environment or if the communicator size is < 2, copy the input into the range
449 if (!has_env || c.size() < 2) {
450 std::ranges::copy(std::views::single(x), std::ranges::begin(rg));
451 return;
452 }
453
454 // make the MPI C library call
455 using value_t = std::ranges::range_value_t<R>;
456 if (all) {
457 check_mpi_call(MPI_Allgather(&x, 1, mpi_type<T>::get(), std::ranges::data(rg), 1, mpi_type<value_t>::get(), c.get()), "MPI_Allgather");
458 } else {
459 check_mpi_call(MPI_Gather(&x, 1, mpi_type<T>::get(), std::ranges::data(rg), 1, mpi_type<value_t>::get(), root, c.get()), "MPI_Gather");
460 }
461 }
462
464
465} // namespace mpi
C++ wrapper around MPI_Comm providing various convenience functions.
Provides a C++ wrapper class for an MPI_Comm object.
A concept that checks if a range type is contiguous and sized and has an MPI compatible value type.
Provides utilities to map C++ datatypes to MPI datatypes.
void scatter_into(T1 &&x_in, T2 &&x_out, communicator c={}, int root=0)
Generic MPI scatter that scatters directly into an existing output object.
decltype(auto) scatter(T &&x, mpi::communicator c={}, int root=0)
Generic MPI scatter.
void gather_into(T1 &&x_in, T2 &&x_out, communicator c={}, int root=0, bool all=false)
Generic MPI gather that gathers directly into an existing output object.
void all_gather_into(T1 &&x_in, T2 &&x_out, communicator c={})
Generic MPI all-gather that gathers directly into an existing output object.
auto mpi_reduce(std::array< T, N > const &arr, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for a std::array.
Definition array.hpp:74
void mpi_reduce_into(std::array< T1, N1 > const &arr_in, std::array< T2, N2 > &arr_out, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for a std::array that reduces directly into an existing output array.
Definition array.hpp:99
void mpi_scatter_into(std::vector< T > const &v_in, std::vector< T > &v_out, communicator c={}, int root=0)
Implementation of an MPI scatter for a std::vector that scatters directly into an existing output vec...
Definition vector.hpp:119
void mpi_gather_into(T const &x, R &&rg, communicator c={}, int root=0, bool all=false)
Implementation of an MPI gather that gathers directly into an existing output range for types that ha...
void reduce_into(T1 &&x_in, T2 &&x_out, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic MPI reduce that reduces directly into an existing output object.
void reduce_in_place(T &&x, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic in-place MPI reduce.
void all_reduce_in_place(T &&x, communicator c={}, MPI_Op op=MPI_SUM)
Generic MPI all-reduce in place.
decltype(auto) reduce(T &&x, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic MPI reduce.
void mpi_broadcast(std::array< T, N > &arr, communicator c={}, int root=0)
Implementation of an MPI broadcast for a std::array.
Definition array.hpp:53
bool all_equal(T const &x, communicator c={})
Checks if a given object is equal across all ranks in the given communicator.
std::vector< T > mpi_gather(T const &x, communicator c={}, int root=0, bool all=false)
Implementation of an MPI gather for types that have a corresponding MPI datatype.
decltype(auto) all_reduce(T &&x, communicator c={}, MPI_Op op=MPI_SUM)
Generic MPI all-reduce.
void broadcast(T &&x, communicator c={}, int root=0)
Generic MPI broadcast.
void all_reduce_into(T1 &&x_in, T2 &&x_out, communicator c={}, MPI_Op op=MPI_SUM)
Generic MPI all-reduce that reduces directly into an existing output object.
decltype(auto) gather(T &&x, communicator c={}, int root=0, bool all=false)
Generic MPI gather.
decltype(auto) all_gather(T &&x, communicator c={})
Generic MPI all-gather.
static const bool has_env
Boolean variable that checks if there is an active MPI runtime environment.
constexpr bool has_mpi_type
Type trait to check if a type T has a corresponding MPI datatype, i.e. if mpi::mpi_type has been spec...
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Definition utils.hpp:48
Macros used in the mpi library.
Map C++ datatypes to the corresponding MPI datatypes.
Definition datatypes.hpp:60
Provides general utilities related to MPI.