TRIQS/triqs_tprf 4.0.0
A TRIQS application
Loading...
Searching...
No Matches
linalg.cpp
1/*******************************************************************************
2 *
3 * TRIQS: a Toolbox for Research in Interacting Quantum Systems
4 *
5 * Copyright (C) 2017, H. U.R. Strand
6 *
7 * TRIQS is free software: you can redistribute it and/or modify it under the
8 * terms of the GNU General Public License as published by the Free Software
9 * Foundation, either version 3 of the License, or (at your option) any later
10 * version.
11 *
12 * TRIQS is distributed in the hope that it will be useful, but WITHOUT ANY
13 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
14 * FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
15 * details.
16 *
17 * You should have received a copy of the GNU General Public License along with
18 * TRIQS. If not, see <http://www.gnu.org/licenses/>.
19 *
20 ******************************************************************************/
21
22#include "linalg.hpp"
23
24namespace triqs_tprf {
25
26 // ----------------------------------------------------
27
28 template <Channel_t CH> g2_nn_t inverse(g2_nn_cvt g) {
29
30 auto g_inv = make_gf(g);
31
32 using dat_t = array<g2_nn_cvt::scalar_t, g2_nn_cvt::data_rank, channel_memory_layout<CH>>;
33
34 auto g_w_dat = dat_t{g.data()};
35 auto g_w_inv_dat = dat_t{g.data()};
36
37 auto mat = channel_matrix_view<CH>(g_w_dat);
38 auto mat_inv = channel_matrix_view<CH>(g_w_inv_dat);
39
40 mat_inv = nda::linalg::inv(mat);
41
42 g_inv.data() = g_w_inv_dat;
43
44 return g_inv;
45 }
46
48 template <Channel_t CH> g2_iw_t inverse(g2_iw_cvt g) {
49
50 //channel_grouping<CH> chg;
51 auto g_inv = make_gf(g);
52
53 for (auto w : std::get<0>(g.mesh())) {
54 auto _ = all_t{};
55 g_inv[w, _, _] = inverse<CH>(g[w, _, _]);
56 }
57 return g_inv;
58 }
59
60 // ----------------------------------------------------
61
62 template <Channel_t CH> g2_nn_t product(g2_nn_cvt A, g2_nn_cvt B) {
63
64 auto C = make_gf(A);
65
66 using dat_t = array<g2_nn_cvt::scalar_t, g2_nn_cvt::data_rank, channel_memory_layout<CH>>;
67
68 auto A_w_dat = dat_t{A.data()};
69 auto B_w_dat = dat_t{B.data()};
70 auto C_w_dat = dat_t{A_w_dat};
71
72 auto A_mat = channel_matrix_view<CH>(A_w_dat);
73 auto B_mat = channel_matrix_view<CH>(B_w_dat);
74 auto C_mat = channel_matrix_view<CH>(C_w_dat);
75
76 C_mat = A_mat * B_mat;
77
78 C.data() = C_w_dat;
79
80 return C;
81 }
82
84 template <Channel_t CH> g2_iw_t product(g2_iw_cvt A, g2_iw_cvt B) {
85
86 auto C = make_gf(A);
87
88 for (auto w : std::get<0>(A.mesh())) {
89 auto _ = all_t{};
90 C[w, _, _] = product<CH>(A[w, _, _], B[w, _, _]);
91 }
92 return C;
93 }
94
95 // ----------------------------------------------------
96
97 template <Channel_t CH> g2_nn_t identity(g2_nn_cvt g) {
98
99 using dat_t = array<g2_nn_cvt::scalar_t, g2_nn_cvt::data_rank, channel_memory_layout<CH>>;
100
101 auto I = make_gf(g);
102
103 auto I_w_dat = dat_t{I.data()};
104 auto I_mat = channel_matrix_view<CH>(I_w_dat);
105
106 I_mat = 1.0; // This sets a nda::matrix to the identity matrix...
107 I.data() = I_w_dat;
108
109 return I;
110 }
111
113 template <Channel_t CH> g2_iw_t identity(g2_iw_cvt g) {
114
115 auto I = make_gf(g);
116
117 for (auto w : std::get<0>(g.mesh())) {
118 auto _ = all_t{};
119 I[w, _, _] = identity<CH>(I[w, _, _]);
120 }
121 return I;
122 }
123
124 // ----------------------------------------------------
125
126 template g2_nn_t inverse<Channel_t::PH>(g2_nn_cvt);
127 template g2_nn_t inverse<Channel_t::PH_bar>(g2_nn_cvt);
128 template g2_nn_t inverse<Channel_t::PP>(g2_nn_cvt);
129
130 g2_nn_t inverse_PH(g2_nn_vt g) { return inverse<Channel_t::PH>(g); }
131 g2_nn_t inverse_PP(g2_nn_vt g) { return inverse<Channel_t::PP>(g); }
132 g2_nn_t inverse_PH_bar(g2_nn_vt g) { return inverse<Channel_t::PH_bar>(g); }
133
134 template g2_iw_t inverse<Channel_t::PH>(g2_iw_cvt);
135 template g2_iw_t inverse<Channel_t::PH_bar>(g2_iw_cvt);
136 template g2_iw_t inverse<Channel_t::PP>(g2_iw_cvt);
137
138 g2_iw_t inverse_PH(g2_iw_vt g) { return inverse<Channel_t::PH>(g); }
139 g2_iw_t inverse_PP(g2_iw_vt g) { return inverse<Channel_t::PP>(g); }
140 g2_iw_t inverse_PH_bar(g2_iw_vt g) { return inverse<Channel_t::PH_bar>(g); }
141
142 // ----------------------------------------------------
143
144 template g2_nn_t product<Channel_t::PH>(g2_nn_cvt, g2_nn_cvt);
145 template g2_nn_t product<Channel_t::PH_bar>(g2_nn_cvt, g2_nn_cvt);
146 template g2_nn_t product<Channel_t::PP>(g2_nn_cvt, g2_nn_cvt);
147
148 g2_nn_t product_PH(g2_nn_vt A, g2_nn_vt B) { return product<Channel_t::PH>(A, B); }
149 g2_nn_t product_PP(g2_nn_vt A, g2_nn_vt B) { return product<Channel_t::PP>(A, B); }
150 g2_nn_t product_PH_bar(g2_nn_vt A, g2_nn_vt B) { return product<Channel_t::PH_bar>(A, B); }
151
152 template g2_iw_t product<Channel_t::PH>(g2_iw_cvt, g2_iw_cvt);
153 template g2_iw_t product<Channel_t::PH_bar>(g2_iw_cvt, g2_iw_cvt);
154 template g2_iw_t product<Channel_t::PP>(g2_iw_cvt, g2_iw_cvt);
155
156 g2_iw_t product_PH(g2_iw_vt A, g2_iw_vt B) { return product<Channel_t::PH>(A, B); }
157 g2_iw_t product_PP(g2_iw_vt A, g2_iw_vt B) { return product<Channel_t::PP>(A, B); }
158 g2_iw_t product_PH_bar(g2_iw_vt A, g2_iw_vt B) { return product<Channel_t::PH_bar>(A, B); }
159
160 // ----------------------------------------------------
161
162 template g2_nn_t identity<Channel_t::PH>(g2_nn_cvt);
163 template g2_nn_t identity<Channel_t::PH_bar>(g2_nn_cvt);
164 template g2_nn_t identity<Channel_t::PP>(g2_nn_cvt);
165
166 g2_nn_t identity_PH(g2_nn_vt g) { return identity<Channel_t::PH>(g); }
167 g2_nn_t identity_PP(g2_nn_vt g) { return identity<Channel_t::PP>(g); }
168 g2_nn_t identity_PH_bar(g2_nn_vt g) { return identity<Channel_t::PH_bar>(g); }
169
170 template g2_iw_t identity<Channel_t::PH>(g2_iw_cvt);
171 template g2_iw_t identity<Channel_t::PH_bar>(g2_iw_cvt);
172 template g2_iw_t identity<Channel_t::PP>(g2_iw_cvt);
173
174 g2_iw_t identity_PH(g2_iw_vt g) { return identity<Channel_t::PH>(g); }
175 g2_iw_t identity_PP(g2_iw_vt g) { return identity<Channel_t::PP>(g); }
176 g2_iw_t identity_PH_bar(g2_iw_vt g) { return identity<Channel_t::PH_bar>(g); }
177
178 // ----------------------------------------------------
179
180 array<g2_nn_cvt::scalar_t, 4> scalar_product_PH(g2_n_cvt vL, g2_nn_cvt M, g2_n_cvt vR) {
181
182 using res_layout = contiguous_layout_with_stride_order<nda::encode(std::array{0, 1, 2, 3})>;
183 using L_layout = contiguous_layout_with_stride_order<nda::encode(std::array{4, 3, 0, 1, 2})>;
184 using R_layout = contiguous_layout_with_stride_order<nda::encode(std::array{0, 2, 1, 3, 4})>;
185 using mat_layout = contiguous_layout_with_stride_order<nda::encode(std::array{0, 2, 3, 1, 4, 5})>;
186
187 using L_vec_t = array<g2_n_cvt::scalar_t, g2_n_cvt::data_rank, L_layout>;
188 using R_vec_t = array<g2_n_cvt::scalar_t, g2_n_cvt::data_rank, R_layout>;
189 //using mat_t = array<g2_nn_cvt::scalar_t, g2_nn_cvt::data_rank, channel_memory_layout<Channel_t::PH>>;
190 using mat_t = array<g2_nn_cvt::scalar_t, g2_nn_cvt::data_rank, mat_layout>;
191
192 array<g2_nn_cvt::scalar_t, 4, res_layout> res_dat(M.target_shape());
193
194 auto M_dat = mat_t{M.data()};
195 auto vL_dat = L_vec_t{vL.data()};
196 auto vR_dat = R_vec_t{vR.data()};
197
198 matrix_view<g2_n_cvt::scalar_t> res(group_indices_view(res_dat, idx_group<0, 1>, idx_group<2, 3>));
199 matrix_view<g2_n_cvt::scalar_t> vecL(group_indices_view(vL_dat, idx_group<4, 3>, idx_group<0, 1, 2>));
200 matrix_view<g2_n_cvt::scalar_t> vecR(group_indices_view(vR_dat, idx_group<0, 2, 1>, idx_group<3, 4>));
201 matrix_view<g2_n_cvt::scalar_t> mat(group_indices_view(M_dat, idx_group<0, 2, 3>, idx_group<1, 4, 5>));
202 //auto mat = channel_matrix_view<Channel_t::PH>(M_dat);
203
204 res = vecL * mat * vecR;
205
206 return res_dat;
207 }
208
209} // namespace triqs_tprf