2021-12-21 04:19:52 +00:00
|
|
|
#pragma once
|
|
|
|
|
|
|
|
#include <cublas_v2.h>
|
|
|
|
#include <cuda.h>
|
|
|
|
|
|
|
|
#include <iostream>
|
|
|
|
#include <string>
|
|
|
|
|
|
|
|
#include "cuda_util.h"
|
|
|
|
|
|
|
|
class Context {
|
2022-05-15 00:59:50 +00:00
|
|
|
public:
|
2021-12-21 04:19:52 +00:00
|
|
|
Context() : _stream(nullptr) {
|
|
|
|
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
|
|
|
|
}
|
|
|
|
|
|
|
|
virtual ~Context() {}
|
|
|
|
|
|
|
|
static Context &Instance() {
|
|
|
|
static Context _ctx;
|
|
|
|
return _ctx;
|
|
|
|
}
|
|
|
|
|
|
|
|
void set_stream(cudaStream_t stream) {
|
|
|
|
_stream = stream;
|
|
|
|
CHECK_GPU_ERROR(cublasSetStream(_cublasHandle, _stream));
|
|
|
|
}
|
|
|
|
|
|
|
|
cudaStream_t get_stream() { return _stream; }
|
|
|
|
|
|
|
|
cublasHandle_t get_cublashandle() { return _cublasHandle; }
|
|
|
|
|
2022-05-15 00:59:50 +00:00
|
|
|
private:
|
2021-12-21 04:19:52 +00:00
|
|
|
cudaStream_t _stream;
|
|
|
|
cublasHandle_t _cublasHandle;
|
|
|
|
};
|