TRIQS/nda 2.0.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
device.hpp
Go to the documentation of this file.
1// Copyright (c) 2023--present, The Simons Foundation
2// This file is part of TRIQS/nda and is licensed under the Apache License, Version 2.0.
3// SPDX-License-Identifier: Apache-2.0
4// See LICENSE in the root of this distribution for details.
5
10
11#pragma once
12
13#include <string_view>
14
15#ifdef NDA_HAVE_CUDA
16#include "./concepts.hpp"
17#include "./exceptions.hpp"
18
19#include <cuda_runtime.h>
20#include <cublas_v2.h>
21
22#include <complex>
23#include <exception>
24#include <string>
25#include <type_traits>
26#endif // NDA_HAVE_CUDA
27
28namespace nda {
29
34
39 template <bool flag = false>
41 static_assert(flag, "Using device functionality without gpu support! Configure project with -DCudaSupport=ON.");
42 }
43
44#ifdef NDA_HAVE_CUDA
45
47 static constexpr bool have_device = true;
48
50 static constexpr bool have_cuda = true;
51
58 inline void device_error_check(cudaError_t success, std::string message = "") {
59 if (success != cudaSuccess) {
60 NDA_RUNTIME_ERROR << "Cuda runtime error: " << std::to_string(success) << "\n"
61 << " message: " << message << "\n"
62 << " cudaGetErrorName: " << std::string(cudaGetErrorName(success)) << "\n"
63 << " cudaGetErrorString: " << std::string(cudaGetErrorString(success)) << "\n";
64 }
65 }
66
73 inline void cuda_device_sync(bool do_sync = true, std::string_view func = "") {
74 if (!do_sync) return;
75 std::string msg = "cudaDeviceSynchronize failed";
76 if (!func.empty()) {
77 msg += " after call to ";
78 msg.append(func);
79 }
80 device_error_check(cudaDeviceSynchronize(), std::move(msg));
81 }
82
95 inline cublasOperation_t get_cublas_op(char op) {
96 switch (op) {
97 case 'N': return CUBLAS_OP_N;
98 case 'T': return CUBLAS_OP_T;
99 case 'C': return CUBLAS_OP_C;
100 default: std::terminate(); return {};
101 }
102 }
103
109 template <FloatOrDouble T>
110 using cuda_complex_t = std::conditional_t<std::is_same_v<T, float>, cuComplex, cuDoubleComplex>;
111
123 template <FloatOrDouble T>
124 cuda_complex_t<T> cucplx(std::complex<T> c) {
125 return {c.real(), c.imag()};
126 }
127
139 template <FloatOrDouble T>
140 cuda_complex_t<T> *cucplx(std::complex<T> *c) {
141 return reinterpret_cast<cuda_complex_t<T> *>(c); // NOLINT
142 }
143
155 template <FloatOrDouble T>
156 cuda_complex_t<T> const *cucplx(std::complex<T> const *c) {
157 return reinterpret_cast<cuda_complex_t<T> const *>(c); // NOLINT
158 }
159
171 template <FloatOrDouble T>
172 cuda_complex_t<T> **cucplx(std::complex<T> **c) {
173 return reinterpret_cast<cuda_complex_t<T> **>(c); // NOLINT
174 }
175
188 template <FloatOrDouble T>
189 cuda_complex_t<T> const **cucplx(std::complex<T> const **c) {
190 return reinterpret_cast<cuda_complex_t<T> const **>(c); // NOLINT
191 }
192
193#else
194
196#define device_error_check(ARG1, ARG2) compile_error_no_gpu()
197
199 static constexpr bool have_device = false;
200
202 static constexpr bool have_cuda = false;
203
205 inline void cuda_device_sync([[maybe_unused]] bool do_sync = true, [[maybe_unused]] std::string_view func = "") {}
206
207#endif // NDA_HAVE_CUDA
208
210
211} // namespace nda
Provides concepts for the nda library.
Provides a custom runtime error class and macros to assert conditions and throw exceptions.
static constexpr bool have_cuda
Constexpr variable that is true if the project is configured with CUDA support.
Definition device.hpp:202
#define device_error_check(ARG1, ARG2)
Trigger a compilation error every time the nda::device_error_check function is called.
Definition device.hpp:196
static constexpr bool have_device
Constexpr variable that is true if the project is configured with GPU support.
Definition device.hpp:199
void compile_error_no_gpu()
Trigger a compilation error in case GPU specific functionality is used without configuring the projec...
Definition device.hpp:40
void cuda_device_sync(bool do_sync=true, std::string_view func="")
Empty function if CudaSupport is not enabled.
Definition device.hpp:205
std::string to_string(std::array< T, R > const &a)
Get a string representation of a std::array.
Definition array.hpp:52