OPALX (Object Oriented Parallel Accelerator Library for Exascal) master (dc2a29eed580)
OPALX
Loading...
Searching...
No Matches
GSLBLAS.h
Go to the documentation of this file.
1//
2// GSL BLAS compatibility to replace gsl_blas
3//
4// Copyright (c) 2023, Paul Scherrer Institute, Villigen PSI, Switzerland
5// All rights reserved
6//
7// This file is part of OPAL.
8//
9// OPAL is free software: you can redistribute it and/or modify
10// it under the terms of the GNU General Public License as published by
11// the Free Software Foundation, either version 3 of the License, or
12// (at your option) any later version.
13//
14// You should have received a copy of the GNU General License
15// along with OPAL. If not, see <https://www.gnu.org/licenses/>.
16//
17
18#ifndef OPAL_GSL_BLAS_HH
19#define OPAL_GSL_BLAS_HH
20
21#include <cstring>
23#include "Utilities/GSLMatrix.h"
24
33
44inline void gsl_blas_dgemm(
45 CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, double alpha, const gsl_matrix* A,
46 const gsl_matrix* B, double beta, gsl_matrix* C) {
47 size_t m = (TransA == CblasNoTrans) ? A->size1 : A->size2;
48 size_t n = (TransB == CblasNoTrans) ? B->size2 : B->size1;
49 size_t k = (TransA == CblasNoTrans) ? A->size2 : A->size1;
50
51 if (C->size1 != m || C->size2 != n) {
52 throw std::runtime_error("gsl_blas_dgemm: size mismatch");
53 }
54
55 // Initialize C with beta*C
56 if (beta != 1.0) {
57 for (size_t i = 0; i < m; ++i) {
58 for (size_t j = 0; j < n; ++j) {
59 *gsl_matrix_ptr(C, i, j) *= beta;
60 }
61 }
62 }
63
64 // Perform multiplication
65 for (size_t i = 0; i < m; ++i) {
66 for (size_t j = 0; j < n; ++j) {
67 double sum = 0.0;
68 for (size_t l = 0; l < k; ++l) {
69 double a_val = (TransA == CblasNoTrans) ? *gsl_matrix_ptr(A, i, l)
70 : *gsl_matrix_ptr(A, l, i);
71 double b_val = (TransB == CblasNoTrans) ? *gsl_matrix_ptr(B, l, j)
72 : *gsl_matrix_ptr(B, j, l);
73 sum += a_val * b_val;
74 }
75 *gsl_matrix_ptr(C, i, j) += alpha * sum;
76 }
77 }
78}
79
90inline void gsl_blas_zgemm(
91 CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, gsl_complex alpha,
92 const gsl_matrix_complex* A, const gsl_matrix_complex* B, gsl_complex beta,
94 size_t m = (TransA == CblasNoTrans) ? A->size1 : A->size2;
95 size_t n = (TransB == CblasNoTrans) ? B->size2 : B->size1;
96 size_t k = (TransA == CblasNoTrans) ? A->size2 : A->size1;
97
98 if (C->size1 != m || C->size2 != n) {
99 throw std::runtime_error("gsl_blas_zgemm: size mismatch");
100 }
101
102 // Initialize C with beta*C
103 if (beta.dat[0] != 1.0 || beta.dat[1] != 0.0) {
104 for (size_t i = 0; i < m; ++i) {
105 for (size_t j = 0; j < n; ++j) {
106 *gsl_matrix_complex_ptr(C, i, j) =
108 }
109 }
110 }
111
112 // Perform multiplication
113 for (size_t i = 0; i < m; ++i) {
114 for (size_t j = 0; j < n; ++j) {
115 gsl_complex sum = gsl_complex(0.0, 0.0);
116 for (size_t l = 0; l < k; ++l) {
117 gsl_complex a_val = (TransA == CblasNoTrans) ? *gsl_matrix_complex_ptr(A, i, l)
118 : *gsl_matrix_complex_ptr(A, l, i);
119 if (TransA == CblasConjTrans) {
120 a_val = gsl_complex_conjugate(a_val);
121 }
122 gsl_complex b_val = (TransB == CblasNoTrans) ? *gsl_matrix_complex_ptr(B, l, j)
123 : *gsl_matrix_complex_ptr(B, j, l);
124 if (TransB == CblasConjTrans) {
125 b_val = gsl_complex_conjugate(b_val);
126 }
127 sum = gsl_complex_add(sum, gsl_complex_mul(a_val, b_val));
128 }
129 *gsl_matrix_complex_ptr(C, i, j) =
131 }
132 }
133}
134
144inline void gsl_blas_dgemv(
145 CBLAS_TRANSPOSE TransA, double alpha, const gsl_matrix* A, const gsl_vector* x, double beta,
146 gsl_vector* y) {
147 size_t m = A->size1;
148 size_t n = A->size2;
149
150 if (TransA == CblasNoTrans) {
151 if (x->size != n || y->size != m) {
152 throw std::runtime_error("gsl_blas_dgemv: size mismatch");
153 }
154 // y = beta*y
155 for (size_t i = 0; i < m; ++i) {
156 *gsl_vector_ptr(y, i) *= beta;
157 }
158 // y += alpha*A*x
159 for (size_t i = 0; i < m; ++i) {
160 double sum = 0.0;
161 for (size_t j = 0; j < n; ++j) {
162 sum += *gsl_matrix_ptr(A, i, j) * *gsl_vector_ptr(x, j);
163 }
164 *gsl_vector_ptr(y, i) += alpha * sum;
165 }
166 } else {
167 if (x->size != m || y->size != n) {
168 throw std::runtime_error("gsl_blas_dgemv: size mismatch");
169 }
170 // y = beta*y
171 for (size_t i = 0; i < n; ++i) {
172 *gsl_vector_ptr(y, i) *= beta;
173 }
174 // y += alpha*A^T*x
175 for (size_t i = 0; i < n; ++i) {
176 double sum = 0.0;
177 for (size_t j = 0; j < m; ++j) {
178 sum += *gsl_matrix_ptr(A, j, i) * *gsl_vector_ptr(x, j);
179 }
180 *gsl_vector_ptr(y, i) += alpha * sum;
181 }
182 }
183}
184
194inline void gsl_blas_zgemv(
195 CBLAS_TRANSPOSE TransA, gsl_complex alpha, const gsl_matrix_complex* A,
197 size_t m = A->size1;
198 size_t n = A->size2;
199
200 if (TransA == CblasNoTrans) {
201 if (x->size != n || y->size != m) {
202 throw std::runtime_error("gsl_blas_zgemv: size mismatch");
203 }
204 // y = beta*y
205 for (size_t i = 0; i < m; ++i) {
207 }
208 // y += alpha*A*x
209 for (size_t i = 0; i < m; ++i) {
210 gsl_complex sum = gsl_complex(0.0, 0.0);
211 for (size_t j = 0; j < n; ++j) {
212 sum = gsl_complex_add(
213 sum,
216 }
219 }
220 } else {
221 if (x->size != m || y->size != n) {
222 throw std::runtime_error("gsl_blas_zgemv: size mismatch");
223 }
224 // y = beta*y
225 for (size_t i = 0; i < n; ++i) {
227 }
228 // y += alpha*A^T*x or A^H*x
229 for (size_t i = 0; i < n; ++i) {
230 gsl_complex sum = gsl_complex(0.0, 0.0);
231 for (size_t j = 0; j < m; ++j) {
232 gsl_complex a_val = *gsl_matrix_complex_ptr(A, j, i);
233 if (TransA == CblasConjTrans) {
234 a_val = gsl_complex_conjugate(a_val);
235 }
236 sum = gsl_complex_add(sum, gsl_complex_mul(a_val, *gsl_vector_complex_ptr(x, j)));
237 }
240 }
241 }
242}
243
244#endif // OPAL_GSL_BLAS_HH
void gsl_blas_zgemm(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, gsl_complex alpha, const gsl_matrix_complex *A, const gsl_matrix_complex *B, gsl_complex beta, gsl_matrix_complex *C)
Complex matrix-matrix multiply and accumulate.
Definition GSLBLAS.h:90
CBLAS_TRANSPOSE
Transpose operation selector for BLAS routines.
Definition GSLBLAS.h:32
@ CblasNoTrans
Definition GSLBLAS.h:32
@ CblasTrans
Definition GSLBLAS.h:32
@ CblasConjTrans
Definition GSLBLAS.h:32
void gsl_blas_zgemv(CBLAS_TRANSPOSE TransA, gsl_complex alpha, const gsl_matrix_complex *A, const gsl_vector_complex *x, gsl_complex beta, gsl_vector_complex *y)
Complex matrix-vector multiply and accumulate.
Definition GSLBLAS.h:194
void gsl_blas_dgemv(CBLAS_TRANSPOSE TransA, double alpha, const gsl_matrix *A, const gsl_vector *x, double beta, gsl_vector *y)
Real matrix-vector multiply and accumulate.
Definition GSLBLAS.h:144
void gsl_blas_dgemm(CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, double alpha, const gsl_matrix *A, const gsl_matrix *B, double beta, gsl_matrix *C)
Real matrix-matrix multiply and accumulate.
Definition GSLBLAS.h:44
gsl_complex gsl_complex_mul(gsl_complex a, gsl_complex b)
Product .
Definition GSLComplex.h:97
gsl_complex gsl_complex_conjugate(gsl_complex a)
Complex conjugate .
Definition GSLComplex.h:151
gsl_complex gsl_complex_add(gsl_complex a, gsl_complex b)
Sum .
Definition GSLComplex.h:79
gsl_complex * gsl_vector_complex_ptr(gsl_vector_complex *v, size_t i)
Return pointer to element in a complex vector.
Definition GSLMatrix.h:252
double * gsl_vector_ptr(gsl_vector *v, size_t i)
Return pointer to element in a real vector.
Definition GSLMatrix.h:236
double * gsl_matrix_ptr(gsl_matrix *m, size_t i, size_t j)
Return pointer to element in a real matrix.
Definition GSLMatrix.h:199
gsl_complex * gsl_matrix_complex_ptr(gsl_matrix_complex *m, size_t i, size_t j)
Return pointer to element in a complex matrix.
Definition GSLMatrix.h:217
Complex number stored as .
Definition GSLComplex.h:27
double dat[2]
Definition GSLComplex.h:28
Dense complex matrix in row-major storage.
Definition GSLMatrix.h:45
Dense real matrix in row-major storage.
Definition GSLMatrix.h:29
size_t size1
Definition GSLMatrix.h:30
size_t size2
Definition GSLMatrix.h:31
Dense complex vector with stride.
Definition GSLMatrix.h:76
Dense real vector with stride.
Definition GSLMatrix.h:61
size_t size
Definition GSLMatrix.h:62