TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
outer_product.hpp
Go to the documentation of this file.
1// Copyright (c) 2020--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../blas/ger.hpp"
14#include "../blas/tools.hpp"
15#include "../concepts.hpp"
16#include "../declarations.hpp"
17#include "../macros.hpp"
19#include "../mem/policies.hpp"
20#include "../stdutil/array.hpp"
21#include "../traits.hpp"
22
23#include <array>
24#include <type_traits>
25
26namespace nda::linalg {
27
60 template <blas_lapack::BlasArray A, blas_lapack::BlasArrayFor<A> B>
62 auto outer_product(A const &a, B const &b) {
63 // check the input arrays/views
64 EXPECTS(a.is_contiguous());
65 EXPECTS(b.is_contiguous());
66
67 // get the return type
68 auto constexpr rank = get_rank<A> + get_rank<B>;
69 auto constexpr algebra = []() {
70 if constexpr (get_algebra<A> == 'V' and get_algebra<B> == 'V') {
71 return 'M';
72 } else {
73 return 'A';
74 }
75 }();
76 using layout_pol = std::conditional_t<get_rank<A> == 1, typename B::layout_policy_t::contiguous_t, typename A::layout_policy_t::contiguous_t>;
77 using cont_pol = heap<mem::common_addr_space<A, B>>;
78 using return_t = basic_array<get_value_t<A>, rank, layout_pol, algebra, cont_pol>;
79
80 // use ger to calculate the outer product
81 auto res = return_t::zeros(stdutil::join(a.shape(), b.shape()));
82 auto a_vec = reshape(a, std::array{a.size()});
83 auto b_vec = reshape(b, std::array{b.size()});
84 auto mat = reshape(res, std::array{a.size(), b.size()});
85 blas::ger(1.0, a_vec, b_vec, mat);
86
87 return res;
88 }
89
90} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides utility functions for std::array.
Provides various traits and utilities for the BLAS interface.
A generic multi-dimensional array.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a generic interface to the BLAS/cuBLAS ger, geru and gerc routines.
auto reshape(A &&a, std::array< Int, R > const &new_shape)
Reshape an nda::basic_array or nda::basic_array_view.
constexpr char get_algebra
Constexpr variable that specifies the algebra of a type.
Definition traits.hpp:137
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:147
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
Definition tools.hpp:89
static constexpr bool has_F_layout
Constexpr variable that is true if all given nda::Array types have nda::F_layout.
Definition tools.hpp:79
void ger(get_value_t< X > alpha, X const &x, Y const &y, A &&a)
Interface to the BLAS/cuBLAS ger and geru routine.
Definition ger.hpp:54
auto outer_product(A const &a, B const &b)
Outer product of two arrays/views.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr std::array< T, R1+R2 > join(std::array< T, R1 > const &a1, std::array< T, R2 > const &a2)
Make a new std::array by joining two existing std::array objects.
Definition array.hpp:299
Macros used in the nda library.
Defines various memory handling policies.
Provides type traits for the nda library.