TRIQS/mpi 1.3.0
C++ interface to MPI
Loading...
Searching...
No Matches
generic_communication.hpp
Go to the documentation of this file.
1// Copyright (c) 2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Alexander Hampel, Olivier Parcollet, Nils Wentzell
16
25#pragma once
26
27#include "./datatypes.hpp"
28#include "./lazy.hpp"
29#include "./utils.hpp"
30
31#include <mpi.h>
32
33#include <type_traits>
34#include <utility>
35#include <vector>
36
37namespace mpi {
38
44 namespace detail {
45
46 // Type trait to check if a type is a std::vector.
47 template <typename T> inline constexpr bool is_std_vector = false;
48
49 // Spezialization of is_std_vector for std::vector<T>.
50 template <typename T> inline constexpr bool is_std_vector<std::vector<T>> = true;
51
52 // Convert an object of type V to an object of type T.
53 template <typename T, typename V> T convert(V v) {
54 if constexpr (is_std_vector<T>) {
55 T res;
56 res.reserve(v.size());
57 for (auto &x : v) res.emplace_back(convert<typename T::value_type>(std::move(x)));
58 return res;
59 } else
60 return T{std::move(v)};
61 }
62
63 } // namespace detail
64
76 template <typename T> [[gnu::always_inline]] void broadcast(T &&x, communicator c = {}, int root = 0) {
77 static_assert(not std::is_const_v<T>, "mpi::broadcast cannot be called on const objects");
78 if (has_env) mpi_broadcast(std::forward<T>(x), c, root);
79 }
80
96 template <typename T>
97 [[gnu::always_inline]] inline decltype(auto) reduce(T &&x, communicator c = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
98 // return type of mpi_reduce
99 using r_t = decltype(mpi_reduce(std::forward<T>(x), c, root, all, op));
100 if constexpr (is_mpi_lazy<r_t>) {
101 return mpi_reduce(std::forward<T>(x), c, root, all, op);
102 } else {
103 if (has_env)
104 return mpi_reduce(std::forward<T>(x), c, root, all, op);
105 else
106 return detail::convert<r_t>(std::forward<T>(x));
107 }
108 }
109
123 template <typename T>
124 [[gnu::always_inline]] inline void reduce_in_place(T &&x, communicator c = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
125 static_assert(not std::is_const_v<T>, "In-place mpi functions cannot be called on const objects");
126 if (has_env) mpi_reduce_in_place(std::forward<T>(x), c, root, all, op);
127 }
128
142 template <typename T> [[gnu::always_inline]] inline decltype(auto) scatter(T &&x, mpi::communicator c = {}, int root = 0) {
143 // return type of mpi_scatter
144 using r_t = decltype(mpi_scatter(std::forward<T>(x), c, root));
145 if constexpr (is_mpi_lazy<r_t>) {
146 return mpi_scatter(std::forward<T>(x), c, root);
147 } else {
148 if (has_env)
149 return mpi_scatter(std::forward<T>(x), c, root);
150 else
151 return detail::convert<r_t>(std::forward<T>(x));
152 }
153 }
154
169 template <typename T> [[gnu::always_inline]] inline decltype(auto) gather(T &&x, mpi::communicator c = {}, int root = 0, bool all = false) {
170 // return type of mpi_gather
171 using r_t = decltype(mpi_gather(std::forward<T>(x), c, root, all));
172 if constexpr (is_mpi_lazy<r_t>) {
173 return mpi_gather(std::forward<T>(x), c, root, all);
174 } else {
175 if (has_env)
176 return mpi_gather(std::forward<T>(x), c, root, all);
177 else
178 return detail::convert<r_t>(std::forward<T>(x));
179 }
180 }
181
186 template <typename T> [[gnu::always_inline]] inline decltype(auto) all_reduce(T &&x, communicator c = {}, MPI_Op op = MPI_SUM) {
187 return reduce(std::forward<T>(x), c, 0, true, op);
188 }
189
194 template <typename T> [[gnu::always_inline]] inline void all_reduce_in_place(T &&x, communicator c = {}, MPI_Op op = MPI_SUM) {
195 reduce_in_place(std::forward<T>(x), c, 0, true, op);
196 }
197
202 template <typename T> [[gnu::always_inline]] inline decltype(auto) all_gather(T &&x, communicator c = {}) {
203 return gather(std::forward<T>(x), c, 0, true);
204 }
205
217 template <typename T>
218 requires(has_mpi_type<T>)
219 void mpi_broadcast(T &x, communicator c = {}, int root = 0) {
220 check_mpi_call(MPI_Bcast(&x, 1, mpi_type<T>::get(), root, c.get()), "MPI_Bcast");
221 }
222
237 template <typename T>
238 requires(has_mpi_type<T>)
239 T mpi_reduce(T const &x, communicator c = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
240 T b;
241 auto d = mpi_type<T>::get();
242 if (!all)
243 // old MPI implementations may require a non-const send buffer
244 check_mpi_call(MPI_Reduce(const_cast<T *>(&x), &b, 1, d, op, root, c.get()), "MPI_Reduce"); // NOLINT
245 else
246 check_mpi_call(MPI_Allreduce(const_cast<T *>(&x), &b, 1, d, op, c.get()), "MPI_Allreduce"); // NOLINT
247 return b;
248 }
249
263 template <typename T>
264 requires(has_mpi_type<T>)
265 void mpi_reduce_in_place(T &x, communicator c = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
266 if (!all)
267 check_mpi_call(MPI_Reduce((c.rank() == root ? MPI_IN_PLACE : &x), &x, 1, mpi_type<T>::get(), op, root, c.get()), "MPI_Reduce");
268 else
269 check_mpi_call(MPI_Allreduce(MPI_IN_PLACE, &x, 1, mpi_type<T>::get(), op, c.get()), "MPI_Allreduce");
270 }
271
287 template <typename T> bool all_equal(T const &x, communicator c = {}) {
288 if (!has_env) return true;
289 auto min_obj = all_reduce(x, c, MPI_MIN);
290 auto max_obj = all_reduce(x, c, MPI_MAX);
291 return min_obj == max_obj;
292 }
293
296} // namespace mpi
C++ wrapper around MPI_Comm providing various convenience functions.
Provides utilities to map C++ datatypes to MPI datatypes.
decltype(auto) scatter(T &&x, mpi::communicator c={}, int root=0)
Generic MPI scatter.
auto mpi_reduce(std::array< T, N > const &arr, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for a std::array.
Definition array.hpp:86
void reduce_in_place(T &&x, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic in-place MPI reduce.
void all_reduce_in_place(T &&x, communicator c={}, MPI_Op op=MPI_SUM)
Generic MPI all-reduce in-place.
decltype(auto) reduce(T &&x, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Generic MPI reduce.
auto mpi_scatter(std::vector< T > const &v, communicator c={}, int root=0)
Implementation of an MPI scatter for a std::vector.
Definition vector.hpp:106
void mpi_broadcast(std::array< T, N > &arr, communicator c={}, int root=0)
Implementation of an MPI broadcast for a std::arr.
Definition array.hpp:51
bool all_equal(T const &x, communicator c={})
Checks if a given object is equal across all ranks in the given communicator.
decltype(auto) all_reduce(T &&x, communicator c={}, MPI_Op op=MPI_SUM)
Generic MPI all-reduce.
void broadcast(T &&x, communicator c={}, int root=0)
Generic MPI broadcast.
std::string mpi_gather(std::string const &s, communicator c={}, int root=0, bool all=false)
Implementation of an MPI gather for a std::string.
Definition string.hpp:65
decltype(auto) gather(T &&x, mpi::communicator c={}, int root=0, bool all=false)
Generic MPI gather.
decltype(auto) all_gather(T &&x, communicator c={})
Generic MPI all-gather.
void mpi_reduce_in_place(std::array< T, N > &arr, communicator c={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an in-place MPI reduce for a std::array.
Definition array.hpp:67
static const bool has_env
Boolean variable that is true, if one of the environment variables OMPI_COMM_WORLD_RANK,...
constexpr bool is_mpi_lazy
Type trait to check if a type is mpi::lazy.
Definition lazy.hpp:84
constexpr bool has_mpi_type
Type trait to check if a type T has a corresponding MPI datatype, i.e. if mpi::mpi_type has been spec...
Definition datatypes.hpp:89
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Definition utils.hpp:72
Provides a struct and tags to represent lazy MPI communication.
Provides general utilities related to MPI.