49template <nda::Array A>
50struct mpi::lazy<mpi::tag::reduce, A> {
67 const bool all{
false};
70 const MPI_Op
op{MPI_SUM};
77 [[nodiscard]]
auto shape()
const {
return rhs.shape(); }
88 template <nda::Array T>
91 if (not target.is_contiguous()) NDA_RUNTIME_ERROR <<
"Error in MPI reduce for nda::Array: Target array needs to be contiguous";
92 static_assert(std::decay_t<A>::layout_t::stride_order_encoded == std::decay_t<T>::layout_t::stride_order_encoded,
93 "Error in MPI reduce for nda::Array: Incompatible stride orders");
96 if (not mpi::has_env) {
102 if constexpr (not mpi::has_mpi_type<value_type>) {
104 target =
nda::map([
this](
auto const &x) {
return mpi::reduce(x, this->comm, this->
root, this->
all, this->
op); })(
rhs);
107 bool in_place = (target.data() ==
rhs.data());
109 if (
rhs.size() != target.size())
110 NDA_RUNTIME_ERROR <<
"Error in MPI reduce for nda::Array: In-place reduction requires arrays of the same size";
113 if (std::abs(target.data() -
rhs.data()) <
rhs.size()) NDA_RUNTIME_ERROR <<
"Error in MPI reduce for nda::Array: Overlapping arrays";
116 void *target_ptr = (
void *)target.data();
117 void *rhs_ptr = (
void *)
rhs.data();
118 auto count =
rhs.size();
119 auto mpi_value_type = mpi::mpi_type<value_type>::get();
122 MPI_Reduce((
comm.rank() ==
root ? MPI_IN_PLACE : rhs_ptr), rhs_ptr, count, mpi_value_type,
op,
root,
comm.get());
124 MPI_Reduce(rhs_ptr, target_ptr, count, mpi_value_type,
op,
root,
comm.get());
127 MPI_Allreduce(MPI_IN_PLACE, rhs_ptr, count, mpi_value_type,
op,
comm.get());
129 MPI_Allreduce(rhs_ptr, target_ptr, count, mpi_value_type,
op,
comm.get());
164 template <
typename A>
169 if (not a.is_contiguous()) NDA_RUNTIME_ERROR <<
"Error in MPI reduce for nda::Array: Array needs to be contiguous";
170 return mpi::lazy<mpi::tag::reduce, A>{std::forward<A>(a), comm, root,
all, op};
Provides basic functions to create and manipulate arrays and views.
Check if a given type satisfies the array initializer concept for a given nda::MemoryArray type.
Provides concepts for the nda library.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
bool all(A const &a)
Do all elements of the array evaluate to true?
void resize_or_check_if_view(A &a, std::array< long, A::rank > const &sha)
Resize a given regular array to the given shape or check if a given view as the correct shape.
mapped< F > map(F f)
Create a lazy function call expression on arrays/views.
ArrayInitializer< std::remove_reference_t< A > > auto mpi_reduce(A &&a, mpi::communicator comm={}, int root=0, bool all=false, MPI_Op op=MPI_SUM)
Implementation of an MPI reduce for nda::basic_array or nda::basic_array_view types.
constexpr bool is_regular_or_view_v
Constexpr variable that is true if type A is either a regular array or a view.
Provides lazy function calls on arrays/views.
const_view_type rhs
View of the array/view to be reduced.
auto shape() const
Compute the shape of the target array.
mpi::communicator comm
MPI communicator.
const bool all
Should all processes receive the result.
decltype(std::declval< const A >()()) const_view_type
Const view type of the array/view stored in the lazy object.
void invoke(T &&target) const
Execute the lazy MPI operation and write the result to a target array/view.
const MPI_Op op
MPI reduction operation.
typename std::decay_t< A >::value_type value_type
Value type of the array/view.
const int root
MPI root process.
Provides type traits for the nda library.