TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
contract.hpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
15#include "./tools.hpp"
17#include "../traits.hpp"
18
19#include <string_view>
20#include <utility>
21
22namespace nda::tensor {
23
28
65 template <BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
66 void contract(get_value_t<A> alpha, A const &a, std::string_view idx_a, B const &b, std::string_view idx_b, get_value_t<A> beta, C &&c, // NOLINT
67 std::string_view idx_c) {
68 // compile-time checks
69 constexpr bool run_on_device = mem::have_device_compatible_addr_space<A, B, C>;
70 static_assert(!run_on_device || have_cutensor, "nda::tensor::contract: cuTENSOR support is required");
71 static_assert(run_on_device || have_tblis, "nda::tensor::contract: TBLIS support is required");
72
73 // dispatch to backends
74 if constexpr (run_on_device) {
75 device::contract(alpha, a, idx_a, b, idx_b, beta, c, idx_c, c);
76 } else {
77 tblis::mult(alpha, a, idx_a, b, idx_b, beta, c, idx_c);
78 }
79 }
80
82 template <BlasArrayOrConj A, BlasArrayOrConjFor<A> B, BlasArrayFor<A> C>
83 void contract(A const &a, std::string_view idx_a, B const &b, std::string_view idx_b, C &&c, std::string_view idx_c) { // NOLINT
84 contract(get_value_t<A>{1}, a, idx_a, b, idx_b, get_value_t<A>{0}, std::forward<C>(c), idx_c);
85 }
86
88
89} // namespace nda::tensor
Provides definitions and type traits involving the different memory address spaces supported by nda.
Provides a C++ interface for various cuTENSOR routines.
std::decay_t< decltype(get_first_element(std::declval< A const >()))> get_value_t
Get the value type of an array/view or a scalar type.
Definition traits.hpp:212
static constexpr bool have_device_compatible_addr_space
Constexpr variable that is true if all given types have an address space compatible with Device.
void contract(get_value_t< A > alpha, A const &a, std::string_view idx_a, B const &b, std::string_view idx_b, get_value_t< A > beta, C &&c, std::string_view idx_c)
Tensor contraction with cuTENSOR/TBLIS dispatch.
Definition contract.hpp:66
static constexpr bool have_tblis
Constexpr variable that is true if nda is configured with TBLIS support.
Definition tools.hpp:47
static constexpr bool have_cutensor
Constexpr variable that is true if nda is configured cuTENSOR support.
Definition tools.hpp:40
Provides a C++ interface for various TBLIS tensor routines.
Provides various traits and utilities for the tensor interface.
Provides type traits for the nda library.