TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
gather.hpp
Go to the documentation of this file.
1// Copyright (c) 2020--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "./utils.hpp"
15#include "../declarations.hpp"
16#include "../macros.hpp"
17#include "../stdutil/array.hpp"
18#include "../traits.hpp"
19
20#include <mpi/mpi.hpp>
21
22#include <cstddef>
23#include <span>
24#include <type_traits>
25
26namespace nda::detail {
27
28 // Helper function to get the shape of the gathered array/view.
29 template <typename A>
30 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
31 auto mpi_gather_shape_impl(A const &a, mpi::communicator comm, int root, bool all) {
32 auto dims = a.shape();
33 dims[0] = mpi::all_reduce(dims[0], comm);
34 if (!all && comm.rank() != root) dims = nda::stdutil::make_initialized_array<dims.size()>(0l);
35 return dims;
36 }
37
38} // namespace nda::detail
39
40namespace nda {
41
46
79 template <typename A1, typename A2>
80 requires(is_regular_or_view_v<A1> and std::decay_t<A1>::is_stride_order_C()
81 and is_regular_or_view_v<A2> and std::decay_t<A2>::is_stride_order_C())
82 void mpi_gather_into(A1 const &a_in, A2 &&a_out, mpi::communicator comm = {}, int root = 0, bool all = false) { // NOLINT
83 // check the shape of the input arrays/views
84 EXPECTS_WITH_MESSAGE(detail::have_mpi_equal_shapes(a_in(nda::range(1), nda::ellipsis{}), comm),
85 "Error in nda::mpi_gather_into: Shapes of arrays/views must be equal save the first one");
86
87 // get output shape
88 auto dims = detail::mpi_gather_shape_impl(a_in, comm, root, all);
89
90 // check if the input and output arrays/views can be used in the MPI call and resize or check the output array/view
91 detail::check_layout_mpi_compatible(a_in, "mpi_gather_into");
92 bool const receives = (all || (comm.rank() == root));
93 if (receives) {
94 detail::check_layout_mpi_compatible(a_out, "mpi_gather_into");
95 resize_or_check_if_view(a_out, dims);
96 }
97
98 // gather the data
99 auto a_in_span = std::span{a_in.data(), static_cast<std::size_t>(a_in.size())};
100 auto a_out_span = std::span{a_out.data(), static_cast<std::size_t>(a_out.size())};
101 mpi::gather_range(a_in_span, a_out_span, comm, root, all);
102 }
103
122 template <typename A>
123 requires(is_regular_or_view_v<A> and std::decay_t<A>::is_stride_order_C())
124 auto mpi_gather(A const &a, mpi::communicator comm = {}, int root = 0, bool all = false) {
125 using return_t = get_regular_t<A>;
126 return_t a_out;
127 mpi::gather_into(a, a_out, comm, root, all);
128 return a_out;
129 }
130
132
133} // namespace nda
Provides utility functions for std::array.
Provides basic functions to create and manipulate arrays and views.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
bool all(A const &a)
Do all elements of the array evaluate to true?
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
auto mpi_gather(A const &a, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types.
Definition gather.hpp:124
void mpi_gather_into(A1 const &a_in, A2 &&a_out, mpi::communicator comm={}, int root=0, bool all=false)
Implementation of an MPI gather for nda::basic_array or nda::basic_array_view types that gathers dire...
Definition gather.hpp:82
decltype(basic_array{std::declval< T >()}) get_regular_t
Get the type of the nda::basic_array that would be obtained by constructing an array from a given typ...
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Definition traits.hpp:162
constexpr std::array< T, R > make_initialized_array(T v)
Create a new std::array object initialized with a specific value.
Definition array.hpp:155
Macros used in the nda library.
Provides various utility functions used by the MPI interface of nda.
Mimics Python's ... syntax.
Definition range.hpp:36
Provides type traits for the nda library.