TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
tblis_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#include "./tblis_interface.hpp"
12#include "../../exceptions.hpp"
13
14#include <tblis/tblis.h>
15
16#include <complex>
17
18namespace nda::tensor::tblis {
19
20 // Import TBLIS types and functions.
21 using namespace ::tblis;
22
23 namespace {
24
25 // Helper function to call set routine.
26 template <typename T>
27 void set_impl(T alpha, tensor_view<T> A, std::string_view idx_A) {
28 tblis_scalar s(alpha);
29 tblis_tensor t(A.data, A.ndim, A.extents, A.strides);
30 tblis_tensor_set(nullptr, nullptr, &s, &t, idx_A.data());
31 }
32
33 // Helper function to call scale routine.
34 template <typename T>
35 void scale_impl(T alpha, tensor_view<T> A, std::string_view idx_A) {
36 tblis_tensor t(alpha, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
37 tblis_tensor_scale(nullptr, nullptr, &t, idx_A.data());
38 }
39
40 // Map our binary_op enum to TBLIS reduce_t.
41 reduce_t to_tblis_reduce_op(binary_op op) {
42 switch (op) {
43 case binary_op::SUM: return REDUCE_SUM;
44 case binary_op::SUM_ABS: return REDUCE_SUM_ABS;
45 case binary_op::MAX: return REDUCE_MAX;
46 case binary_op::MAX_ABS: return REDUCE_MAX_ABS;
47 case binary_op::MIN: return REDUCE_MIN;
48 case binary_op::MIN_ABS: return REDUCE_MIN_ABS;
49 case binary_op::NORM_2: return REDUCE_NORM_2;
50 default: NDA_RUNTIME_ERROR << "nda::tensor::tblis::reduce: nda::tensor::binary_op has no TBLIS equivalent";
51 }
52 }
53
54 // Helper function to call reduce routine.
55 template <typename T>
56 T reduce_impl(binary_op op, const_tensor_view<T> A, std::string_view idx_A) {
57 tblis_tensor t(T{1}, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
58 tblis_scalar result(T{});
59 len_type idx = 0;
60 tblis_tensor_reduce(nullptr, nullptr, to_tblis_reduce_op(op), &t, idx_A.data(), &result, &idx);
61 return result.as<T>();
62 }
63
64 // Helper function to call dot routine.
65 template <typename T>
66 T dot_impl(const_tensor_view<T> A, std::string_view idx_A, const_tensor_view<T> B, std::string_view idx_B) {
67 tblis_tensor tA(T{1}, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
68 tblis_tensor tB(T{1}, B.op == unary_op::CONJ, B.data, B.ndim, B.extents, B.strides);
69 tblis_scalar result(T{});
70 tblis_tensor_dot(nullptr, nullptr, &tA, idx_A.data(), &tB, idx_B.data(), &result);
71 return result.as<T>();
72 }
73
74 // Helper function to call add routine.
75 template <typename T>
76 void add_impl(T alpha, const_tensor_view<T> A, std::string_view idx_A, T beta, tensor_view<T> B, std::string_view idx_B) {
77 tblis_tensor tA(alpha, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
78 tblis_tensor tB(beta, B.op == unary_op::CONJ, B.data, B.ndim, B.extents, B.strides);
79 tblis_tensor_add(nullptr, nullptr, &tA, idx_A.data(), &tB, idx_B.data());
80 }
81
82 // Helper function to call mult routine.
83 template <typename T>
84 void mult_impl(T alpha, const_tensor_view<T> A, std::string_view idx_A, const_tensor_view<T> B, std::string_view idx_B, T beta, tensor_view<T> C,
85 std::string_view idx_C) {
86 tblis_tensor tA(alpha, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
87 tblis_tensor tB(T{1}, B.op == unary_op::CONJ, B.data, B.ndim, B.extents, B.strides);
88 tblis_tensor tC(beta, C.op == unary_op::CONJ, C.data, C.ndim, C.extents, C.strides);
89 tblis_tensor_mult(nullptr, nullptr, &tA, idx_A.data(), &tB, idx_B.data(), &tC, idx_C.data());
90 }
91
92 } // namespace
93
94 // set
95 void set(float alpha, tensor_view<float> A, std::string_view idx_A) { set_impl(alpha, A, idx_A); }
96 void set(double alpha, tensor_view<double> A, std::string_view idx_A) { set_impl(alpha, A, idx_A); }
97 void set(std::complex<float> alpha, tensor_view<std::complex<float>> A, std::string_view idx_A) { set_impl(alpha, A, idx_A); }
98 void set(std::complex<double> alpha, tensor_view<std::complex<double>> A, std::string_view idx_A) { set_impl(alpha, A, idx_A); }
99
100 // scale
101 void scale(float alpha, tensor_view<float> A, std::string_view idx_A) { scale_impl(alpha, A, idx_A); }
102 void scale(double alpha, tensor_view<double> A, std::string_view idx_A) { scale_impl(alpha, A, idx_A); }
103 void scale(std::complex<float> alpha, tensor_view<std::complex<float>> A, std::string_view idx_A) { scale_impl(alpha, A, idx_A); }
104 void scale(std::complex<double> alpha, tensor_view<std::complex<double>> A, std::string_view idx_A) { scale_impl(alpha, A, idx_A); }
105
106 // reduce
107 float reduce(binary_op op, const_tensor_view<float> A, std::string_view idx_A) { return reduce_impl(op, A, idx_A); }
108 double reduce(binary_op op, const_tensor_view<double> A, std::string_view idx_A) { return reduce_impl(op, A, idx_A); }
109 std::complex<float> reduce(binary_op op, const_tensor_view<std::complex<float>> A, std::string_view idx_A) { return reduce_impl(op, A, idx_A); }
110 std::complex<double> reduce(binary_op op, const_tensor_view<std::complex<double>> A, std::string_view idx_A) { return reduce_impl(op, A, idx_A); }
111
112 // dot
113 float dot(const_tensor_view<float> A, std::string_view idx_A, const_tensor_view<float> B, std::string_view idx_B) {
114 return dot_impl(A, idx_A, B, idx_B);
115 }
116 double dot(const_tensor_view<double> A, std::string_view idx_A, const_tensor_view<double> B, std::string_view idx_B) {
117 return dot_impl(A, idx_A, B, idx_B);
118 }
119 std::complex<float> dot(const_tensor_view<std::complex<float>> A, std::string_view idx_A, const_tensor_view<std::complex<float>> B,
120 std::string_view idx_B) {
121 return dot_impl(A, idx_A, B, idx_B);
122 }
123 std::complex<double> dot(const_tensor_view<std::complex<double>> A, std::string_view idx_A, const_tensor_view<std::complex<double>> B,
124 std::string_view idx_B) {
125 return dot_impl(A, idx_A, B, idx_B);
126 }
127
128 // add
129 void add(float alpha, const_tensor_view<float> A, std::string_view idx_A, float beta, tensor_view<float> B, std::string_view idx_B) {
130 add_impl(alpha, A, idx_A, beta, B, idx_B);
131 }
132 void add(double alpha, const_tensor_view<double> A, std::string_view idx_A, double beta, tensor_view<double> B, std::string_view idx_B) {
133 add_impl(alpha, A, idx_A, beta, B, idx_B);
134 }
135 void add(std::complex<float> alpha, const_tensor_view<std::complex<float>> A, std::string_view idx_A, std::complex<float> beta,
136 tensor_view<std::complex<float>> B, std::string_view idx_B) {
137 add_impl(alpha, A, idx_A, beta, B, idx_B);
138 }
139 void add(std::complex<double> alpha, const_tensor_view<std::complex<double>> A, std::string_view idx_A, std::complex<double> beta,
140 tensor_view<std::complex<double>> B, std::string_view idx_B) {
141 add_impl(alpha, A, idx_A, beta, B, idx_B);
142 }
143
144 // mult
145 void mult(float alpha, const_tensor_view<float> A, std::string_view idx_A, const_tensor_view<float> B, std::string_view idx_B, float beta,
146 tensor_view<float> C, std::string_view idx_C) {
147 mult_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C);
148 }
149 void mult(double alpha, const_tensor_view<double> A, std::string_view idx_A, const_tensor_view<double> B, std::string_view idx_B, double beta,
150 tensor_view<double> C, std::string_view idx_C) {
151 mult_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C);
152 }
153 void mult(std::complex<float> alpha, const_tensor_view<std::complex<float>> A, std::string_view idx_A, const_tensor_view<std::complex<float>> B,
154 std::string_view idx_B, std::complex<float> beta, tensor_view<std::complex<float>> C, std::string_view idx_C) {
155 mult_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C);
156 }
157 void mult(std::complex<double> alpha, const_tensor_view<std::complex<double>> A, std::string_view idx_A, const_tensor_view<std::complex<double>> B,
158 std::string_view idx_B, std::complex<double> beta, tensor_view<std::complex<double>> C, std::string_view idx_C) {
159 mult_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C);
160 }
161
162} // namespace nda::tensor::tblis
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
binary_op
Binary operations for tensor operations.
Definition tools.hpp:67
tensor_view< const T > const_tensor_view
Alias for a tensor_view with const value type.
Definition tools.hpp:234
Provides a C++ interface for various TBLIS tensor routines.