27namespace nda::tensor::device {
30 thread_local bool synchronize =
true;
31 void set_synchronization(
bool do_sync)
noexcept { synchronize = do_sync; }
32 bool get_synchronization() noexcept {
return synchronize; }
35 cutensorHandle_t &get_handle() {
36 struct handle_storage_t {
37 handle_storage_t() { cutensorCreate(&handle); }
38 ~handle_storage_t() { cutensorDestroy(handle); }
39 cutensorHandle_t handle = {};
41 static auto sto = handle_storage_t{};
49 void cutensor_error_check(cutensorStatus_t status, std::string_view func) {
50 if (status != CUTENSOR_STATUS_SUCCESS) {
51 NDA_RUNTIME_ERROR <<
"cuTENSOR runtime error in function " << func <<
"\n"
52 <<
" cutensorStatus_t: " << status <<
"\n"
53 <<
" cutensorGetErrorString: " << cutensorGetErrorString(status) <<
"\n";
58 template <
typename T,
typename U = std::remove_const_t<T>>
59 constexpr auto cuda_data_type() {
60 if constexpr (std::is_same_v<U, float>) {
61 return CUTENSOR_R_32F;
62 }
else if constexpr (std::is_same_v<U, double>) {
63 return CUTENSOR_R_64F;
64 }
else if constexpr (std::is_same_v<U, std::complex<float>>) {
65 return CUTENSOR_C_32F;
66 }
else if constexpr (std::is_same_v<U, std::complex<double>>) {
67 return CUTENSOR_C_64F;
72 template <
typename T,
typename U = std::remove_const_t<T>>
73 constexpr auto cutensor_compute_type() {
74 if constexpr (AnyOf<U, float, std::complex<float>>) {
75 return CUTENSOR_COMPUTE_DESC_32F;
76 }
else if constexpr (AnyOf<U, double, std::complex<double>>) {
77 return CUTENSOR_COMPUTE_DESC_64F;
83 auto find_alignment(T *p) {
84 auto const x =
reinterpret_cast<std::uintptr_t
>(p);
86 std::uintptr_t alignment = std::uintptr_t(1) << std::countr_zero(x);
87 return static_cast<std::uint32_t
>(std::min(alignment, std::uintptr_t(256)));
91 auto to_modes(std::string_view idx) {
return std::vector<std::int32_t>(idx.begin(), idx.end()); }
95 cutensorOperator_t to_cutensor_unary_op(
unary_op op) {
97 case unary_op::IDENTITY:
return CUTENSOR_OP_IDENTITY;
98 case unary_op::SQRT:
return CUTENSOR_OP_SQRT;
99 case unary_op::RELU:
return CUTENSOR_OP_RELU;
100 case unary_op::CONJ:
return CUTENSOR_OP_CONJ;
101 case unary_op::RCP:
return CUTENSOR_OP_RCP;
102 case unary_op::SIGMOID:
return CUTENSOR_OP_SIGMOID;
103 case unary_op::TANH:
return CUTENSOR_OP_TANH;
104 case unary_op::EXP:
return CUTENSOR_OP_EXP;
105 case unary_op::LOG:
return CUTENSOR_OP_LOG;
106 case unary_op::ABS:
return CUTENSOR_OP_ABS;
107 case unary_op::NEG:
return CUTENSOR_OP_NEG;
108 case unary_op::SIN:
return CUTENSOR_OP_SIN;
109 case unary_op::COS:
return CUTENSOR_OP_COS;
110 case unary_op::TAN:
return CUTENSOR_OP_TAN;
111 case unary_op::SINH:
return CUTENSOR_OP_SINH;
112 case unary_op::COSH:
return CUTENSOR_OP_COSH;
113 case unary_op::ASIN:
return CUTENSOR_OP_ASIN;
114 case unary_op::ACOS:
return CUTENSOR_OP_ACOS;
115 case unary_op::ATAN:
return CUTENSOR_OP_ATAN;
116 case unary_op::ASINH:
return CUTENSOR_OP_ASINH;
117 case unary_op::ACOSH:
return CUTENSOR_OP_ACOSH;
118 case unary_op::ATANH:
return CUTENSOR_OP_ATANH;
119 case unary_op::CEIL:
return CUTENSOR_OP_CEIL;
120 case unary_op::FLOOR:
return CUTENSOR_OP_FLOOR;
121 case unary_op::MISH:
return CUTENSOR_OP_MISH;
122 case unary_op::SWISH:
return CUTENSOR_OP_SWISH;
123 case unary_op::SOFT_PLUS:
return CUTENSOR_OP_SOFT_PLUS;
124 case unary_op::SOFT_SIGN:
return CUTENSOR_OP_SOFT_SIGN;
125 default: NDA_RUNTIME_ERROR <<
"nda::tensor::cutensor: unary_op has no cuTENSOR equivalent";
131 cutensorOperator_t to_cutensor_binary_op(
binary_op op) {
133 case binary_op::SUM:
return CUTENSOR_OP_ADD;
134 case binary_op::PROD:
return CUTENSOR_OP_MUL;
135 case binary_op::MAX:
return CUTENSOR_OP_MAX;
136 case binary_op::MIN:
return CUTENSOR_OP_MIN;
137 default: NDA_RUNTIME_ERROR <<
"nda::tensor::cutensor: binary_op has no cuTENSOR equivalent";
142 template <
typename T>
143 auto create_tensor_desc(tensor_view<T> tv) {
144 cutensorTensorDescriptor_t desc{};
145 auto status = cutensorCreateTensorDescriptor(get_handle(), &desc,
static_cast<std::uint32_t
>(tv.ndim), tv.extents, tv.strides,
146 cuda_data_type<T>(), find_alignment(tv.data));
147 cutensor_error_check(status,
"cutensorCreateTensorDescriptor");
152 void destroy_tensor_desc(cutensorTensorDescriptor_t desc) {
153 cutensor_error_check(cutensorDestroyTensorDescriptor(desc),
"cutensorDestroyTensorDescriptor");
157 cutensorPlanPreference_t create_plan_pref() {
158 cutensorPlanPreference_t pref{};
159 cutensor_error_check(cutensorCreatePlanPreference(get_handle(), &pref, CUTENSOR_ALGO_DEFAULT, CUTENSOR_JIT_MODE_NONE),
160 "cutensorCreatePlanPreference");
165 void destroy_plan_pref(cutensorPlanPreference_t pref) {
166 cutensor_error_check(cutensorDestroyPlanPreference(pref),
"cutensorDestroyPlanPreference");
170 cutensorPlan_t create_plan(cutensorOperationDescriptor_t op_desc, cutensorPlanPreference_t pref, std::uint64_t workspace_limit = 0) {
171 cutensorPlan_t plan{};
172 cutensor_error_check(cutensorCreatePlan(get_handle(), &plan, op_desc, pref, workspace_limit),
"cutensorCreatePlan");
177 void destroy_plan(cutensorPlan_t plan) { cutensor_error_check(cutensorDestroyPlan(plan),
"cutensorDestroyPlan"); }
180 void destroy_op_desc(cutensorOperationDescriptor_t op_desc) {
181 cutensor_error_check(cutensorDestroyOperationDescriptor(op_desc),
"cutensorDestroyOperationDescriptor");
185 std::uint64_t estimate_workspace(cutensorOperationDescriptor_t op_desc, cutensorPlanPreference_t pref,
186 cutensorWorksizePreference_t ws_pref = CUTENSOR_WORKSPACE_DEFAULT) {
187 std::uint64_t size = 0;
188 cutensor_error_check(cutensorEstimateWorkspaceSize(get_handle(), op_desc, pref, ws_pref, &size),
"cutensorEstimateWorkspaceSize");
193 template <
typename T>
194 void permute_impl(T alpha,
const_tensor_view<T> A, std::string_view idx_A, tensor_view<T> B, std::string_view idx_B) {
195 auto &handle = get_handle();
198 auto desc_A = create_tensor_desc(A);
199 auto desc_B = create_tensor_desc(B);
202 auto modes_A = to_modes(idx_A);
203 auto modes_B = to_modes(idx_B);
206 cutensorOperationDescriptor_t op_desc{};
207 cutensor_error_check(cutensorCreatePermutation(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_B, modes_B.data(),
208 cutensor_compute_type<T>()),
209 "cutensorCreatePermutation");
212 auto pref = create_plan_pref();
213 auto ws_limit = estimate_workspace(op_desc, pref);
214 auto plan = create_plan(op_desc, pref, ws_limit);
217 cutensor_error_check(cutensorPermute(handle, plan, &alpha, A.data, B.data,
nullptr ),
"cutensorPermute");
224 destroy_plan_pref(pref);
225 destroy_op_desc(op_desc);
226 destroy_tensor_desc(desc_A);
227 destroy_tensor_desc(desc_B);
232 template <
typename T>
235 auto &handle = get_handle();
238 auto desc_A = create_tensor_desc(A);
239 auto desc_C = create_tensor_desc(C);
240 auto desc_D = create_tensor_desc(D);
243 auto modes_A = to_modes(idx_A);
244 auto modes_C = to_modes(idx_C);
247 cutensorOperationDescriptor_t op_desc{};
248 cutensor_error_check(cutensorCreateElementwiseBinary(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_C,
249 modes_C.data(), to_cutensor_unary_op(C.op), desc_D, modes_C.data(),
250 to_cutensor_binary_op(op_AC), cutensor_compute_type<T>()),
251 "cutensorCreateElementwiseBinary");
254 auto pref = create_plan_pref();
255 auto ws_limit = estimate_workspace(op_desc, pref);
256 auto plan = create_plan(op_desc, pref, ws_limit);
259 cutensor_error_check(cutensorElementwiseBinaryExecute(handle, plan, &alpha, A.data, &gamma, C.data, D.data,
nullptr ),
260 "cutensorElementwiseBinaryExecute");
267 destroy_plan_pref(pref);
268 destroy_op_desc(op_desc);
269 destroy_tensor_desc(desc_A);
270 destroy_tensor_desc(desc_C);
271 destroy_tensor_desc(desc_D);
276 template <
typename T>
279 auto &handle = get_handle();
282 auto desc_A = create_tensor_desc(A);
283 auto desc_B = create_tensor_desc(B);
284 auto desc_C = create_tensor_desc(C);
285 auto desc_D = create_tensor_desc(D);
288 auto modes_A = to_modes(idx_A);
289 auto modes_B = to_modes(idx_B);
290 auto modes_C = to_modes(idx_C);
293 cutensorOperationDescriptor_t op_desc{};
294 cutensor_error_check(cutensorCreateElementwiseTrinary(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_B,
295 modes_B.data(), to_cutensor_unary_op(B.op), desc_C, modes_C.data(),
296 to_cutensor_unary_op(C.op), desc_D, modes_C.data(), to_cutensor_binary_op(op_AB),
297 to_cutensor_binary_op(op_ABC), cutensor_compute_type<T>()),
298 "cutensorCreateElementwiseTrinary");
301 auto pref = create_plan_pref();
302 auto ws_limit = estimate_workspace(op_desc, pref);
303 auto plan = create_plan(op_desc, pref, ws_limit);
306 cutensor_error_check(cutensorElementwiseTrinaryExecute(handle, plan, &alpha, A.data, &beta, B.data, &gamma, C.data, D.data,
nullptr ),
307 "cutensorElementwiseTrinaryExecute");
314 destroy_plan_pref(pref);
315 destroy_op_desc(op_desc);
316 destroy_tensor_desc(desc_A);
317 destroy_tensor_desc(desc_B);
318 destroy_tensor_desc(desc_C);
319 destroy_tensor_desc(desc_D);
325 template <
typename T>
328 auto &handle = get_handle();
331 auto desc_A = create_tensor_desc(A);
332 auto desc_C = create_tensor_desc(C);
333 auto desc_D = create_tensor_desc(D);
336 auto modes_A = to_modes(idx_A);
337 auto modes_C = to_modes(idx_C);
340 cutensorOperationDescriptor_t op_desc{};
341 cutensor_error_check(cutensorCreateReduction(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_C, modes_C.data(),
342 to_cutensor_unary_op(C.op), desc_D, modes_C.data(), to_cutensor_binary_op(op_reduce),
343 cutensor_compute_type<T>()),
344 "cutensorCreateReduction");
347 auto pref = create_plan_pref();
348 auto ws_limit = estimate_workspace(op_desc, pref);
349 auto plan = create_plan(op_desc, pref, ws_limit);
352 std::uint64_t ws_size = 0;
353 cutensor_error_check(cutensorPlanGetAttribute(handle, plan, CUTENSOR_PLAN_REQUIRED_WORKSPACE, &ws_size,
sizeof(ws_size)),
354 "cutensorPlanGetAttribute");
357 void *workspace =
nullptr;
358 if (ws_size > 0) {
device_error_check(cudaMalloc(&workspace, ws_size),
"cudaMalloc"); }
361 cutensor_error_check(cutensorReduce(handle, plan, &alpha, A.data, &beta, C.data, D.data, workspace, ws_size,
nullptr ),
372 destroy_plan_pref(pref);
373 destroy_op_desc(op_desc);
374 destroy_tensor_desc(desc_A);
375 destroy_tensor_desc(desc_C);
376 destroy_tensor_desc(desc_D);
381 template <
typename T>
384 auto &handle = get_handle();
387 auto desc_A = create_tensor_desc(A);
388 auto desc_B = create_tensor_desc(B);
389 auto desc_C = create_tensor_desc(C);
390 auto desc_D = create_tensor_desc(D);
393 auto modes_A = to_modes(idx_A);
394 auto modes_B = to_modes(idx_B);
395 auto modes_C = to_modes(idx_C);
398 cutensorOperationDescriptor_t op_desc{};
399 cutensor_error_check(cutensorCreateContraction(handle, &op_desc, desc_A, modes_A.data(), to_cutensor_unary_op(A.op), desc_B, modes_B.data(),
400 to_cutensor_unary_op(B.op), desc_C, modes_C.data(), to_cutensor_unary_op(C.op), desc_D,
401 modes_C.data(), cutensor_compute_type<T>()),
402 "cutensorCreateContraction");
405 auto pref = create_plan_pref();
406 auto ws_limit = estimate_workspace(op_desc, pref);
407 auto plan = create_plan(op_desc, pref, ws_limit);
410 std::uint64_t ws_size = 0;
411 cutensor_error_check(cutensorPlanGetAttribute(handle, plan, CUTENSOR_PLAN_REQUIRED_WORKSPACE, &ws_size,
sizeof(ws_size)),
412 "cutensorPlanGetAttribute");
415 void *workspace =
nullptr;
416 if (ws_size > 0) {
device_error_check(cudaMalloc(&workspace, ws_size),
"cudaMalloc"); }
419 cutensor_error_check(cutensorContract(handle, plan, &alpha, A.data, B.data, &beta, C.data, D.data, workspace, ws_size,
nullptr ),
430 destroy_plan_pref(pref);
431 destroy_op_desc(op_desc);
432 destroy_tensor_desc(desc_A);
433 destroy_tensor_desc(desc_B);
434 destroy_tensor_desc(desc_C);
435 destroy_tensor_desc(desc_D);
441 void permute(
float alpha,
const_tensor_view<float> A, std::string_view idx_A, tensor_view<float> B, std::string_view idx_B) {
442 permute_impl(alpha, A, idx_A, B, idx_B);
444 void permute(
double alpha,
const_tensor_view<double> A, std::string_view idx_A, tensor_view<double> B, std::string_view idx_B) {
445 permute_impl(alpha, A, idx_A, B, idx_B);
447 void permute(std::complex<float> alpha,
const_tensor_view<std::complex<float>> A, std::string_view idx_A, tensor_view<std::complex<float>> B,
448 std::string_view idx_B) {
449 permute_impl(alpha, A, idx_A, B, idx_B);
451 void permute(std::complex<double> alpha,
const_tensor_view<std::complex<double>> A, std::string_view idx_A, tensor_view<std::complex<double>> B,
452 std::string_view idx_B) {
453 permute_impl(alpha, A, idx_A, B, idx_B);
458 std::string_view idx_C, tensor_view<float> D,
binary_op op_AC) {
459 elementwise_binary_impl(alpha, A, idx_A, gamma, C, idx_C, D, op_AC);
462 std::string_view idx_C, tensor_view<double> D,
binary_op op_AC) {
463 elementwise_binary_impl(alpha, A, idx_A, gamma, C, idx_C, D, op_AC);
465 void elementwise_binary(std::complex<float> alpha,
const_tensor_view<std::complex<float>> A, std::string_view idx_A, std::complex<float> gamma,
467 elementwise_binary_impl(alpha, A, idx_A, gamma, C, idx_C, D, op_AC);
469 void elementwise_binary(std::complex<double> alpha,
const_tensor_view<std::complex<double>> A, std::string_view idx_A, std::complex<double> gamma,
471 elementwise_binary_impl(alpha, A, idx_A, gamma, C, idx_C, D, op_AC);
478 elementwise_trinary_impl(alpha, A, idx_A, beta, B, idx_B, gamma, C, idx_C, D, op_AB, op_ABC);
483 elementwise_trinary_impl(alpha, A, idx_A, beta, B, idx_B, gamma, C, idx_C, D, op_AB, op_ABC);
485 void elementwise_trinary(std::complex<float> alpha,
const_tensor_view<std::complex<float>> A, std::string_view idx_A, std::complex<float> beta,
486 const_tensor_view<std::complex<float>> B, std::string_view idx_B, std::complex<float> gamma,
489 elementwise_trinary_impl(alpha, A, idx_A, beta, B, idx_B, gamma, C, idx_C, D, op_AB, op_ABC);
491 void elementwise_trinary(std::complex<double> alpha,
const_tensor_view<std::complex<double>> A, std::string_view idx_A, std::complex<double> beta,
492 const_tensor_view<std::complex<double>> B, std::string_view idx_B, std::complex<double> gamma,
495 elementwise_trinary_impl(alpha, A, idx_A, beta, B, idx_B, gamma, C, idx_C, D, op_AB, op_ABC);
500 tensor_view<float> D,
binary_op op_reduce) {
501 reduce_impl(alpha, A, idx_A, beta, C, idx_C, D, op_reduce);
504 tensor_view<double> D,
binary_op op_reduce) {
505 reduce_impl(alpha, A, idx_A, beta, C, idx_C, D, op_reduce);
507 void reduce(std::complex<float> alpha,
const_tensor_view<std::complex<float>> A, std::string_view idx_A, std::complex<float> beta,
509 reduce_impl(alpha, A, idx_A, beta, C, idx_C, D, op_reduce);
511 void reduce(std::complex<double> alpha,
const_tensor_view<std::complex<double>> A, std::string_view idx_A, std::complex<double> beta,
512 const_tensor_view<std::complex<double>> C, std::string_view idx_C, tensor_view<std::complex<double>> D,
binary_op op_reduce) {
513 reduce_impl(alpha, A, idx_A, beta, C, idx_C, D, op_reduce);
519 contract_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C, D);
523 contract_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C, D);
526 std::string_view idx_B, std::complex<float> beta,
const_tensor_view<std::complex<float>> C, std::string_view idx_C,
527 tensor_view<std::complex<float>> D) {
528 contract_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C, D);
530 void contract(std::complex<double> alpha,
const_tensor_view<std::complex<double>> A, std::string_view idx_A,
531 const_tensor_view<std::complex<double>> B, std::string_view idx_B, std::complex<double> beta,
532 const_tensor_view<std::complex<double>> C, std::string_view idx_C, tensor_view<std::complex<double>> D) {
533 contract_impl(alpha, A, idx_A, B, idx_B, beta, C, idx_C, D);
Provides concepts for the nda library.
Provides a C++ interface for various cuTENSOR routines.
Provides GPU and non-GPU specific functionality.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
#define device_error_check(ARG1, ARG2)
Trigger a compilation error every time the nda::device_error_check function is called.
void cuda_device_sync(bool do_sync=true, std::string_view func="")
Empty function if CudaSupport is not enabled.
unary_op
Unary element-wise operations for tensor operations.
binary_op
Binary operations for tensor operations.
tensor_view< const T > const_tensor_view
Alias for a tensor_view with const value type.