TRIQS/TRIQS 4.0.0
Researching Interacting Quantum Systems
Loading...
Searching...
No Matches
generator.hpp
1
2// Copyright (c) Lewis Baker
3// Licenced under MIT license. See LICENSE.txt for details.
5#ifndef CPPCORO_GENERATOR_HPP_INCLUDED
6#define CPPCORO_GENERATOR_HPP_INCLUDED
7
8#include "coroutine.hpp"
9#include <type_traits>
10#include <utility>
11#include <exception>
12#include <iterator>
13#include <functional>
14
15namespace cppcoro
16{
17 template<typename T>
18 class generator;
19
20 namespace detail
21 {
22 template<typename T>
23 class generator_promise
24 {
25 public:
26
27 using value_type = std::remove_reference_t<T>;
28 using reference_type = std::conditional_t<std::is_reference_v<T>, T, T&>;
29 using pointer_type = value_type*;
30
31 generator_promise() = default;
32
33 generator<T> get_return_object() noexcept;
34
35 constexpr cppcoro::suspend_always initial_suspend() const noexcept { return {}; }
36 constexpr cppcoro::suspend_always final_suspend() const noexcept { return {}; }
37
38 template<
39 typename U = T,
40 std::enable_if_t<!std::is_rvalue_reference<U>::value, int> = 0>
41 cppcoro::suspend_always yield_value(std::remove_reference_t<T>& value) noexcept
42 {
43 m_value = std::addressof(value);
44 return {};
45 }
46
47 cppcoro::suspend_always yield_value(std::remove_reference_t<T>&& value) noexcept
48 {
49 m_value = std::addressof(value);
50 return {};
51 }
52
53 void unhandled_exception()
54 {
55 m_exception = std::current_exception();
56 }
57
58 void return_void()
59 {
60 }
61
62 reference_type value() const noexcept
63 {
64 return static_cast<reference_type>(*m_value);
65 }
66
67 // Don't allow any use of 'co_await' inside the generator coroutine.
68 template<typename U>
69 cppcoro::suspend_never await_transform(U&& value) = delete;
70
71 void rethrow_if_exception()
72 {
73 if (m_exception)
74 {
75 std::rethrow_exception(m_exception);
76 }
77 }
78
79 private:
80
81 pointer_type m_value{};
82 std::exception_ptr m_exception{};
83
84 };
85
86 struct generator_sentinel {};
87
88 template<typename T>
89 class generator_iterator
90 {
91 using coroutine_handle = cppcoro::coroutine_handle<generator_promise<T>>;
92
93 public:
94
95 using iterator_category = std::input_iterator_tag;
96 // What type should we use for counting elements of a potentially infinite sequence?
97 using difference_type = std::ptrdiff_t;
98 using value_type = typename generator_promise<T>::value_type;
99 using reference = typename generator_promise<T>::reference_type;
100 using pointer = typename generator_promise<T>::pointer_type;
101
102 // Iterator needs to be default-constructible to satisfy the Range concept.
103 generator_iterator() noexcept
104 : m_coroutine(nullptr)
105 {}
106
107 explicit generator_iterator(coroutine_handle coroutine) noexcept
108 : m_coroutine(coroutine)
109 {}
110
111 friend bool operator==(const generator_iterator& it, generator_sentinel) noexcept
112 {
113 return !it.m_coroutine || it.m_coroutine.done();
114 }
115
116 friend bool operator!=(const generator_iterator& it, generator_sentinel s) noexcept
117 {
118 return !(it == s);
119 }
120
121 friend bool operator==(generator_sentinel s, const generator_iterator& it) noexcept
122 {
123 return (it == s);
124 }
125
126 friend bool operator!=(generator_sentinel s, const generator_iterator& it) noexcept
127 {
128 return it != s;
129 }
130
131 generator_iterator& operator++()
132 {
133 m_coroutine.resume();
134 if (m_coroutine.done())
135 {
136 m_coroutine.promise().rethrow_if_exception();
137 }
138
139 return *this;
140 }
141
142 // Need to provide post-increment operator to implement the 'Range' concept.
143 void operator++(int)
144 {
145 (void)operator++();
146 }
147
148 reference operator*() const noexcept
149 {
150 return m_coroutine.promise().value();
151 }
152
153 pointer operator->() const noexcept
154 {
155 return std::addressof(operator*());
156 }
157
158 private:
159
160 coroutine_handle m_coroutine;
161 };
162 }
163
164 template<typename T>
165 class [[nodiscard]] generator
166 {
167 public:
168
169 using promise_type = detail::generator_promise<T>;
170 using iterator = detail::generator_iterator<T>;
171
172 generator() noexcept
173 : m_coroutine(nullptr)
174 {}
175
176 generator(generator&& other) noexcept
177 : m_coroutine(other.m_coroutine)
178 {
179 other.m_coroutine = nullptr;
180 }
181
182 generator(const generator& other) = delete;
183
184 ~generator()
185 {
186 if (m_coroutine)
187 {
188 m_coroutine.destroy();
189 }
190 }
191
192 generator& operator=(generator other) noexcept
193 {
194 swap(other);
195 return *this;
196 }
197
199 {
200 if (m_coroutine)
201 {
202 m_coroutine.resume();
203 if (m_coroutine.done())
204 {
205 m_coroutine.promise().rethrow_if_exception();
206 }
207 }
208
209 return iterator{ m_coroutine };
210 }
211
212 detail::generator_sentinel end() noexcept
213 {
214 return detail::generator_sentinel{};
215 }
216
217 void swap(generator& other) noexcept
218 {
219 std::swap(m_coroutine, other.m_coroutine);
220 }
221
222 private:
223
224 friend class detail::generator_promise<T>;
225
226 explicit generator(cppcoro::coroutine_handle<promise_type> coroutine) noexcept
227 : m_coroutine(coroutine)
228 {}
229
230 cppcoro::coroutine_handle<promise_type> m_coroutine;
231
232 };
233
234 template<typename T>
235 void swap(generator<T>& a, generator<T>& b)
236 {
237 a.swap(b);
238 }
239
240 namespace detail
241 {
242 template<typename T>
243 generator<T> generator_promise<T>::get_return_object() noexcept
244 {
245 using coroutine_handle = cppcoro::coroutine_handle<generator_promise<T>>;
246 return generator<T>{ coroutine_handle::from_promise(*this) };
247 }
248 }
249
250 template<typename FUNC, typename T>
251 generator<std::invoke_result_t<FUNC&, typename generator<T>::iterator::reference>> fmap(FUNC func, generator<T> source)
252 {
253 for (auto&& value : source)
254 {
255 co_yield std::invoke(func, static_cast<decltype(value)>(value));
256 }
257 }
258}
259
260#endif
iterator_impl< false > iterator
Mutable block iterator type.
iterator end()
Get an iterator past the last block.
iterator begin()
Get an iterator to the first block.