TRIQS/nda 1.3.0
Multi-dimensional array library for C++
Loading...
Searching...
No Matches
bound_check_worker.hpp
Go to the documentation of this file.
1// Copyright (c) 2018 Commissariat à l'énergie atomique et aux énergies alternatives (CEA)
2// Copyright (c) 2018 Centre national de la recherche scientifique (CNRS)
3// Copyright (c) 2018-2022 Simons Foundation
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0.txt
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16//
17// Authors: Olivier Parcollet, Nils Wentzell
18
24#pragma once
25
26#include "./range.hpp"
27
28#include <cstdint>
29#include <stdexcept>
30#include <sstream>
31
32namespace nda::detail {
33
34 // Check the bounds when accessing single elements or slices of an array/view.
35 struct bound_check_worker {
36 // Shape of the array/view.
37 long const *lengths{};
38
39 // Error code to store the positions of the arguments which are out of bounds.
40 uint32_t error_code = 0;
41
42 // Number of dimensions that are covered by a given nda::ellipsis.
43 int ellipsis_loss = 0;
44
45 // Current dimension to be checked.
46 int N = 0;
47
48 // Check if the given index is within the bounds of the array/view.
49 void check_current_dim(long idx) {
50 if ((idx < 0) or (idx >= lengths[N])) { error_code += 1ul << N; }
51 ++N;
52 }
53
54 // Check if the given nda::range is within the bounds of the array/view.
55 void check_current_dim(range const &r) {
56 if (r.size() > 0) {
57 auto first_idx = r.first();
58 auto last_idx = first_idx + (r.size() - 1) * r.step();
59 if (first_idx < 0 or first_idx >= lengths[N] or last_idx < 0 or last_idx >= lengths[N]) error_code += 1ul << N;
60 }
61 ++N;
62 }
63
64 // Check the bounds when an nda::range::all_t is encountered (no need to check anything).
65 void check_current_dim(range::all_t) { ++N; }
66
67 // Check the bounds when an nda::ellipsis is encountered (no need to check anything).
68 void check_current_dim(ellipsis) { N += ellipsis_loss + 1; }
69
70 // Accumulate an error message for the current dimension and index.
71 void accumulate_error_msg(std::stringstream &fs, long idx) {
72 if (error_code & (1ull << N)) fs << "Argument " << N << " = " << idx << " is not within [0," << lengths[N] << "[.\n";
73 N++;
74 }
75
76 // Accumulate an error message for the current dimension and nda::range.
77 void accumulate_error_msg(std::stringstream &fs, range const &r) {
78 if (error_code & (1ull << N)) fs << "Argument " << N << " = " << r << " is not within [0," << lengths[N] << "[.\n";
79 ++N;
80 }
81
82 // Accumulate an error message for the current dimension and nda::range::all_t.
83 void accumulate_error_msg(std::stringstream &, range::all_t) { ++N; }
84
85 // Accumulate an error message for the current dimension and nda::ellipsis.
86 void accumulate_error_msg(std::stringstream &, ellipsis) { N += ellipsis_loss + 1; }
87 };
88
89} // namespace nda::detail
90
91namespace nda {
92
105 template <typename... Args>
106 void assert_in_bounds(int rank, long const *lengths, Args const &...args) {
107 // initialize the bounds checker
108 detail::bound_check_worker w{lengths};
109
110 // number of dimensions that are covered by an nda::ellipsis
111 w.ellipsis_loss = rank - sizeof...(Args);
112
113 // check the bounds on each argument/index
114 (w.check_current_dim(args), ...);
115
116 // if no error, stop here
117 if (!w.error_code) return;
118
119 // accumulate error message and throw
120 w.N = 0;
121 std::stringstream fs;
122 (w.accumulate_error_msg(fs, args), ...);
123 throw std::runtime_error("Index/Range out of bounds:\n" + fs.str());
124 }
125
126} // namespace nda
void assert_in_bounds(int rank, long const *lengths, Args const &...args)
Check if the given indices/arguments are within the bounds of an array/view.
Includes the itertools header and provides some additional utilities.