TRIQS/mpi 2.0.0
C++ interface to MPI
Loading...
Searching...
No Matches
datatypes.hpp
Go to the documentation of this file.
1// Copyright (c) 2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Alexander Hampel, Olivier Parcollet, Nils Wentzell
16
21
22#pragma once
23
24#include "./utils.hpp"
25
26#include <mpi.h>
27
28#include <algorithm>
29#include <array>
30#include <complex>
31#include <cstdlib>
32#include <tuple>
33#include <type_traits>
34#include <utility>
35#include <vector>
36
37namespace mpi {
38
43
60 template <typename T> struct mpi_type {};
61
62#define D(T, MPI_TY) \
63 \
64 template <> struct mpi_type<T> { \
65 [[nodiscard]] static MPI_Datatype get() noexcept { return MPI_TY; } \
66 }
67
68 // mpi_type specialization for various built-in types
69 D(bool, MPI_C_BOOL);
70 D(char, MPI_CHAR);
71 D(int, MPI_INT);
72 D(long, MPI_LONG);
73 D(long long, MPI_LONG_LONG);
74 D(double, MPI_DOUBLE);
75 D(float, MPI_FLOAT);
76 D(std::complex<double>, MPI_C_DOUBLE_COMPLEX);
77 D(unsigned long, MPI_UNSIGNED_LONG);
78 D(unsigned int, MPI_UNSIGNED);
79 D(unsigned long long, MPI_UNSIGNED_LONG_LONG);
80#undef D
81
86 template <typename E>
87 requires(std::is_enum_v<E>)
88 struct mpi_type<E> : mpi_type<std::underlying_type_t<E>> {};
89
94 template <typename T> struct mpi_type<const T> : mpi_type<T> {};
95
102 template <typename T> constexpr bool has_mpi_type = requires { mpi_type<T>::get(); };
103
104 namespace detail {
105
106 // Helper struct to check if member types are mpi-serializable, i.e. have an associated mpi_type
107 struct serialize_checker {
108 template <typename T>
109 void operator&(T &)
110 requires(has_mpi_type<T>)
111 {}
112 };
113
114 } // namespace detail
115
120 template <typename T>
121 concept Serializable = requires(const T ac, T a, detail::serialize_checker ar) {
122 { ac.serialize(ar) } -> std::same_as<void>;
123 { a.deserialize(ar) } -> std::same_as<void>;
124 };
125
138 template <typename... Ts> [[nodiscard]] MPI_Datatype get_mpi_type(std::tuple<Ts...> tup) {
139 static constexpr int N = sizeof...(Ts);
140 std::array<MPI_Datatype, N> types = {mpi_type<std::remove_reference_t<Ts>>::get()...};
141
142 // the number of elements per type (we want 1 per type)
143 std::array<int, N> blocklen;
144 for (int i = 0; i < N; ++i) { blocklen[i] = 1; }
145
146 // displacements of the blocks in bytes w.r.t. to the memory address of the first block
147 std::array<MPI_Aint, N> disp;
148 // initialize displacement array from the tuple element addresses
149 []<size_t... Is>(std::index_sequence<Is...>, auto &t, MPI_Aint *d) {
150 ((d[Is] = (char *)&std::get<Is>(t) - (char *)&std::get<0>(t)), ...);
151 // account for non-trivial memory layouts of the tuple elements
152 auto min_el = *std::min_element(d, d + sizeof...(Ts));
153 ((d[Is] -= min_el), ...);
154 }(std::index_sequence_for<Ts...>{}, tup, disp.data());
155
156 // create and return MPI datatype
157 MPI_Datatype cty{};
158 check_mpi_call(MPI_Type_create_struct(N, blocklen.data(), disp.data(), types.data(), &cty), "MPI_Type_create_struct");
159 check_mpi_call(MPI_Type_commit(&cty), "MPI_Type_commit");
160 return cty;
161 }
162
167 template <typename... Ts> struct mpi_type<std::tuple<Ts...>> {
168 [[nodiscard]] static MPI_Datatype get() noexcept {
169 static MPI_Datatype type = get_mpi_type(std::tuple<Ts...>{});
170 return type;
171 }
172 };
173
195 template <typename U>
196 requires(not Serializable<U>) and requires(U u) { tie_data(u); }
197 struct mpi_type<U> {
198 [[nodiscard]] static MPI_Datatype get() noexcept {
199 static MPI_Datatype type = get_mpi_type(tie_data(U{}));
200 return type;
201 }
202 };
203
204 namespace detail {
205
206 // Archive helper class to obtain MPI custom type info using references to class members.
207 struct mpi_archive {
208 std::vector<int> block_lengths{};
209 std::vector<MPI_Aint> displacements{};
210 std::vector<MPI_Datatype> types{};
211 MPI_Aint base_address{};
212
213 // Constructor sets the base address of the object.
214 explicit mpi_archive(const void *base) { MPI_Get_address(base, &base_address); }
215
216 // Overloaded operator& to process members to set the block lengths, displacements and MPI types.
217 template <typename T>
218 requires(has_mpi_type<T>)
219 mpi_archive &operator&(const T &member) {
220 types.push_back(mpi_type<T>::get());
221 MPI_Aint address{};
222 MPI_Get_address(&member, &address);
223 displacements.push_back(MPI_Aint_diff(address, base_address));
224 block_lengths.push_back(1);
225 return *this;
226 }
227 };
228
229 } // namespace detail
230
249 template <Serializable T> [[nodiscard]] MPI_Datatype get_mpi_type(const T &obj) {
250 detail::mpi_archive ar(&obj);
251 obj.serialize(ar);
252 MPI_Datatype mpi_type{};
253 MPI_Type_create_struct(static_cast<int>(ar.block_lengths.size()), ar.block_lengths.data(), ar.displacements.data(), ar.types.data(), &mpi_type);
254 MPI_Type_commit(&mpi_type);
255 return mpi_type;
256 }
257
262 template <Serializable S> struct mpi_type<S> {
263 [[nodiscard]] static MPI_Datatype get() noexcept {
264 static MPI_Datatype type = get_mpi_type(S{});
265 return type;
266 }
267 };
268
270
271} // namespace mpi
A concept that checks if objects of a type can be serialized and deserialized.
MPI_Datatype get_mpi_type(std::tuple< Ts... > tup)
Create a new MPI_Datatype from a tuple.
constexpr bool has_mpi_type
Type trait to check if a type T has a corresponding MPI datatype, i.e. if mpi::mpi_type has been spec...
void check_mpi_call(int errcode, const std::string &mpi_routine)
Check the success of an MPI call.
Definition utils.hpp:48
Map C++ datatypes to the corresponding MPI datatypes.
Definition datatypes.hpp:60
Provides general utilities related to MPI.