TRIQS/nda
2.0.0
Multi-dimensional array library for C++
Toggle main menu visibility
Loading...
Searching...
No Matches
contract.hpp
Go to the documentation of this file.
1
// Copyright (c) 2024--present, The Simons Foundation
2
// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3
// SPDX-License-Identifier: Apache-2.0
4
// See LICENSE in the root of this distribution for details.
5
10
11
#pragma once
12
13
#include "
./interface/cutensor_interface.hpp
"
14
#include "
./interface/tblis_interface.hpp
"
15
#include "
./tools.hpp
"
16
#include "
../mem/address_space.hpp
"
17
#include "
../traits.hpp
"
18
19
#include <string_view>
20
#include <utility>
21
22
namespace
nda::tensor {
23
28
65
template
<BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
66
void
contract
(
get_value_t<A>
alpha, A
const
&a, std::string_view idx_a, B
const
&b, std::string_view idx_b,
get_value_t<A>
beta, C &&c,
// NOLINT
67
std::string_view idx_c) {
68
// compile-time checks
69
constexpr
bool
run_on_device =
mem::have_device_compatible_addr_space<A, B, C>
;
70
static_assert
(!run_on_device ||
have_cutensor
,
"nda::tensor::contract: cuTENSOR support is required"
);
71
static_assert
(run_on_device ||
have_tblis
,
"nda::tensor::contract: TBLIS support is required"
);
72
73
// dispatch to backends
74
if
constexpr
(run_on_device) {
75
device::contract(alpha, a, idx_a, b, idx_b, beta, c, idx_c, c);
76
}
else
{
77
tblis::mult(alpha, a, idx_a, b, idx_b, beta, c, idx_c);
78
}
79
}
80
82
template
<BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
83
void
contract
(A
const
&a, std::string_view idx_a, B
const
&b, std::string_view idx_b, C &&c, std::string_view idx_c) {
// NOLINT
84
contract
(
get_value_t<A>
{1}, a, idx_a, b, idx_b,
get_value_t<A>
{0}, std::forward<C>(c), idx_c);
85
}
86
88
89
}
// namespace nda::tensor
address_space.hpp
Provides definitions and type traits involving the different memory address spaces supported by nda.
cutensor_interface.hpp
Provides a C++ interface for various cuTENSOR routines.
nda::get_value_t
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition
traits.hpp:212
nda::mem::have_device_compatible_addr_space
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
Definition
address_space.hpp:177
nda::tensor::contract
void contract(get_value_t< A > alpha, A const &a, std::string_view idx_a, B const &b, std::string_view idx_b, get_value_t< A > beta, C &&c, std::string_view idx_c)
Tensor contraction with cuTENSOR/TBLIS dispatch.
Definition
contract.hpp:66
nda::tensor::have_tblis
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
Definition
tools.hpp:47
nda::tensor::have_cutensor
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition
tools.hpp:40
tblis_interface.hpp
Provides a C++ interface for various TBLIS tensor routines.
tools.hpp
Provides various traits and utilities for the tensor interface.
traits.hpp
Provides type traits for the nda library.
nda
tensor
contract.hpp
Generated by
1.17.0