TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
reduce.hpp
Go to the documentation of this file.
1// Copyright (c) 2020-2024 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Thomas Hahn, Olivier Parcollet, Nils Wentzell
16
22#pragma once
23
24#include "./utils.hpp"
26#include "../concepts.hpp"
27#include "../declarations.hpp"
28#include "../exceptions.hpp"
29#include "../macros.hpp"
30#include "../map.hpp"
31#include "../traits.hpp"
32
33#include <mpi.h>
34#include <mpi/mpi.hpp>
35
36#include <array>
37#include <cmath>
38#include <cstddef>
39#include <span>
40#include <type_traits>
41#include <utility>
42
43namespace nda::detail {
44
45 // Helper function that (all)reduces arrays/views in-place.
46 template <typename A>
48 void mpi_reduce_in_place_impl(A &&a_out, mpi::communicator comm = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) { // NOLINT
49 // check the shape of the input arrays/views
50 EXPECTS_WITH_MESSAGE(have_mpi_equal_shapes(a_out, comm), "Error in nda::detail::mpi_reduce_in_place_impl: Shapes of arrays/views must be equal")
51
52 // do nothing if there is no active MPI environment or if the communicator size is < 2
53 if (not mpi::has_env || comm.size() < 2) { return; }
54
55 // reduce the arrays/views
56 using value_type = typename std::decay_t<A>::value_type;
57 if constexpr (not mpi::has_mpi_type<value_type>) {
58 // arrays/views of non-MPI types are reduced element-wise
59 nda::for_each(a_out.shape(), [&](auto... args) { mpi::reduce_in_place(a_out(args...), comm, root, all, op); });
60 } else {
61 // for MPI-types we have to perform some checks on the input/ouput arrays/views
62 check_layout_mpi_compatible(a_out, "detail::mpi_reduce_in_place_impl");
63
64 // reduce the data
65 auto a_out_span = std::span{a_out.data(), static_cast<std::size_t>(a_out.size())};
66 mpi::reduce_in_place_range(a_out_span, comm, root, all, op);
67 }
68 }
69
70} // namespace nda::detail
71
72namespace nda {
73
112 template <typename A1, typename A2>
114 void mpi_reduce_capi(A1 const &a_in, A2 &&a_out, mpi::communicator comm = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) { // NOLINT
115 // check the shape of the input arrays/views
116 EXPECTS_WITH_MESSAGE(detail::have_mpi_equal_shapes(a_in, comm), "Error in nda::mpi_reduce_capi: Shapes of arrays/views must be equal")
117
118 // simply copy if there is no active MPI environment or if the communicator size is < 2
119 if (not mpi::has_env || comm.size() < 2) {
120 a_out = a_in;
121 return;
122 }
123
124 // reduce the arrays/views
125 using value_type = typename std::decay_t<A1>::value_type;
126 if constexpr (not mpi::has_mpi_type<value_type>) {
127 // arrays/views of non-MPI types are reduced element-wise
128 a_out = nda::map([&](auto const &x) { return mpi::reduce(x, comm, root, all, op); })(a_in);
129 } else {
130 // for MPI-types we have to perform some checks on the input and ouput arrays/views
131 detail::check_layout_mpi_compatible(a_in, "mpi_reduce_capi");
132 if ((comm.rank() == root) || all) {
133 detail::check_layout_mpi_compatible(a_out, "mpi_reduce_capi");
134 resize_or_check_if_view(a_out, a_in.shape());
135 if (std::abs(a_out.data() - a_in.data()) < a_in.size()) NDA_RUNTIME_ERROR << "Error in nda::mpi_reduce_capi: Overlapping arrays";
136 }
137
138 // reduce the data
139 auto a_out_span = std::span{a_out.data(), static_cast<std::size_t>(a_out.size())};
140 auto a_in_span = std::span{a_in.data(), static_cast<std::size_t>(a_in.size())};
141 mpi::reduce_range(a_in_span, a_out_span, comm, root, all, op);
142 }
143 }
144
171 template <typename A>
173 auto lazy_mpi_reduce(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
174 return mpi::lazy<mpi::tag::reduce, A>{std::forward<A>(a), comm, root, all, op};
175 }
176
195 template <typename A>
197 void mpi_reduce_in_place(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) { // NOLINT
198 detail::mpi_reduce_in_place_impl(a, comm, root, all, op);
199 }
200
219 template <typename A>
221 auto mpi_reduce(A const &a, mpi::communicator comm = {}, int root = 0, bool all = false, MPI_Op op = MPI_SUM) {
222 using return_t = get_regular_t<A>;
223 return_t a_out;
224 mpi_reduce_capi(a, a_out, comm, root, all, op);
225 return a_out;
226 }
227
230} // namespace nda
231
246template <nda::Array A>
247struct mpi::lazy<mpi::tag::reduce, A> {
249 using value_type = typename std::decay_t<A>::value_type;
250
252 using stored_type = A;
253
256
258 mpi::communicator comm;
259
261 const int root{0}; // NOLINT (const is fine here)
262
264 const bool all{false}; // NOLINT (const is fine here)
265
267 const MPI_Op op{MPI_SUM}; // NOLINT (const is fine here)
268
279 [[nodiscard]] auto shape() const {
280 if ((comm.rank() == root) || all) return rhs.shape();
281 return std::array<long, std::remove_cvref_t<stored_type>::rank>{};
282 }
283
302 template <nda::Array T>
303 void invoke(T &&target) const { // NOLINT (temporary views are allowed here)
304 if (target.data() == rhs.data()) {
305 nda::detail::mpi_reduce_in_place_impl(target, comm, root, all, op);
306 } else {
307 nda::mpi_reduce_capi(rhs, target, comm, root, all, op);
308 }
309 }
310};
Provides basic functions to create and manipulate arrays and views.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
bool all(A const &a)
Do all elements of the array evaluate to true?
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
mapped< F > map(F f)
Create a lazy function call expression on arrays/views.
Definition map.hpp:213
auto lazy_mpi_reduce(A &&a, mpi::communicator comm={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of a lazy MPI reduce for nda::basic_array or nda::basic_array_view types.
Definition reduce.hpp:173
auto mpi_reduce(A const &a, mpi::communicator comm={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for nda::basic_array or nda::basic_array_view types.
Definition reduce.hpp:221
void mpi_reduce_capi(A1 const &a_in, A2 &&a_out, mpi::communicator comm={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for nda::basic_array or nda::basic_array_view types using a C-style A...
Definition reduce.hpp:114
void mpi_reduce_in_place(A &&a, mpi::communicator comm={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an in-place MPI reduce for nda::basic_array or nda::basic_array_view types.
Definition reduce.hpp:197
decltype(basic_array{std::declval< T >()}) get_regular_t
Get the type of the nda::basic_array that would be obtained by constructing an array from a given typ...
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:163
__inline__ void for_each(std::array< Int, R > const &shape, F &&f)
Loop over all possible index values of a given shape and apply a function to them.
Definition for_each.hpp:129
Macros used in the nda library.
Provides lazy function calls on arrays/views.
Provides various utility functions used by the MPI interface of nda.
auto shape() const
Compute the shape of the nda::ArrayInitializer object.
Definition reduce.hpp:279
mpi::communicator comm
MPI communicator.
Definition reduce.hpp:258
stored_type rhs
Array/View to be reduced.
Definition reduce.hpp:255
void invoke(T &&target) const
Execute the lazy MPI operation and write the result to a target array/view.
Definition reduce.hpp:303
A stored_type
Type of the array/view stored in the lazy object.
Definition reduce.hpp:252
typename std::decay_t< A >::value_type value_type
Value type of the array/view.
Definition reduce.hpp:249
Provides type traits for the nda library.