TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
outer_product.hpp
Go to the documentation of this file.
1// Copyright (c) 2020--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include "../blas/ger.hpp"
14#include "../blas/tools.hpp"
15#include "../concepts.hpp"
16#include "../declarations.hpp"
17#include "../macros.hpp"
19#include "../mem/policies.hpp"
20#include "../traits.hpp"
21
22namespace nda::linalg {
23
50 template <MemoryArray A, MemoryArray B>
52 auto outer_product(A const &a, B const &b) {
53 // check the input arrays/views
54 EXPECTS(a.is_contiguous());
55 EXPECTS(b.is_contiguous());
56
57 // get the return type
58 auto constexpr rank = get_rank<A> + get_rank<B>;
59 auto constexpr algebra = []() {
60 if constexpr (get_algebra<A> == 'V' and get_algebra<B> == 'V') {
61 return 'M';
62 } else {
63 return 'A';
64 }
65 }();
66 using layout_pol = typename A::layout_policy_t::contiguous_t;
68 using return_t = basic_array<get_value_t<A>, rank, layout_pol, algebra, cont_pol>;
69
70 // use ger to calculate the outer product
71 auto res = return_t::zeros(stdutil::join(a.shape(), b.shape()));
72 auto a_vec = reshape(a, std::array{a.size()});
73 auto b_vec = reshape(b, std::array{b.size()});
74 auto mat = reshape(res, std::array{a.size(), b.size()});
75 nda::blas::ger(1.0, a_vec, b_vec, mat);
76
77 return res;
78 }
79
80} // namespace nda::linalg
Provides definitions and type traits involving the different memory address spaces supported by nda.
A generic multi-dimensional array.
Provides concepts for the nda library.
Provides various convenient aliases and helper functions for nda::basic_array and nda::basic_array_vi...
Provides a generic interface to the BLAS ger and geru routine.
auto reshape(A &&a, std::array< Int, R > const &new_shape)
Reshape an nda::basic_array or nda::basic_array_view.
constexpr char get_algebra
Constexpr variable that specifies the algebra of a type.
Definition traits.hpp:116
constexpr int get_rank
Constexpr variable that specifies the rank of an nda::Array or of a contiguous 1-dimensional range.
Definition traits.hpp:126
static constexpr bool has_C_layout
Constexpr variable that is true if the given nda::Array type has a C memory layout.
Definition tools.hpp:65
static constexpr bool has_F_layout
Constexpr variable that is true if the given nda::Array type has a Fortran memory layout.
Definition tools.hpp:55
void ger(get_value_t< X > alpha, X const &x, Y const &y, M &&m)
Interface to the BLAS ger and geru routine.
Definition ger.hpp:55
auto outer_product(A const &a, B const &b)
Outer product of two arrays/views.
heap_basic< mem::mallocator< AdrSp > > heap
Alias template of the nda::heap_basic policy using an nda::mem::mallocator.
Definition policies.hpp:52
constexpr std::array< T, R1+R2 > join(std::array< T, R1 > const &a1, std::array< T, R2 > const &a2)
Make a new std::array by joining two existing std::array objects.
Definition array.hpp:299
Macros used in the nda library.
Defines various memory handling policies.
Provides various traits and utilities for the BLAS interface.
Provides type traits for the nda library.