TRIQS/mpi 1.3.0
C++ interface to MPI
Loading...
Searching...
No Matches
datatypes.hpp
Go to the documentation of this file.
1// Copyright (c) 2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Alexander Hampel, Olivier Parcollet, Nils Wentzell
16
21
22#pragma once
23
24#include "./utils.hpp"
25
26#include <mpi.h>
27
28#include <algorithm>
29#include <array>
30#include <complex>
31#include <cstdlib>
32#include <tuple>
33#include <type_traits>
34#include <utility>
35#include <vector>
36
37namespace mpi {
38
43
58 template <typename T> struct mpi_type {};
59
60#define D(T, MPI_TY) \
61 \
62 template <> struct mpi_type<T> { \
63 [[nodiscard]] static MPI_Datatype get() noexcept { return MPI_TY; } \
64 }
65
66 // mpi_type specialization for various built-in types
67 D(bool, MPI_C_BOOL);
68 D(char, MPI_CHAR);
69 D(int, MPI_INT);
70 D(long, MPI_LONG);
71 D(long long, MPI_LONG_LONG);
72 D(double, MPI_DOUBLE);
73 D(float, MPI_FLOAT);
74 D(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
75 D(unsigned long, MPI_UNSIGNED_LONG);
76 D(unsigned int, MPI_UNSIGNED);
77 D(unsigned long long, MPI_UNSIGNED_LONG_LONG);
78#undef D
79
84 template <typename E>
85 requires(std::is_enum_v<E>)
86 struct mpi_type<E> : mpi_type<std::underlying_type_t<E>> {};
87
92 template <typename T> struct mpi_type<const T> : mpi_type<T> {};
93
99 template <typename T, typename = void> constexpr bool has_mpi_type = false;
100
105 template <typename T> constexpr bool has_mpi_type<T, std::void_t<decltype(mpi_type<T>::get())>> = true;
106
107 namespace detail {
108
109 // Helper struct to check if member types are mpi-serializable, i.e. have an associated mpi_type
110 struct serialize_checker {
111 template <typename T>
112 void operator&(T &)
113 requires(has_mpi_type<T>)
114 {}
115 };
116
117 } // namespace detail
118
123 template <typename T>
124 concept Serializable = requires(const T ac, T a, detail::serialize_checker ar) {
125 { ac.serialize(ar) } -> std::same_as<void>;
126 { a.deserialize(ar) } -> std::same_as<void>;
127 };
128
141 template <typename... Ts> [[nodiscard]] MPI_Datatype get_mpi_type(std::tuple<Ts...> tup) {
142 static constexpr int N = sizeof...(Ts);
143 std::array<MPI_Datatype, N> types = {mpi_type<std::remove_reference_t<Ts>>::get()...};
144
145 // the number of elements per type (we want 1 per type)
146 std::array<int, N> blocklen;
147 for (int i = 0; i < N; ++i) { blocklen[i] = 1; }
148
149 // displacements of the blocks in bytes w.r.t. to the memory address of the first block
150 std::array<MPI_Aint, N> disp;
151 // initialize displacement array from the tuple element addresses
152 []<size_t... Is>(std::index_sequence<Is...>, auto &t, MPI_Aint *d) {
153 ((d[Is] = (char *)&std::get<Is>(t) - (char *)&std::get<0>(t)), ...);
154 // account for non-trivial memory layouts of the tuple elements
155 auto min_el = *std::min_element(d, d + sizeof...(Ts));
156 ((d[Is] -= min_el), ...);
157 }(std::index_sequence_for<Ts...>{}, tup, disp.data());
158
159 // create and return MPI datatype
160 MPI_Datatype cty{};
161 check_mpi_call(MPI_Type_create_struct(N, blocklen.data(), disp.data(), types.data(), &cty), "MPI_Type_create_struct");
162 check_mpi_call(MPI_Type_commit(&cty), "MPI_Type_commit");
163 return cty;
164 }
165
170 template <typename... Ts> struct mpi_type<std::tuple<Ts...>> {
171 [[nodiscard]] static MPI_Datatype get() noexcept {
172 static MPI_Datatype type = get_mpi_type(std::tuple<Ts...>{});
173 return type;
174 }
175 };
176
198 template <typename U>
199 requires(not Serializable<U>) and requires(U u) { tie_data(u); }
200 struct mpi_type<U> {
201 [[nodiscard]] static MPI_Datatype get() noexcept {
202 static MPI_Datatype type = get_mpi_type(tie_data(U{}));
203 return type;
204 }
205 };
206
207 namespace detail {
208
209 // Archive helper class to obtain MPI custom type info using references to class members.
210 struct mpi_archive {
211 std::vector<int> block_lengths{};
212 std::vector<MPI_Aint> displacements{};
213 std::vector<MPI_Datatype> types{};
214 MPI_Aint base_address{};
215
216 // Constructor sets the base address of the object.
217 explicit mpi_archive(const void *base) { MPI_Get_address(base, &base_address); }
218
219 // Overloaded operator& to process members to set the block lengths, displacements and MPI types.
220 template <typename T>
221 requires(has_mpi_type<T>)
222 mpi_archive &operator&(const T &member) {
223 types.push_back(mpi_type<T>::get());
224 MPI_Aint address{};
225 MPI_Get_address(&member, &address);
226 displacements.push_back(MPI_Aint_diff(address, base_address));
227 block_lengths.push_back(1);
228 return *this;
229 }
230 };
231
232 } // namespace detail
233
251 template <Serializable T> [[nodiscard]] MPI_Datatype get_mpi_type(const T &obj) {
252 detail::mpi_archive ar(&obj);
253 obj.serialize(ar);
254 MPI_Datatype mpi_type{};
255 MPI_Type_create_struct(static_cast<int>(ar.block_lengths.size()), ar.block_lengths.data(), ar.displacements.data(), ar.types.data(), &mpi_type);
256 MPI_Type_commit(&mpi_type);
257 return mpi_type;
258 }
259
264 template <Serializable S> struct mpi_type<S> {
265 [[nodiscard]] static MPI_Datatype get() noexcept {
266 static MPI_Datatype type = get_mpi_type(S{});
267 return type;
268 }
269 };
270
272
273} // namespace mpi
A concept that checks if objects of a type can be serialized and deserialized.
MPI_Datatype get_mpi_type(std::tuple< Ts... > tup)
Create a new MPI_Datatype from a tuple.
constexpr bool has_mpi_type
Type trait to check if a type T has a corresponding MPI datatype, i.e. if mpi::mpi_type has been spec...
Definition datatypes.hpp:99
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Definition utils.hpp:51
Map C++ datatypes to the corresponding MPI datatypes.
Definition datatypes.hpp:58
Provides general utilities related to MPI.