28 template <Channel_t CH> g2_nn_t inverse(g2_nn_cvt g) {
30 auto g_inv = make_gf(g);
32 using dat_t = array<g2_nn_cvt::scalar_t, g2_nn_cvt::data_rank, channel_memory_layout<CH>>;
34 auto g_w_dat = dat_t{g.data()};
35 auto g_w_inv_dat = dat_t{g.data()};
37 auto mat = channel_matrix_view<CH>(g_w_dat);
38 auto mat_inv = channel_matrix_view<CH>(g_w_inv_dat);
40 mat_inv = nda::linalg::inv(mat);
42 g_inv.data() = g_w_inv_dat;
48 template <Channel_t CH> g2_iw_t inverse(g2_iw_cvt g) {
51 auto g_inv = make_gf(g);
53 for (
auto w : std::get<0>(g.mesh())) {
55 g_inv[w, _, _] = inverse<CH>(g[w, _, _]);
62 template <Channel_t CH> g2_nn_t product(g2_nn_cvt A, g2_nn_cvt B) {
66 using dat_t = array<g2_nn_cvt::scalar_t, g2_nn_cvt::data_rank, channel_memory_layout<CH>>;
68 auto A_w_dat = dat_t{A.data()};
69 auto B_w_dat = dat_t{B.data()};
70 auto C_w_dat = dat_t{A_w_dat};
72 auto A_mat = channel_matrix_view<CH>(A_w_dat);
73 auto B_mat = channel_matrix_view<CH>(B_w_dat);
74 auto C_mat = channel_matrix_view<CH>(C_w_dat);
76 C_mat = A_mat * B_mat;
84 template <Channel_t CH> g2_iw_t product(g2_iw_cvt A, g2_iw_cvt B) {
88 for (
auto w : std::get<0>(A.mesh())) {
90 C[w, _, _] = product<CH>(A[w, _, _], B[w, _, _]);
97 template <Channel_t CH> g2_nn_t identity(g2_nn_cvt g) {
99 using dat_t = array<g2_nn_cvt::scalar_t, g2_nn_cvt::data_rank, channel_memory_layout<CH>>;
103 auto I_w_dat = dat_t{I.data()};
104 auto I_mat = channel_matrix_view<CH>(I_w_dat);
113 template <Channel_t CH> g2_iw_t identity(g2_iw_cvt g) {
117 for (
auto w : std::get<0>(g.mesh())) {
119 I[w, _, _] = identity<CH>(I[w, _, _]);
126 template g2_nn_t inverse<Channel_t::PH>(g2_nn_cvt);
127 template g2_nn_t inverse<Channel_t::PH_bar>(g2_nn_cvt);
128 template g2_nn_t inverse<Channel_t::PP>(g2_nn_cvt);
130 g2_nn_t inverse_PH(g2_nn_vt g) {
return inverse<Channel_t::PH>(g); }
131 g2_nn_t inverse_PP(g2_nn_vt g) {
return inverse<Channel_t::PP>(g); }
132 g2_nn_t inverse_PH_bar(g2_nn_vt g) {
return inverse<Channel_t::PH_bar>(g); }
134 template g2_iw_t inverse<Channel_t::PH>(g2_iw_cvt);
135 template g2_iw_t inverse<Channel_t::PH_bar>(g2_iw_cvt);
136 template g2_iw_t inverse<Channel_t::PP>(g2_iw_cvt);
138 g2_iw_t inverse_PH(g2_iw_vt g) {
return inverse<Channel_t::PH>(g); }
139 g2_iw_t inverse_PP(g2_iw_vt g) {
return inverse<Channel_t::PP>(g); }
140 g2_iw_t inverse_PH_bar(g2_iw_vt g) {
return inverse<Channel_t::PH_bar>(g); }
144 template g2_nn_t product<Channel_t::PH>(g2_nn_cvt, g2_nn_cvt);
145 template g2_nn_t product<Channel_t::PH_bar>(g2_nn_cvt, g2_nn_cvt);
146 template g2_nn_t product<Channel_t::PP>(g2_nn_cvt, g2_nn_cvt);
148 g2_nn_t product_PH(g2_nn_vt A, g2_nn_vt B) {
return product<Channel_t::PH>(A, B); }
149 g2_nn_t product_PP(g2_nn_vt A, g2_nn_vt B) {
return product<Channel_t::PP>(A, B); }
150 g2_nn_t product_PH_bar(g2_nn_vt A, g2_nn_vt B) {
return product<Channel_t::PH_bar>(A, B); }
152 template g2_iw_t product<Channel_t::PH>(g2_iw_cvt, g2_iw_cvt);
153 template g2_iw_t product<Channel_t::PH_bar>(g2_iw_cvt, g2_iw_cvt);
154 template g2_iw_t product<Channel_t::PP>(g2_iw_cvt, g2_iw_cvt);
156 g2_iw_t product_PH(g2_iw_vt A, g2_iw_vt B) {
return product<Channel_t::PH>(A, B); }
157 g2_iw_t product_PP(g2_iw_vt A, g2_iw_vt B) {
return product<Channel_t::PP>(A, B); }
158 g2_iw_t product_PH_bar(g2_iw_vt A, g2_iw_vt B) {
return product<Channel_t::PH_bar>(A, B); }
162 template g2_nn_t identity<Channel_t::PH>(g2_nn_cvt);
163 template g2_nn_t identity<Channel_t::PH_bar>(g2_nn_cvt);
164 template g2_nn_t identity<Channel_t::PP>(g2_nn_cvt);
166 g2_nn_t identity_PH(g2_nn_vt g) {
return identity<Channel_t::PH>(g); }
167 g2_nn_t identity_PP(g2_nn_vt g) {
return identity<Channel_t::PP>(g); }
168 g2_nn_t identity_PH_bar(g2_nn_vt g) {
return identity<Channel_t::PH_bar>(g); }
170 template g2_iw_t identity<Channel_t::PH>(g2_iw_cvt);
171 template g2_iw_t identity<Channel_t::PH_bar>(g2_iw_cvt);
172 template g2_iw_t identity<Channel_t::PP>(g2_iw_cvt);
174 g2_iw_t identity_PH(g2_iw_vt g) {
return identity<Channel_t::PH>(g); }
175 g2_iw_t identity_PP(g2_iw_vt g) {
return identity<Channel_t::PP>(g); }
176 g2_iw_t identity_PH_bar(g2_iw_vt g) {
return identity<Channel_t::PH_bar>(g); }
180 array<g2_nn_cvt::scalar_t, 4> scalar_product_PH(g2_n_cvt vL, g2_nn_cvt M, g2_n_cvt vR) {
182 using res_layout = contiguous_layout_with_stride_order<nda::encode(std::array{0, 1, 2, 3})>;
183 using L_layout = contiguous_layout_with_stride_order<nda::encode(std::array{4, 3, 0, 1, 2})>;
184 using R_layout = contiguous_layout_with_stride_order<nda::encode(std::array{0, 2, 1, 3, 4})>;
185 using mat_layout = contiguous_layout_with_stride_order<nda::encode(std::array{0, 2, 3, 1, 4, 5})>;
187 using L_vec_t = array<g2_n_cvt::scalar_t, g2_n_cvt::data_rank, L_layout>;
188 using R_vec_t = array<g2_n_cvt::scalar_t, g2_n_cvt::data_rank, R_layout>;
190 using mat_t = array<g2_nn_cvt::scalar_t, g2_nn_cvt::data_rank, mat_layout>;
192 array<g2_nn_cvt::scalar_t, 4, res_layout> res_dat(M.target_shape());
194 auto M_dat = mat_t{M.data()};
195 auto vL_dat = L_vec_t{vL.data()};
196 auto vR_dat = R_vec_t{vR.data()};
198 matrix_view<g2_n_cvt::scalar_t> res(group_indices_view(res_dat, idx_group<0, 1>, idx_group<2, 3>));
199 matrix_view<g2_n_cvt::scalar_t> vecL(group_indices_view(vL_dat, idx_group<4, 3>, idx_group<0, 1, 2>));
200 matrix_view<g2_n_cvt::scalar_t> vecR(group_indices_view(vR_dat, idx_group<0, 2, 1>, idx_group<3, 4>));
201 matrix_view<g2_n_cvt::scalar_t> mat(group_indices_view(M_dat, idx_group<0, 2, 3>, idx_group<1, 4, 5>));
204 res = vecL * mat * vecR;