TRIQS/mpi 1.3.0
C++ interface to MPI
Loading...
Searching...
No Matches
datatypes.hpp
Go to the documentation of this file.
1// Copyright (c) 2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Alexander Hampel, Olivier Parcollet, Nils Wentzell
16
21
22#pragma once
23
24#include "./utils.hpp"
25
26#include <mpi.h>
27
28#include <algorithm>
29#include <array>
30#include <complex>
31#include <cstdlib>
32#include <tuple>
33#include <type_traits>
34#include <utility>
35
36namespace mpi {
37
42
57 template <typename T> struct mpi_type {};
58
59#define D(T, MPI_TY) \
60 \
61 template <> struct mpi_type<T> { \
62 [[nodiscard]] static MPI_Datatype get() noexcept { return MPI_TY; } \
63 }
64
65 // mpi_type specialization for various built-in types
66 D(bool, MPI_C_BOOL);
67 D(char, MPI_CHAR);
68 D(int, MPI_INT);
69 D(long, MPI_LONG);
70 D(long long, MPI_LONG_LONG);
71 D(double, MPI_DOUBLE);
72 D(float, MPI_FLOAT);
73 D(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
74 D(unsigned long, MPI_UNSIGNED_LONG);
75 D(unsigned int, MPI_UNSIGNED);
76 D(unsigned long long, MPI_UNSIGNED_LONG_LONG);
77#undef D
78
83 template <typename E>
84 requires(std::is_enum_v<E>)
85 struct mpi_type<E> : mpi_type<std::underlying_type_t<E>> {};
86
91 template <typename T> struct mpi_type<const T> : mpi_type<T> {};
92
98 template <typename T, typename = void> constexpr bool has_mpi_type = false;
99
104 template <typename T> constexpr bool has_mpi_type<T, std::void_t<decltype(mpi_type<T>::get())>> = true;
105
106 namespace detail {
107
108 // Helper struct to check if member types are mpi-serializable, i.e. have an associated mpi_type
109 struct serialize_checker {
110 template <typename T>
111 void operator&(T &)
112 requires(has_mpi_type<T>)
113 {}
114 };
115
116 } // namespace detail
117
122 template <typename T>
123 concept Serializable = requires(const T ac, T a, detail::serialize_checker ar) {
124 { ac.serialize(ar) } -> std::same_as<void>;
125 { a.deserialize(ar) } -> std::same_as<void>;
126 };
127
140 template <typename... Ts> [[nodiscard]] MPI_Datatype get_mpi_type(std::tuple<Ts...> tup) {
141 static constexpr int N = sizeof...(Ts);
142 std::array<MPI_Datatype, N> types = {mpi_type<std::remove_reference_t<Ts>>::get()...};
143
144 // the number of elements per type (we want 1 per type)
145 std::array<int, N> blocklen;
146 for (int i = 0; i < N; ++i) { blocklen[i] = 1; }
147
148 // displacements of the blocks in bytes w.r.t. to the memory address of the first block
149 std::array<MPI_Aint, N> disp;
150 // initialize displacement array from the tuple element addresses
151 []<size_t... Is>(std::index_sequence<Is...>, auto &t, MPI_Aint *d) {
152 ((d[Is] = (char *)&std::get<Is>(t) - (char *)&std::get<0>(t)), ...);
153 // account for non-trivial memory layouts of the tuple elements
154 auto min_el = *std::min_element(d, d + sizeof...(Ts));
155 ((d[Is] -= min_el), ...);
156 }(std::index_sequence_for<Ts...>{}, tup, disp.data());
157
158 // create and return MPI datatype
159 MPI_Datatype cty{};
160 check_mpi_call(MPI_Type_create_struct(N, blocklen.data(), disp.data(), types.data(), &cty), "MPI_Type_create_struct");
161 check_mpi_call(MPI_Type_commit(&cty), "MPI_Type_commit");
162 return cty;
163 }
164
169 template <typename... Ts> struct mpi_type<std::tuple<Ts...>> {
170 [[nodiscard]] static MPI_Datatype get() noexcept {
171 static MPI_Datatype type = get_mpi_type(std::tuple<Ts...>{});
172 return type;
173 }
174 };
175
197 template <typename U>
198 requires(not Serializable<U>) and requires(U u) { tie_data(u); }
199 struct mpi_type<U> {
200 [[nodiscard]] static MPI_Datatype get() noexcept {
201 static MPI_Datatype type = get_mpi_type(tie_data(U{}));
202 return type;
203 }
204 };
205
206 namespace detail {
207
208 // Archive helper class to obtain MPI custom type info using references to class members.
209 struct mpi_archive {
210 std::vector<int> block_lengths{};
211 std::vector<MPI_Aint> displacements{};
212 std::vector<MPI_Datatype> types{};
213 MPI_Aint base_address{};
214
215 // Constructor sets the base address of the object.
216 explicit mpi_archive(const void *base) { MPI_Get_address(base, &base_address); }
217
218 // Overloaded operator& to process members to set the block lengths, displacements and MPI types.
219 template <typename T>
220 requires(has_mpi_type<T>)
221 mpi_archive &operator&(const T &member) {
222 types.push_back(mpi_type<T>::get());
223 MPI_Aint address{};
224 MPI_Get_address(&member, &address);
225 displacements.push_back(MPI_Aint_diff(address, base_address));
226 block_lengths.push_back(1);
227 return *this;
228 }
229 };
230
231 } // namespace detail
232
250 template <Serializable T> [[nodiscard]] MPI_Datatype get_mpi_type(const T &obj) {
251 detail::mpi_archive ar(&obj);
252 obj.serialize(ar);
253 MPI_Datatype mpi_type{};
254 MPI_Type_create_struct(static_cast<int>(ar.block_lengths.size()), ar.block_lengths.data(), ar.displacements.data(), ar.types.data(), &mpi_type);
255 MPI_Type_commit(&mpi_type);
256 return mpi_type;
257 }
258
263 template <Serializable S> struct mpi_type<S> {
264 [[nodiscard]] static MPI_Datatype get() noexcept {
265 static MPI_Datatype type = get_mpi_type(S{});
266 return type;
267 }
268 };
269
271
272} // namespace mpi
A concept that checks if objects of a type can be serialized and deserialized.
MPI_Datatype get_mpi_type(std::tuple< Ts... > tup)
Create a new MPI_Datatype from a tuple.
constexpr bool has_mpi_type
Type trait to check if a type T has a corresponding MPI datatype, i.e. if mpi::mpi_type has been spec...
Definition datatypes.hpp:98
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Definition utils.hpp:51
Map C++ datatypes to the corresponding MPI datatypes.
Definition datatypes.hpp:57
Provides general utilities related to MPI.