TRIQS/mpi 1.3.0
C++ interface to MPI
Loading...
Searching...
No Matches
datatypes.hpp
Go to the documentation of this file.
1// Copyright (c) 2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Alexander Hampel, Olivier Parcollet, Nils Wentzell
16
21
22#pragma once
23
24#include "./utils.hpp"
25
26#include <mpi.h>
27
28#include <algorithm>
29#include <array>
30#include <complex>
31#include <cstdlib>
32#include <tuple>
33#include <type_traits>
34#include <utility>
35
36namespace mpi {
37
42
57 template <typename T> struct mpi_type {};
58
59#define D(T, MPI_TY) \
60 \
61 template <> struct mpi_type<T> { \
62 [[nodiscard]] static MPI_Datatype get() noexcept { return MPI_TY; } \
63 }
64
65 // mpi_type specialization for various built-in types
66 D(bool, MPI_C_BOOL);
67 D(char, MPI_CHAR);
68 D(int, MPI_INT);
69 D(long, MPI_LONG);
70 D(long long, MPI_LONG_LONG);
71 D(double, MPI_DOUBLE);
72 D(float, MPI_FLOAT);
73 D(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
74 D(unsigned long, MPI_UNSIGNED_LONG);
75 D(unsigned int, MPI_UNSIGNED);
76 D(unsigned long long, MPI_UNSIGNED_LONG_LONG);
77#undef D
78
83 template <typename E>
84 requires(std::is_enum_v<E>)
85 struct mpi_type<E> : mpi_type<std::underlying_type_t<E>> {};
86
91 template <typename T> struct mpi_type<const T> : mpi_type<T> {};
92
97 template <typename T, typename = void> constexpr bool has_mpi_type = false;
98
103 template <typename T> constexpr bool has_mpi_type<T, std::void_t<decltype(mpi_type<T>::get())>> = true;
104
105 namespace detail {
106
107 // Helper struct to check if member types are mpi-serializable, i.e. have an associated mpi_type
108 struct serialize_checker {
109 template <typename T>
110 void operator&(T &)
111 requires(has_mpi_type<T>)
112 {}
113 };
114
115 } // namespace detail
116
121 template <typename T>
122 concept Serializable = requires(const T ac, T a, detail::serialize_checker ar) {
123 { ac.serialize(ar) } -> std::same_as<void>;
124 { a.deserialize(ar) } -> std::same_as<void>;
125 };
126
139 template <typename... Ts> [[nodiscard]] MPI_Datatype get_mpi_type(std::tuple<Ts...> tup) {
140 static constexpr int N = sizeof...(Ts);
141 std::array<MPI_Datatype, N> types = {mpi_type<std::remove_reference_t<Ts>>::get()...};
142
143 // the number of elements per type (we want 1 per type)
144 std::array<int, N> blocklen;
145 for (int i = 0; i < N; ++i) { blocklen[i] = 1; }
146
147 // displacements of the blocks in bytes w.r.t. to the memory address of the first block
148 std::array<MPI_Aint, N> disp;
149 // initialize displacement array from the tuple element addresses
150 []<size_t... Is>(std::index_sequence<Is...>, auto &t, MPI_Aint *d) {
151 ((d[Is] = (char *)&std::get<Is>(t) - (char *)&std::get<0>(t)), ...);
152 // account for non-trivial memory layouts of the tuple elements
153 auto min_el = *std::min_element(d, d + sizeof...(Ts));
154 ((d[Is] -= min_el), ...);
155 }(std::index_sequence_for<Ts...>{}, tup, disp.data());
156
157 // create and return MPI datatype
158 MPI_Datatype cty{};
159 check_mpi_call(MPI_Type_create_struct(N, blocklen.data(), disp.data(), types.data(), &cty), "MPI_Type_create_struct");
160 check_mpi_call(MPI_Type_commit(&cty), "MPI_Type_commit");
161 return cty;
162 }
163
168 template <typename... Ts> struct mpi_type<std::tuple<Ts...>> {
169 [[nodiscard]] static MPI_Datatype get() noexcept {
170 static MPI_Datatype type = get_mpi_type(std::tuple<Ts...>{});
171 return type;
172 }
173 };
174
196 template <typename U>
197 requires(not Serializable<U>) and requires(U u) { tie_data(u); }
198 struct mpi_type<U> {
199 [[nodiscard]] static MPI_Datatype get() noexcept {
200 static MPI_Datatype type = get_mpi_type(tie_data(U{}));
201 return type;
202 }
203 };
204
205 namespace detail {
206
207 // Archive helper class to obtain MPI custom type info using references to class members.
208 struct mpi_archive {
209 std::vector<int> block_lengths{};
210 std::vector<MPI_Aint> displacements{};
211 std::vector<MPI_Datatype> types{};
212 MPI_Aint base_address{};
213
214 // Constructor sets the base address of the object.
215 explicit mpi_archive(const void *base) { MPI_Get_address(base, &base_address); }
216
217 // Overloaded operator& to process members to set the block lengths, displacements and MPI types.
218 template <typename T>
219 requires(has_mpi_type<T>)
220 mpi_archive &operator&(const T &member) {
221 types.push_back(mpi_type<T>::get());
222 MPI_Aint address{};
223 MPI_Get_address(&member, &address);
224 displacements.push_back(MPI_Aint_diff(address, base_address));
225 block_lengths.push_back(1);
226 return *this;
227 }
228 };
229
230 } // namespace detail
231
249 template <Serializable T> [[nodiscard]] MPI_Datatype get_mpi_type(const T &obj) {
250 detail::mpi_archive ar(&obj);
251 obj.serialize(ar);
252 MPI_Datatype mpi_type{};
253 MPI_Type_create_struct(static_cast<int>(ar.block_lengths.size()), ar.block_lengths.data(), ar.displacements.data(), ar.types.data(), &mpi_type);
254 MPI_Type_commit(&mpi_type);
255 return mpi_type;
256 }
257
262 template <Serializable S> struct mpi_type<S> {
263 [[nodiscard]] static MPI_Datatype get() noexcept {
264 static MPI_Datatype type = get_mpi_type(S{});
265 return type;
266 }
267 };
268
270
271} // namespace mpi
A concept that checks if objects of a type can be serialized and deserialized.
MPI_Datatype get_mpi_type(std::tuple< Ts... > tup)
Create a new MPI_Datatype from a tuple.
constexpr bool has_mpi_type
Type trait to check if a type T has a corresponding MPI datatype, i.e. if mpi::mpi_type has been spec...
Definition datatypes.hpp:97
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Definition utils.hpp:73
Map C++ datatypes to the corresponding MPI datatypes.
Definition datatypes.hpp:57
Provides general utilities related to MPI.