14#include <tblis/tblis.h>
18namespace nda::tensor::tblis {
21 using namespace ::tblis;
27 void set_impl(T alpha, tensor_view<T> A, std::string_view idx_A) {
28 tblis_scalar s(alpha);
29 tblis_tensor t(A.data, A.ndim, A.extents, A.strides);
30 tblis_tensor_set(
nullptr,
nullptr, &s, &t, idx_A.data());
35 void scale_impl(T alpha, tensor_view<T> A, std::string_view idx_A) {
36 tblis_tensor t(alpha, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
37 tblis_tensor_scale(
nullptr,
nullptr, &t, idx_A.data());
41 reduce_t to_tblis_reduce_op(
binary_op op) {
43 case binary_op::SUM:
return REDUCE_SUM;
44 case binary_op::SUM_ABS:
return REDUCE_SUM_ABS;
45 case binary_op::MAX:
return REDUCE_MAX;
46 case binary_op::MAX_ABS:
return REDUCE_MAX_ABS;
47 case binary_op::MIN:
return REDUCE_MIN;
48 case binary_op::MIN_ABS:
return REDUCE_MIN_ABS;
49 case binary_op::NORM_2:
return REDUCE_NORM_2;
50 default: NDA_RUNTIME_ERROR <<
"nda::tensor::tblis::reduce: nda::tensor::binary_op has no TBLIS equivalent";
57 tblis_tensor t(T{1}, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
58 tblis_scalar result(T{});
60 tblis_tensor_reduce(
nullptr,
nullptr, to_tblis_reduce_op(op), &t, idx_A.data(), &result, &idx);
61 return result.as<T>();
67 tblis_tensor tA(T{1}, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
68 tblis_tensor tB(T{1}, B.op == unary_op::CONJ, B.data, B.ndim, B.extents, B.strides);
69 tblis_scalar result(T{});
70 tblis_tensor_dot(
nullptr,
nullptr, &tA, idx_A.data(), &tB, idx_B.data(), &result);
71 return result.as<T>();
76 void add_impl(T alpha,
const_tensor_view<T> A, std::string_view idx_A, T beta, tensor_view<T> B, std::string_view idx_B) {
77 tblis_tensor tA(alpha, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
78 tblis_tensor tB(beta, B.op == unary_op::CONJ, B.data, B.ndim, B.extents, B.strides);
79 tblis_tensor_add(
nullptr,
nullptr, &tA, idx_A.data(), &tB, idx_B.data());
85 std::string_view idx_C) {
86 tblis_tensor tA(alpha, A.op == unary_op::CONJ, A.data, A.ndim, A.extents, A.strides);
87 tblis_tensor tB(T{1}, B.op == unary_op::CONJ, B.data, B.ndim, B.extents, B.strides);
88 tblis_tensor tC(beta, C.op == unary_op::CONJ, C.data, C.ndim, C.extents, C.strides);
89 tblis_tensor_mult(
nullptr,
nullptr, &tA, idx_A.data(), &tB, idx_B.data(), &tC, idx_C.data());
95 void set(
float alpha, tensor_view<float> A, std::string_view idx_A) { set_impl(alpha, A, idx_A); }
96 void set(
double alpha, tensor_view<double> A, std::string_view idx_A) { set_impl(alpha, A, idx_A); }
97 void set(std::complex<float> alpha, tensor_view<std::complex<float>> A, std::string_view idx_A) { set_impl(alpha, A, idx_A); }
98 void set(std::complex<double> alpha, tensor_view<std::complex<double>> A, std::string_view idx_A) { set_impl(alpha, A, idx_A); }
101 void scale(
float alpha, tensor_view<float> A, std::string_view idx_A) { scale_impl(alpha, A, idx_A); }
102 void scale(
double alpha, tensor_view<double> A, std::string_view idx_A) { scale_impl(alpha, A, idx_A); }
103 void scale(std::complex<float> alpha, tensor_view<std::complex<float>> A, std::string_view idx_A) { scale_impl(alpha, A, idx_A); }
104 void scale(std::complex<double> alpha, tensor_view<std::complex<double>> A, std::string_view idx_A) { scale_impl(alpha, A, idx_A); }
109 std::complex<float> reduce(
binary_op op,
const_tensor_view<std::complex<float>> A, std::string_view idx_A) {
return reduce_impl(op, A, idx_A); }
110 std::complex<double> reduce(
binary_op op,
const_tensor_view<std::complex<double>> A, std::string_view idx_A) {
return reduce_impl(op, A, idx_A); }
114 return dot_impl(A, idx_A, B, idx_B);
117 return dot_impl(A, idx_A, B, idx_B);
120 std::string_view idx_B) {
121 return dot_impl(A, idx_A, B, idx_B);
124 std::string_view idx_B) {
125 return dot_impl(A, idx_A, B, idx_B);
129 void add(
float alpha,
const_tensor_view<float> A, std::string_view idx_A,
float beta, tensor_view<float> B, std::string_view idx_B) {
130 add_impl(alpha, A, idx_A, beta, B, idx_B);
132 void add(
double alpha,
const_tensor_view<double> A, std::string_view idx_A,
double beta, tensor_view<double> B, std::string_view idx_B) {
133 add_impl(alpha, A, idx_A, beta, B, idx_B);
135 void add(std::complex<float> alpha,
const_tensor_view<std::complex<float>> A, std::string_view idx_A, std::complex<float> beta,
136 tensor_view<std::complex<float>> B, std::string_view idx_B) {
137 add_impl(alpha, A, idx_A, beta, B, idx_B);
139 void add(std::complex<double> alpha,
const_tensor_view<std::complex<double>> A, std::string_view idx_A, std::complex<double> beta,
140 tensor_view<std::complex<double>> B, std::string_view idx_B) {
141 add_impl(alpha, A, idx_A, beta, B, idx_B);
146 tensor_view<float> C, std::string_view idx_C) {
147 mult_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C);
150 tensor_view<double> C, std::string_view idx_C) {
151 mult_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C);
154 std::string_view idx_B, std::complex<float> beta, tensor_view<std::complex<float>> C, std::string_view idx_C) {
155 mult_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C);
158 std::string_view idx_B, std::complex<double> beta, tensor_view<std::complex<double>> C, std::string_view idx_C) {
159 mult_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C);
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
binary_op
Binary operations for tensor operations.
tensor_view< const T > const_tensor_view
Alias for a tensor_view with const value type.
Provides a C++ interface for various TBLIS tensor routines.