TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
scatter.hpp
Go to the documentation of this file.
1// Copyright (c) 2020-2023 Simons Foundation
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0.txt
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14//
15// Authors: Olivier Parcollet, Nils Wentzell
16
17/**
18 * @file
19 * @brief Provides an MPI scatter function for nda::Array types.
20 */
21
22#pragma once
23
24#include "../basic_functions.hpp"
25#include "../concepts.hpp"
26#include "../exceptions.hpp"
27#include "../traits.hpp"
28
29#include <mpi/mpi.hpp>
30
31#include <type_traits>
32#include <utility>
33#include <vector>
34
35/**
36 * @ingroup av_mpi
37 * @brief Specialization of the `mpi::lazy` class for nda::Array types and the `mpi::tag::scatter` tag.
38 *
39 * @details An object of this class is returned when scattering nda::Array objects across multiple MPI processes.
40 *
41 * It models an nda::ArrayInitializer, that means it can be used to initialize and assign to nda::basic_array and
42 * nda::basic_array_view objects. The input array will be a chunked along its first dimension using `mpi::chunk_length`.
43 *
44 * See nda::mpi_scatter for an example.
45 *
46 * @tparam A nda::Array type to be scattered.
47 */
48template <nda::Array A>
49struct mpi::lazy<mpi::tag::scatter, A> {
50 /// Value type of the array/view.
51 using value_type = typename std::decay_t<A>::value_type;
52
53 /// Const view type of the array/view stored in the lazy object.
54 using const_view_type = decltype(std::declval<const A>()());
55
56 /// View of the array/view to be scattered.
57 const_view_type rhs;
58
59 /// MPI communicator.
60 mpi::communicator comm;
61
62 /// MPI root process.
63 const int root{0}; // NOLINT (const is fine here)
64
65 /// Should all processes receive the result. (doesn't make sense for scatter)
66 const bool all{false}; // NOLINT (const is fine here)
67
68 /**
69 * @brief Compute the shape of the target array.
70 *
71 * @details The target shape will be the same as the input shape, except that the first dimension of the input array
72 * is chunked into equal (as much as possible) parts using `mpi::chunk_length` and assigned to each MPI process.
73 *
74 * @warning This makes an MPI call.
75 *
76 * @return Shape of the target array.
77 */
78 [[nodiscard]] auto shape() const {
79 auto dims = rhs.shape();
80 long dim0 = dims[0];
81 mpi::broadcast(dim0, comm, root);
82 dims[0] = mpi::chunk_length(dim0, comm.size(), comm.rank());
83 return dims;
84 }
85
86 /**
87 * @brief Execute the lazy MPI operation and write the result to a target array/view.
88 *
89 * @tparam T nda::Array type of the target array/view.
90 * @param target Target array/view.
91 */
92 template <nda::Array T>
93 void invoke(T &&target) const { // NOLINT (temporary views are allowed here)
94 if (not target.is_contiguous()) NDA_RUNTIME_ERROR << "Error in MPI scatter for nda::Array: Target array needs to be contiguous";
95 static_assert(std::decay_t<A>::layout_t::stride_order_encoded == std::decay_t<T>::layout_t::stride_order_encoded,
96 "Error in MPI scatter for nda::Array: Incompatible stride orders");
97
98 // special case for non-mpi runs
99 if (not mpi::has_env) {
100 target = rhs;
101 return;
102 }
103
104 // get target shape and resize or check the target array
105 auto dims = shape();
106 resize_or_check_if_view(target, dims);
107
108 // compute send counts, receive counts and memory displacements
109 auto dim0 = rhs.extent(0);
110 auto stride0 = rhs.indexmap().strides()[0];
111 auto sendcounts = std::vector<int>(comm.size());
112 auto displs = std::vector<int>(comm.size() + 1, 0);
113 int recvcount = mpi::chunk_length(dim0, comm.size(), comm.rank()) * stride0;
114 for (int r = 0; r < comm.size(); ++r) {
115 sendcounts[r] = mpi::chunk_length(dim0, comm.size(), r) * stride0;
116 displs[r + 1] = sendcounts[r] + displs[r];
117 }
118
119 // scatter the data
120 auto mpi_value_type = mpi::mpi_type<value_type>::get();
121 MPI_Scatterv((void *)rhs.data(), &sendcounts[0], &displs[0], mpi_value_type, (void *)target.data(), recvcount, mpi_value_type, root, comm.get());
122 }
123};
124
125namespace nda {
126
127 /**
128 * @ingroup av_mpi
129 * @brief Implementation of an MPI scatter for nda::basic_array or nda::basic_array_view types.
130 *
131 * @details Since the returned `mpi::lazy` object models an nda::ArrayInitializer, it can be used to initialize/assign
132 * to nda::basic_array and nda::basic_array_view objects:
133 *
134 * @code{.cpp}
135 * // create an array on all processes
136 * nda::array<int, 2> arr(10, 4);
137 *
138 * // ...
139 * // fill array on root process
140 * // ...
141 *
142 * // scatter the array to all processes
143 * nda::array<int, 2> res = mpi::scatter(arr);
144 * @endcode
145 *
146 * Here, the array `res` will have a shape of `(10 / comm.size(), 4)`.
147 *
148 * @tparam A nda::basic_array or nda::basic_array_view type.
149 * @param a Array or view to be scattered.
150 * @param comm `mpi::communicator` object.
151 * @param root Rank of the root process.
152 * @param all Should all processes receive the result of the scatter (not used).
153 * @return An `mpi::lazy` object modelling an nda::ArrayInitializer.
154 */
155 template <typename A>
156 ArrayInitializer<std::remove_reference_t<A>> auto mpi_scatter(A &&a, mpi::communicator comm = {}, int root = 0, bool all = false)
157 requires(is_regular_or_view_v<A>)
158 {
159 if (not a.is_contiguous()) NDA_RUNTIME_ERROR << "Error in MPI scatter for nda::Array: Array needs to be contiguous";
160 return mpi::lazy<mpi::tag::scatter, A>{std::forward<A>(a), comm, root, all};
161 }
162
163} // namespace nda
#define NDA_RUNTIME_ERROR
ArrayInitializer< std::remove_reference_t< A > > auto mpi_scatter(A &&a, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI scatter for nda::basic_array or nda::basic_array_view types.
Definition scatter.hpp:156
void invoke(T &&target) const
Execute the lazy MPI operation and write the result to a target array/view.
Definition scatter.hpp:93
const_view_type rhs
View of the array/view to be scattered.
Definition scatter.hpp:57
mpi::communicator comm
MPI communicator.
Definition scatter.hpp:60
const bool all
Should all processes receive the result. (doesn't make sense for scatter)
Definition scatter.hpp:66
const int root
MPI root process.
Definition scatter.hpp:63
auto shape() const
Compute the shape of the target array.
Definition scatter.hpp:78