USRP_Server  2.0
A flexible, GPU-accelerated radio-frequency readout software.
fir.cu
Go to the documentation of this file.
1 
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <math.h>
7 
8 #include <cuda_runtime.h>
9 #include <cublas_v2.h>
10 
11 #include "fir.hpp"
12 
13 #define checkcublas(X) assert( ( X ) == CUBLAS_STATUS_SUCCESS )
14 
15  FIR::FIR(cublasHandle_t handle, cudaStream_t stream, float2 *hcoeff, int M, int f, int nt) :
16  _handle(handle),_stream(stream),_M(M),_f(f)
17 {
18  _ntap = M * f;
19  _nb = nt / M;
20  assert(nt % M == 0);
21  _nout = nt + f - 1;
22  _nt = nt;
23 
24  cudaMalloc(&_dout,_nout*sizeof(float2));
25  assert(_dout != NULL);
26  cudaMemset(&_dout,0,_nout*sizeof(float2));
27 
28  cudaMalloc(&_dcoeff,_ntap*sizeof(float2));
29  assert(_dcoeff != NULL);
30  cudaMemcpy(_dcoeff,hcoeff,_ntap*sizeof(float2),cudaMemcpyHostToDevice);
31 
32  cudaMalloc(&_dtrapz,_nb*_f*sizeof(float2));
33  assert(_dtrapz != NULL);
34 }
35 
37 {
38  cudaFree(_dout);
39  cudaFree(_dtrapz);
40  cudaFree(_dcoeff);
41  memset(this,0,sizeof(*this));
42 }
43 
44 void FIR::fir_apply(const float2 *din)
45 {
46  float2 alpha = {1.0f,0.0f};
47  float2 beta = {0.0f,0.0f};
48  checkcublas(cublasCgemm(_handle,CUBLAS_OP_T,CUBLAS_OP_N,
49  _nb,_f,_M,
50  &alpha,
51  din,_M,
52  _dcoeff,_M,
53  &beta,
54  _dtrapz,_nb));
55 
56  for(int i=0;i<_f;i++) {
57  checkcublas(cublasCaxpy(_handle,_nb,
58  &alpha,
59  &_dtrapz[i*_nb],1,
60  &_dout[_f-i-1],1));
61  }
62 }
63 
65 {
66  int rem = _f - 1;
67  cudaMemcpyAsync(_dout,&_dout[_nb],rem*sizeof(float2),cudaMemcpyDeviceToDevice,_stream);
68  cudaMemsetAsync(&_dout[rem],0,_nb*sizeof(float2),_stream);
69 }
70 
71 void FIR::fir_to_host(float2 *hout)
72 {
73  cudaMemcpyAsync(hout,_dout,_nb*sizeof(float2),cudaMemcpyDeviceToHost,_stream); // M is the decimation factor
74 
75 }
76 
77 // To be refined
79 void FIR::fir_to_dev(float2 *dout)
80 {
81  cudaMemcpyAsync(dout,_dout,_nb*sizeof(float2),cudaMemcpyDeviceToDevice,_stream); // M is the decimation factor
82 }
83 void FIR::run_fir(const float2 *din, float2 *hout)
84 {
85  fir_apply(din);
86  fir_to_dev(hout);
87  fir_shift();
88 }
FIR(cublasHandle_t handle, cudaStream_t stream, float2 *hcoeff, int M, int f, int nt)
Definition: fir.cu:15
float2 * _dout
Definition: fir.hpp:26
~FIR()
Definition: fir.cu:36
float2 * _dtrapz
Definition: fir.hpp:28
void fir_shift()
Definition: fir.cu:64
int _ntap
Definition: fir.hpp:21
float2 * _dcoeff
Definition: fir.hpp:25
int _nb
Definition: fir.hpp:22
int _M
Definition: fir.hpp:19
#define checkcublas(X)
Definition: fir.cu:13
int _nout
Definition: fir.hpp:24
int _nt
Definition: fir.hpp:23
cublasHandle_t _handle
Definition: fir.hpp:17
void fir_apply(const float2 *din)
Definition: fir.cu:44
void fir_to_dev(float2 *dout)
Definition: fir.cu:79
int _f
Definition: fir.hpp:20
cudaStream_t _stream
Definition: fir.hpp:18
void fir_to_host(float2 *hout)
Definition: fir.cu:71
void run_fir(const float2 *din, float2 *hout)
Definition: fir.cu:83