TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
cutensor_interface.cpp
Go to the documentation of this file.
1// Copyright (c) 2024--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
12#include "../tools.hpp"
13#include "../../concepts.hpp"
14#include "../../device.hpp"
15#include "../../exceptions.hpp"
16
17#include <cutensor.h>
18
19#include <algorithm>
20#include <bit>
21#include <complex>
22#include <cstdint>
23#include <string_view>
24#include <type_traits>
25#include <vector>
26
27namespace nda::tensor::device {
28
29 // File-local synchronization flag, exposed via the get/set functions below (matches the BLAS interface pattern).
30 thread_local bool synchronize = true; // NOLINT (per-thread option is on purpose)
31 void set_synchronization(bool do_sync) noexcept { synchronize = do_sync; }
32 bool get_synchronization() noexcept { return synchronize; }
33
34 // Get the cutensor handle.
35 cutensorHandle_t &get_handle() {
36 struct handle_storage_t { // RAII for handle
37 handle_storage_t() { cutensorCreate(&handle); }
38 ~handle_storage_t() { cutensorDestroy(handle); }
39 cutensorHandle_t handle = {};
40 };
41 static auto sto = handle_storage_t{};
42 return sto.handle;
43 }
44
45 // Anonymous namespace for local functions.
46 namespace {
47
48 // Check the success of a cutensor operation.
49 void cutensor_error_check(cutensorStatus_t status, std::string_view func) {
50 if (status != CUTENSOR_STATUS_SUCCESS) {
51 NDA_RUNTIME_ERROR << "cuTENSOR runtime error in function " << func << "\n"
52 << " cutensorStatus_t: " << status << "\n"
53 << " cutensorGetErrorString: " << cutensorGetErrorString(status) << "\n";
54 }
55 }
56
57 // Cuda data type conversion.
58 template <typename T, typename U = std::remove_const_t<T>>
59 constexpr auto cuda_data_type() {
60 if constexpr (std::is_same_v<U, float>) {
61 return CUTENSOR_R_32F;
62 } else if constexpr (std::is_same_v<U, double>) {
63 return CUTENSOR_R_64F;
64 } else if constexpr (std::is_same_v<U, std::complex<float>>) {
65 return CUTENSOR_C_32F;
66 } else if constexpr (std::is_same_v<U, std::complex<double>>) {
67 return CUTENSOR_C_64F;
68 }
69 }
70
71 // Cutensor compute type conversion.
72 template <typename T, typename U = std::remove_const_t<T>>
73 constexpr auto cutensor_compute_type() {
74 if constexpr (AnyOf<U, float, std::complex<float>>) {
75 return CUTENSOR_COMPUTE_DESC_32F;
76 } else if constexpr (AnyOf<U, double, std::complex<double>>) {
77 return CUTENSOR_COMPUTE_DESC_64F;
78 }
79 }
80
81 // Find the pointer alignment for a given pointer, capped at 256.
82 template <typename T>
83 auto find_alignment(T *p) {
84 auto const x = reinterpret_cast<std::uintptr_t>(p); // NOLINT (reinterpret_cast is necessary here)
85 // largest power-of-two divisor of the address, capped at 256 (cudaMalloc default alignment)
86 std::uintptr_t alignment = std::uintptr_t(1) << std::countr_zero(x);
87 return static_cast<std::uint32_t>(std::min(alignment, std::uintptr_t(256)));
88 }
89
90 // Convert an index string to a vector of int32_t mode labels for cuTENSOR.
91 auto to_modes(std::string_view idx) { return std::vector<std::int32_t>(idx.begin(), idx.end()); }
92
93 // Map unary_op enum to cuTENSOR unary operator.
94 // clang-format off
95 cutensorOperator_t to_cutensor_unary_op(unary_op op) {
96 switch (op) {
97 case unary_op::IDENTITY: return CUTENSOR_OP_IDENTITY;
98 case unary_op::SQRT: return CUTENSOR_OP_SQRT;
99 case unary_op::RELU: return CUTENSOR_OP_RELU;
100 case unary_op::CONJ: return CUTENSOR_OP_CONJ;
101 case unary_op::RCP: return CUTENSOR_OP_RCP;
102 case unary_op::SIGMOID: return CUTENSOR_OP_SIGMOID;
103 case unary_op::TANH: return CUTENSOR_OP_TANH;
104 case unary_op::EXP: return CUTENSOR_OP_EXP;
105 case unary_op::LOG: return CUTENSOR_OP_LOG;
106 case unary_op::ABS: return CUTENSOR_OP_ABS;
107 case unary_op::NEG: return CUTENSOR_OP_NEG;
108 case unary_op::SIN: return CUTENSOR_OP_SIN;
109 case unary_op::COS: return CUTENSOR_OP_COS;
110 case unary_op::TAN: return CUTENSOR_OP_TAN;
111 case unary_op::SINH: return CUTENSOR_OP_SINH;
112 case unary_op::COSH: return CUTENSOR_OP_COSH;
113 case unary_op::ASIN: return CUTENSOR_OP_ASIN;
114 case unary_op::ACOS: return CUTENSOR_OP_ACOS;
115 case unary_op::ATAN: return CUTENSOR_OP_ATAN;
116 case unary_op::ASINH: return CUTENSOR_OP_ASINH;
117 case unary_op::ACOSH: return CUTENSOR_OP_ACOSH;
118 case unary_op::ATANH: return CUTENSOR_OP_ATANH;
119 case unary_op::CEIL: return CUTENSOR_OP_CEIL;
120 case unary_op::FLOOR: return CUTENSOR_OP_FLOOR;
121 case unary_op::MISH: return CUTENSOR_OP_MISH;
122 case unary_op::SWISH: return CUTENSOR_OP_SWISH;
123 case unary_op::SOFT_PLUS: return CUTENSOR_OP_SOFT_PLUS;
124 case unary_op::SOFT_SIGN: return CUTENSOR_OP_SOFT_SIGN;
125 default: NDA_RUNTIME_ERROR << "nda::tensor::cutensor: unary_op has no cuTENSOR equivalent";
126 }
127 }
128 // clang-format on
129
130 // Map binary_op enum to cuTENSOR binary operator.
131 cutensorOperator_t to_cutensor_binary_op(binary_op op) {
132 switch (op) {
133 case binary_op::SUM: return CUTENSOR_OP_ADD;
134 case binary_op::PROD: return CUTENSOR_OP_MUL;
135 case binary_op::MAX: return CUTENSOR_OP_MAX;
136 case binary_op::MIN: return CUTENSOR_OP_MIN;
137 default: NDA_RUNTIME_ERROR << "nda::tensor::cutensor: binary_op has no cuTENSOR equivalent";
138 }
139 }
140
141 // Create a tensor descriptor from a tensor view.
142 template <typename T>
143 auto create_tensor_desc(tensor_view<T> tv) {
144 cutensorTensorDescriptor_t desc{};
145 auto status = cutensorCreateTensorDescriptor(get_handle(), &desc, static_cast<std::uint32_t>(tv.ndim), tv.extents, tv.strides,
146 cuda_data_type<T>(), find_alignment(tv.data));
147 cutensor_error_check(status, "cutensorCreateTensorDescriptor");
148 return desc;
149 }
150
151 // Destroy a given tensor descriptor.
152 void destroy_tensor_desc(cutensorTensorDescriptor_t desc) {
153 cutensor_error_check(cutensorDestroyTensorDescriptor(desc), "cutensorDestroyTensorDescriptor");
154 }
155
156 // Create a plan preference with default algorithm and no JIT.
157 cutensorPlanPreference_t create_plan_pref() {
158 cutensorPlanPreference_t pref{};
159 cutensor_error_check(cutensorCreatePlanPreference(get_handle(), &pref, CUTENSOR_ALGO_DEFAULT, CUTENSOR_JIT_MODE_NONE),
160 "cutensorCreatePlanPreference");
161 return pref;
162 }
163
164 // Destroy a given plan preference.
165 void destroy_plan_pref(cutensorPlanPreference_t pref) {
166 cutensor_error_check(cutensorDestroyPlanPreference(pref), "cutensorDestroyPlanPreference");
167 }
168
169 // Create an execution plan from an operation descriptor, plan preference and workspace size limit.
170 cutensorPlan_t create_plan(cutensorOperationDescriptor_t op_desc, cutensorPlanPreference_t pref, std::uint64_t workspace_limit = 0) {
171 cutensorPlan_t plan{};
172 cutensor_error_check(cutensorCreatePlan(get_handle(), &plan, op_desc, pref, workspace_limit), "cutensorCreatePlan");
173 return plan;
174 }
175
176 // Destroy a given execution plan.
177 void destroy_plan(cutensorPlan_t plan) { cutensor_error_check(cutensorDestroyPlan(plan), "cutensorDestroyPlan"); }
178
179 // Destroy a given operation descriptor.
180 void destroy_op_desc(cutensorOperationDescriptor_t op_desc) {
181 cutensor_error_check(cutensorDestroyOperationDescriptor(op_desc), "cutensorDestroyOperationDescriptor");
182 }
183
184 // Estimate the workspace size for an operation.
185 std::uint64_t estimate_workspace(cutensorOperationDescriptor_t op_desc, cutensorPlanPreference_t pref,
186 cutensorWorksizePreference_t ws_pref = CUTENSOR_WORKSPACE_DEFAULT) {
187 std::uint64_t size = 0;
188 cutensor_error_check(cutensorEstimateWorkspaceSize(get_handle(), op_desc, pref, ws_pref, &size), "cutensorEstimateWorkspaceSize");
189 return size;
190 }
191
192 // Helper function to call the cuTENSOR permutation routine: B = alpha * opA(A).
193 template <typename T>
194 void permute_impl(T alpha, const_tensor_view<T> A, std::string_view idx_A, tensor_view<T> B, std::string_view idx_B) {
195 auto &handle = get_handle();
196
197 // create tensor descriptors
198 auto desc_A = create_tensor_desc(A);
199 auto desc_B = create_tensor_desc(B);
200
201 // convert index strings to mode arrays
202 auto modes_A = to_modes(idx_A);
203 auto modes_B = to_modes(idx_B);
204
205 // create operation descriptor
206 cutensorOperationDescriptor_t op_desc{};
207 cutensor_error_check(cutensorCreatePermutation(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_B, modes_B.data(),
208 cutensor_compute_type<T>()),
209 "cutensorCreatePermutation");
210
211 // create plan preference, estimate workspace, and create plan
212 auto pref = create_plan_pref();
213 auto ws_limit = estimate_workspace(op_desc, pref);
214 auto plan = create_plan(op_desc, pref, ws_limit);
215
216 // execute permutation
217 cutensor_error_check(cutensorPermute(handle, plan, &alpha, A.data, B.data, nullptr /*stream*/), "cutensorPermute");
218
219 // synchronize
220 cuda_device_sync(synchronize, "cutensorPermute");
221
222 // cleanup
223 destroy_plan(plan);
224 destroy_plan_pref(pref);
225 destroy_op_desc(op_desc);
226 destroy_tensor_desc(desc_A);
227 destroy_tensor_desc(desc_B);
228 }
229
230 // Helper function to call the cuTENSOR elementwise binary routine: D = op_AC(alpha * op_A(A), gamma * op_C(C)).
231 // D must have the same descriptor (shape/strides) as C but may point to different memory.
232 template <typename T>
233 void elementwise_binary_impl(T alpha, const_tensor_view<T> A, std::string_view idx_A, T gamma, const_tensor_view<T> C, std::string_view idx_C,
234 tensor_view<T> D, binary_op op_AC) {
235 auto &handle = get_handle();
236
237 // create tensor descriptors (D descriptor must match C in shape/modes)
238 auto desc_A = create_tensor_desc(A);
239 auto desc_C = create_tensor_desc(C);
240 auto desc_D = create_tensor_desc(D);
241
242 // convert index strings to mode arrays
243 auto modes_A = to_modes(idx_A);
244 auto modes_C = to_modes(idx_C);
245
246 // create operation descriptor (D has same modes as C)
247 cutensorOperationDescriptor_t op_desc{};
248 cutensor_error_check(cutensorCreateElementwiseBinary(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_C,
249 modes_C.data(), to_cutensor_unary_op(C.op), desc_D, modes_C.data(),
250 to_cutensor_binary_op(op_AC), cutensor_compute_type<T>()),
251 "cutensorCreateElementwiseBinary");
252
253 // create plan preference, estimate workspace, and create plan
254 auto pref = create_plan_pref();
255 auto ws_limit = estimate_workspace(op_desc, pref);
256 auto plan = create_plan(op_desc, pref, ws_limit);
257
258 // execute elementwise binary
259 cutensor_error_check(cutensorElementwiseBinaryExecute(handle, plan, &alpha, A.data, &gamma, C.data, D.data, nullptr /*stream*/),
260 "cutensorElementwiseBinaryExecute");
261
262 // synchronize
263 cuda_device_sync(synchronize, "cutensorElementwiseBinaryExecute");
264
265 // cleanup
266 destroy_plan(plan);
267 destroy_plan_pref(pref);
268 destroy_op_desc(op_desc);
269 destroy_tensor_desc(desc_A);
270 destroy_tensor_desc(desc_C);
271 destroy_tensor_desc(desc_D);
272 }
273
274 // Helper function to call the cuTENSOR elementwise trinary routine: D = op_ABC(op_AB(alpha * op_A(A), beta * op_B(B)), gamma * op_C(C)).
275 // D must have the same descriptor (shape/strides) as C but may point to different memory.
276 template <typename T>
277 void elementwise_trinary_impl(T alpha, const_tensor_view<T> A, std::string_view idx_A, T beta, const_tensor_view<T> B, std::string_view idx_B,
278 T gamma, const_tensor_view<T> C, std::string_view idx_C, tensor_view<T> D, binary_op op_AB, binary_op op_ABC) {
279 auto &handle = get_handle();
280
281 // create tensor descriptors (D descriptor must match C in shape/modes)
282 auto desc_A = create_tensor_desc(A);
283 auto desc_B = create_tensor_desc(B);
284 auto desc_C = create_tensor_desc(C);
285 auto desc_D = create_tensor_desc(D);
286
287 // convert index strings to mode arrays
288 auto modes_A = to_modes(idx_A);
289 auto modes_B = to_modes(idx_B);
290 auto modes_C = to_modes(idx_C);
291
292 // create operation descriptor (D has same modes as C)
293 cutensorOperationDescriptor_t op_desc{};
294 cutensor_error_check(cutensorCreateElementwiseTrinary(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_B,
295 modes_B.data(), to_cutensor_unary_op(B.op), desc_C, modes_C.data(),
296 to_cutensor_unary_op(C.op), desc_D, modes_C.data(), to_cutensor_binary_op(op_AB),
297 to_cutensor_binary_op(op_ABC), cutensor_compute_type<T>()),
298 "cutensorCreateElementwiseTrinary");
299
300 // create plan preference, estimate workspace, and create plan
301 auto pref = create_plan_pref();
302 auto ws_limit = estimate_workspace(op_desc, pref);
303 auto plan = create_plan(op_desc, pref, ws_limit);
304
305 // execute elementwise trinary
306 cutensor_error_check(cutensorElementwiseTrinaryExecute(handle, plan, &alpha, A.data, &beta, B.data, &gamma, C.data, D.data, nullptr /*stream*/),
307 "cutensorElementwiseTrinaryExecute");
308
309 // synchronize
310 cuda_device_sync(synchronize, "cutensorElementwiseTrinaryExecute");
311
312 // cleanup
313 destroy_plan(plan);
314 destroy_plan_pref(pref);
315 destroy_op_desc(op_desc);
316 destroy_tensor_desc(desc_A);
317 destroy_tensor_desc(desc_B);
318 destroy_tensor_desc(desc_C);
319 destroy_tensor_desc(desc_D);
320 }
321
322 // Helper function to call the cuTENSOR reduction routine: D = alpha * opReduce(op_A(A)) + beta * op_C(C).
323 // D must have the same descriptor (shape/strides) as C but may point to different memory.
324 // The modes of C/D must be a subset of the modes of A. The modes in A but not in C are reduced.
325 template <typename T>
326 void reduce_impl(T alpha, const_tensor_view<T> A, std::string_view idx_A, T beta, const_tensor_view<T> C, std::string_view idx_C,
327 tensor_view<T> D, binary_op op_reduce) {
328 auto &handle = get_handle();
329
330 // create tensor descriptors (D descriptor must match C in shape/modes)
331 auto desc_A = create_tensor_desc(A);
332 auto desc_C = create_tensor_desc(C);
333 auto desc_D = create_tensor_desc(D);
334
335 // convert index strings to mode arrays
336 auto modes_A = to_modes(idx_A);
337 auto modes_C = to_modes(idx_C);
338
339 // create operation descriptor (D has same modes as C)
340 cutensorOperationDescriptor_t op_desc{};
341 cutensor_error_check(cutensorCreateReduction(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_C, modes_C.data(),
342 to_cutensor_unary_op(C.op), desc_D, modes_C.data(), to_cutensor_binary_op(op_reduce),
343 cutensor_compute_type<T>()),
344 "cutensorCreateReduction");
345
346 // create plan preference, estimate workspace, and create plan
347 auto pref = create_plan_pref();
348 auto ws_limit = estimate_workspace(op_desc, pref);
349 auto plan = create_plan(op_desc, pref, ws_limit);
350
351 // query the actual required workspace size from the plan
352 std::uint64_t ws_size = 0;
353 cutensor_error_check(cutensorPlanGetAttribute(handle, plan, CUTENSOR_PLAN_REQUIRED_WORKSPACE, &ws_size, sizeof(ws_size)),
354 "cutensorPlanGetAttribute");
355
356 // allocate workspace
357 void *workspace = nullptr;
358 if (ws_size > 0) { device_error_check(cudaMalloc(&workspace, ws_size), "cudaMalloc"); }
359
360 // execute reduction
361 cutensor_error_check(cutensorReduce(handle, plan, &alpha, A.data, &beta, C.data, D.data, workspace, ws_size, nullptr /*stream*/),
362 "cutensorReduce");
363
364 // synchronize
365 cuda_device_sync(synchronize, "cutensorReduce");
366
367 // free workspace
368 if (workspace) { device_error_check(cudaFree(workspace), "cudaFree"); }
369
370 // cleanup
371 destroy_plan(plan);
372 destroy_plan_pref(pref);
373 destroy_op_desc(op_desc);
374 destroy_tensor_desc(desc_A);
375 destroy_tensor_desc(desc_C);
376 destroy_tensor_desc(desc_D);
377 }
378
379 // Helper function to call the cuTENSOR contraction routine: D = alpha * op_A(A) * op_B(B) + beta * op_C(C).
380 // D must have the same descriptor (shape/strides) as C but may point to different memory.
381 template <typename T>
382 void contract_impl(T alpha, const_tensor_view<T> A, std::string_view idx_A, const_tensor_view<T> B, std::string_view idx_B, T beta,
383 const_tensor_view<T> C, std::string_view idx_C, tensor_view<T> D) {
384 auto &handle = get_handle();
385
386 // create tensor descriptors (D descriptor must match C in shape/modes)
387 auto desc_A = create_tensor_desc(A);
388 auto desc_B = create_tensor_desc(B);
389 auto desc_C = create_tensor_desc(C);
390 auto desc_D = create_tensor_desc(D);
391
392 // convert index strings to mode arrays
393 auto modes_A = to_modes(idx_A);
394 auto modes_B = to_modes(idx_B);
395 auto modes_C = to_modes(idx_C);
396
397 // create operation descriptor (D has same modes as C)
398 cutensorOperationDescriptor_t op_desc{};
399 cutensor_error_check(cutensorCreateContraction(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_B, modes_B.data(),
400 to_cutensor_unary_op(B.op), desc_C, modes_C.data(), to_cutensor_unary_op(C.op), desc_D,
401 modes_C.data(), cutensor_compute_type<T>()),
402 "cutensorCreateContraction");
403
404 // create plan preference, estimate workspace, and create plan
405 auto pref = create_plan_pref();
406 auto ws_limit = estimate_workspace(op_desc, pref);
407 auto plan = create_plan(op_desc, pref, ws_limit);
408
409 // query the actual required workspace size from the plan
410 std::uint64_t ws_size = 0;
411 cutensor_error_check(cutensorPlanGetAttribute(handle, plan, CUTENSOR_PLAN_REQUIRED_WORKSPACE, &ws_size, sizeof(ws_size)),
412 "cutensorPlanGetAttribute");
413
414 // allocate workspace
415 void *workspace = nullptr;
416 if (ws_size > 0) { device_error_check(cudaMalloc(&workspace, ws_size), "cudaMalloc"); }
417
418 // execute contraction
419 cutensor_error_check(cutensorContract(handle, plan, &alpha, A.data, B.data, &beta, C.data, D.data, workspace, ws_size, nullptr /*stream*/),
420 "cutensorContract");
421
422 // synchronize
423 cuda_device_sync(synchronize, "cutensorContract");
424
425 // free workspace
426 if (workspace) { device_error_check(cudaFree(workspace), "cudaFree"); }
427
428 // cleanup
429 destroy_plan(plan);
430 destroy_plan_pref(pref);
431 destroy_op_desc(op_desc);
432 destroy_tensor_desc(desc_A);
433 destroy_tensor_desc(desc_B);
434 destroy_tensor_desc(desc_C);
435 destroy_tensor_desc(desc_D);
436 }
437
438 } // namespace
439
440 // permute
441 void permute(float alpha, const_tensor_view<float> A, std::string_view idx_A, tensor_view<float> B, std::string_view idx_B) {
442 permute_impl(alpha, A, idx_A, B, idx_B);
443 }
444 void permute(double alpha, const_tensor_view<double> A, std::string_view idx_A, tensor_view<double> B, std::string_view idx_B) {
445 permute_impl(alpha, A, idx_A, B, idx_B);
446 }
447 void permute(std::complex<float> alpha, const_tensor_view<std::complex<float>> A, std::string_view idx_A, tensor_view<std::complex<float>> B,
448 std::string_view idx_B) {
449 permute_impl(alpha, A, idx_A, B, idx_B);
450 }
451 void permute(std::complex<double> alpha, const_tensor_view<std::complex<double>> A, std::string_view idx_A, tensor_view<std::complex<double>> B,
452 std::string_view idx_B) {
453 permute_impl(alpha, A, idx_A, B, idx_B);
454 }
455
456 // elementwise_binary
457 void elementwise_binary(float alpha, const_tensor_view<float> A, std::string_view idx_A, float gamma, const_tensor_view<float> C,
458 std::string_view idx_C, tensor_view<float> D, binary_op op_AC) {
459 elementwise_binary_impl(alpha, A, idx_A, gamma, C, idx_C, D, op_AC);
460 }
461 void elementwise_binary(double alpha, const_tensor_view<double> A, std::string_view idx_A, double gamma, const_tensor_view<double> C,
462 std::string_view idx_C, tensor_view<double> D, binary_op op_AC) {
463 elementwise_binary_impl(alpha, A, idx_A, gamma, C, idx_C, D, op_AC);
464 }
465 void elementwise_binary(std::complex<float> alpha, const_tensor_view<std::complex<float>> A, std::string_view idx_A, std::complex<float> gamma,
466 const_tensor_view<std::complex<float>> C, std::string_view idx_C, tensor_view<std::complex<float>> D, binary_op op_AC) {
467 elementwise_binary_impl(alpha, A, idx_A, gamma, C, idx_C, D, op_AC);
468 }
469 void elementwise_binary(std::complex<double> alpha, const_tensor_view<std::complex<double>> A, std::string_view idx_A, std::complex<double> gamma,
470 const_tensor_view<std::complex<double>> C, std::string_view idx_C, tensor_view<std::complex<double>> D, binary_op op_AC) {
471 elementwise_binary_impl(alpha, A, idx_A, gamma, C, idx_C, D, op_AC);
472 }
473
474 // elementwise_trinary
475 void elementwise_trinary(float alpha, const_tensor_view<float> A, std::string_view idx_A, float beta, const_tensor_view<float> B,
476 std::string_view idx_B, float gamma, const_tensor_view<float> C, std::string_view idx_C, tensor_view<float> D,
477 binary_op op_AB, binary_op op_ABC) {
478 elementwise_trinary_impl(alpha, A, idx_A, beta, B, idx_B, gamma, C, idx_C, D, op_AB, op_ABC);
479 }
480 void elementwise_trinary(double alpha, const_tensor_view<double> A, std::string_view idx_A, double beta, const_tensor_view<double> B,
481 std::string_view idx_B, double gamma, const_tensor_view<double> C, std::string_view idx_C, tensor_view<double> D,
482 binary_op op_AB, binary_op op_ABC) {
483 elementwise_trinary_impl(alpha, A, idx_A, beta, B, idx_B, gamma, C, idx_C, D, op_AB, op_ABC);
484 }
485 void elementwise_trinary(std::complex<float> alpha, const_tensor_view<std::complex<float>> A, std::string_view idx_A, std::complex<float> beta,
486 const_tensor_view<std::complex<float>> B, std::string_view idx_B, std::complex<float> gamma,
487 const_tensor_view<std::complex<float>> C, std::string_view idx_C, tensor_view<std::complex<float>> D, binary_op op_AB,
488 binary_op op_ABC) {
489 elementwise_trinary_impl(alpha, A, idx_A, beta, B, idx_B, gamma, C, idx_C, D, op_AB, op_ABC);
490 }
491 void elementwise_trinary(std::complex<double> alpha, const_tensor_view<std::complex<double>> A, std::string_view idx_A, std::complex<double> beta,
492 const_tensor_view<std::complex<double>> B, std::string_view idx_B, std::complex<double> gamma,
493 const_tensor_view<std::complex<double>> C, std::string_view idx_C, tensor_view<std::complex<double>> D, binary_op op_AB,
494 binary_op op_ABC) {
495 elementwise_trinary_impl(alpha, A, idx_A, beta, B, idx_B, gamma, C, idx_C, D, op_AB, op_ABC);
496 }
497
498 // reduce
499 void reduce(float alpha, const_tensor_view<float> A, std::string_view idx_A, float beta, const_tensor_view<float> C, std::string_view idx_C,
500 tensor_view<float> D, binary_op op_reduce) {
501 reduce_impl(alpha, A, idx_A, beta, C, idx_C, D, op_reduce);
502 }
503 void reduce(double alpha, const_tensor_view<double> A, std::string_view idx_A, double beta, const_tensor_view<double> C, std::string_view idx_C,
504 tensor_view<double> D, binary_op op_reduce) {
505 reduce_impl(alpha, A, idx_A, beta, C, idx_C, D, op_reduce);
506 }
507 void reduce(std::complex<float> alpha, const_tensor_view<std::complex<float>> A, std::string_view idx_A, std::complex<float> beta,
508 const_tensor_view<std::complex<float>> C, std::string_view idx_C, tensor_view<std::complex<float>> D, binary_op op_reduce) {
509 reduce_impl(alpha, A, idx_A, beta, C, idx_C, D, op_reduce);
510 }
511 void reduce(std::complex<double> alpha, const_tensor_view<std::complex<double>> A, std::string_view idx_A, std::complex<double> beta,
512 const_tensor_view<std::complex<double>> C, std::string_view idx_C, tensor_view<std::complex<double>> D, binary_op op_reduce) {
513 reduce_impl(alpha, A, idx_A, beta, C, idx_C, D, op_reduce);
514 }
515
516 // contract
517 void contract(float alpha, const_tensor_view<float> A, std::string_view idx_A, const_tensor_view<float> B, std::string_view idx_B, float beta,
518 const_tensor_view<float> C, std::string_view idx_C, tensor_view<float> D) {
519 contract_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C, D);
520 }
521 void contract(double alpha, const_tensor_view<double> A, std::string_view idx_A, const_tensor_view<double> B, std::string_view idx_B, double beta,
522 const_tensor_view<double> C, std::string_view idx_C, tensor_view<double> D) {
523 contract_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C, D);
524 }
525 void contract(std::complex<float> alpha, const_tensor_view<std::complex<float>> A, std::string_view idx_A, const_tensor_view<std::complex<float>> B,
526 std::string_view idx_B, std::complex<float> beta, const_tensor_view<std::complex<float>> C, std::string_view idx_C,
527 tensor_view<std::complex<float>> D) {
528 contract_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C, D);
529 }
530 void contract(std::complex<double> alpha, const_tensor_view<std::complex<double>> A, std::string_view idx_A,
531 const_tensor_view<std::complex<double>> B, std::string_view idx_B, std::complex<double> beta,
532 const_tensor_view<std::complex<double>> C, std::string_view idx_C, tensor_view<std::complex<double>> D) {
533 contract_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C, D);
534 }
535
536} // namespace nda::tensor::device
Provides concepts for the nda library.
Provides a C++ interface for various cuTENSOR routines.
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
#define device_error_check(ARG1, ARG2)
Trigger a compilation error every time the nda::device_error_check function is called.
Definition device.hpp:196
void cuda_device_sync(bool do_sync=true, std::string_view func="")
Empty function if CudaSupport is not enabled.
Definition device.hpp:205
unary_op
Unary element-wise operations for tensor operations.
Definition tools.hpp:103
binary_op
Binary operations for tensor operations.
Definition tools.hpp:67
tensor_view< const T > const_tensor_view
Alias for a tensor_view with const value type.
Definition tools.hpp:234
Provides various traits and utilities for the tensor interface.