TRIQS/nda
2.0.0
Multi-dimensional array library for C++
Toggle main menu visibility
Loading...
Searching...
No Matches
gemm.hpp
Go to the documentation of this file.
1
// Copyright (c) 2019--present, The Simons Foundation
2
// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3
// SPDX-License-Identifier: Apache-2.0
4
// See LICENSE in the root of this distribution for details.
5
10
11
#pragma once
12
13
#include "
./interface/cxx_interface.hpp
"
14
#include "
./tools.hpp
"
15
#include "
../concepts.hpp
"
16
#include "
../device.hpp
"
17
#include "
../layout_transforms.hpp
"
18
#include "
../macros.hpp
"
19
#include "
../mem/address_space.hpp
"
20
#include "
../traits.hpp
"
21
22
#include <utility>
23
24
namespace
nda::blas {
25
30
56
template
<BlasArrayOrConj<2> A, BlasArrayOrConjFor<A, 2> B, BlasArrayFor<A, 2> C>
57
void
gemm
(
get_value_t<A>
alpha, A
const
&a, B
const
&b,
get_value_t<A>
beta, C &&c) {
58
// if C is in C-layout, compute the transpose of the product
59
if
constexpr
(
has_C_layout<C>
) {
60
gemm
(alpha,
transpose
(b),
transpose
(a), beta,
transpose
(std::forward<C>(c)));
61
}
else
{
62
// get underlying matrix in case it is given as a conjugate expression
63
auto
&mat_a =
get_array
(a);
64
auto
&mat_b =
get_array
(b);
65
66
// check the dimensions of the input/output arrays/views
67
auto
const
[m, k] = mat_a.shape();
68
auto
const
[l, n] = mat_b.shape();
69
EXPECTS(k == l);
70
EXPECTS(m == c.extent(0));
71
EXPECTS(n == c.extent(1));
72
73
// arrays/views must be BLAS compatible
74
EXPECTS(mat_a.indexmap().min_stride() == 1);
75
EXPECTS(mat_b.indexmap().min_stride() == 1);
76
EXPECTS(c.indexmap().min_stride() == 1);
77
78
// perform the actual library call
79
if
constexpr
(
mem::have_device_compatible_addr_space<A, B, C>
) {
80
device::gemm(
get_op<A>
,
get_op<B>
, m, n, k, alpha, mat_a.data(),
get_ld
(mat_a), mat_b.data(),
get_ld
(mat_b), beta, c.data(),
get_ld
(c));
81
}
else
{
82
f77::gemm(
get_op<A>
,
get_op<B>
, m, n, k, alpha, mat_a.data(),
get_ld
(mat_a), mat_b.data(),
get_ld
(mat_b), beta, c.data(),
get_ld
(c));
83
}
84
}
85
}
86
88
89
}
// namespace nda::blas
address_space.hpp
Provides definitions and type traits involving the different memory address spaces supported by nda.
cxx_interface.hpp
Provides a C++ interface for various BLAS routines.
tools.hpp
Provides various traits and utilities for the BLAS interface.
concepts.hpp
Provides concepts for the nda library.
device.hpp
Provides GPU and non-GPU specific functionality.
nda::transpose
auto transpose(A &&a)
Transpose the memory layout of an nda::MemoryArray or an nda::expr_call.
Definition
layout_transforms.hpp:180
nda::get_value_t
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition
traits.hpp:212
nda::blas_lapack::get_array
MemoryArray decltype(auto) get_array(A &&a)
Get the underlying array of a conjugate lazy expression or return the array itself in case it is an n...
Definition
tools.hpp:68
nda::blas_lapack::get_op
static constexpr char get_op
Variable template that determines the BLAS matrix operation tag ('N','T','C') based on the given bool...
Definition
tools.hpp:104
nda::blas_lapack::get_ld
int get_ld(A const &a)
Get the leading dimension of an nda::MemoryArray with rank 1 or 2 for BLAS/LAPACK calls.
Definition
tools.hpp:128
nda::blas_lapack::has_C_layout
static constexpr bool has_C_layout
Constexpr variable that is true if all given nda::Array types have nda::C_layout.
Definition
tools.hpp:89
nda::blas::gemm
void gemm(get_value_t< A > alpha, A const &a, B const &b, get_value_t< A > beta, C &&c)
Interface to the BLAS/cuBLAS gemm routine.
Definition
gemm.hpp:57
nda::mem::have_device_compatible_addr_space
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Definition
address_space.hpp:177
layout_transforms.hpp
Provides functions to transform the memory layout of an nda::basic_array or nda::basic_array_view.
macros.hpp
Macros used in the nda library.
traits.hpp
Provides type traits for the nda library.
nda
blas
gemm.hpp
Generated by
1.17.0