TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
scatter.hpp
Go to the documentation of this file.
1// Copyright (c) 2020--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./utils.hpp"
14#include "../declarations.hpp"
15#include "../macros.hpp"
16#include "../traits.hpp"
17
18#include <mpi/mpi.hpp>
19
20#include <cstddef>
21#include <functional>
22#include <numeric>
23#include <span>
24#include <tuple>
25#include <type_traits>
26
27namespace nda::detail {
28
29 // Helper function to get the shape and total size of the scattered array/view as well as the stride along the first
30 // dimension.
31 template <typename A>
32 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
33 auto mpi_scatter_shape_impl(A const &a, mpi::communicator comm, int root) {
34 auto dims = a.shape();
35 mpi::broadcast(dims, comm, root);
36 auto scattered_size = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<>());
37 auto stride0 = scattered_size / dims[0];
38 dims[0] = mpi::chunk_length(dims[0], comm.size(), comm.rank());
39 return std::make_tuple(dims, scattered_size, stride0);
40 }
41
42} // namespace nda::detail
43
44namespace nda {
45
50
79 template <typename A1, typename A2>
80 requires(is_regular_or_view_v<A1> and std::decay_t<A1>::is_stride_order_C()
81 and is_regular_or_view_v<A2> and std::decay_t<A2>::is_stride_order_C())
82 void mpi_scatter_into(A1 const &a_in, A2 &&a_out, mpi::communicator comm = {}, int root = 0) { // NOLINT
83 // check the ranks of the input arrays/views
84 EXPECTS_WITH_MESSAGE(detail::have_mpi_equal_ranks(a_in, comm), "Error in nda::mpi_scatter_into: Ranks of arrays/views must be equal")
85
86 // check if the input and output arrays/views can be used in the MPI call
87 if (comm.rank() == root) detail::check_layout_mpi_compatible(a_in, "mpi_scatter_into");
88 detail::check_layout_mpi_compatible(a_out, "mpi_scatter_into");
89
90 // get output shape and resize or check the output array/view
91 auto [dims, scattered_size, stride0] = detail::mpi_scatter_shape_impl(a_in, comm, root);
92 resize_or_check_if_view(a_out, dims);
93
94 // scatter the data
95 auto a_out_span = std::span{a_out.data(), static_cast<std::size_t>(a_out.size())};
96 auto a_in_span = std::span{a_in.data(), static_cast<std::size_t>(a_in.size())};
97 mpi::scatter_range(a_in_span, a_out_span, scattered_size, comm, root, stride0);
98 }
99
116 template <typename A>
117 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
118 auto mpi_scatter(A const &a, mpi::communicator comm = {}, int root = 0) {
119 using return_t = get_regular_t<A>;
120 return_t a_out;
121 mpi::scatter_into(a, a_out, comm, root);
122 return a_out;
123 }
124
126
127} // namespace nda
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
auto mpi_scatter(A const &a, mpi::communicator comm={}, int root=0)
Implementation of an MPI scatter for nda::basic_array or nda::basic_array_view types.
Definition scatter.hpp:118
void mpi_scatter_into(A1 const &a_in, A2 &&a_out, mpi::communicator comm={}, int root=0)
Implementation of an MPI scatter for nda::basic_array or nda::basic_array_view types that scatters di...
Definition scatter.hpp:82
decltype(basic_array{std::declval< T >()}) get_regular_t
Get the type of the nda::basic_array that would be obtained by constructing an array from a given typ...
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:153
Macros used in the nda library.
Provides various utility functions used by the MPI interface of nda.
Provides type traits for the nda library.