fix format (#611)

pull/673/head
Yuer867 2022-04-01 14:19:27 +08:00 committed by binmakeswell
parent d3d5bedc65
commit 5ecef13c16
2 changed files with 5 additions and 6 deletions

View File

@ -9,7 +9,7 @@
#include "cuda_util.h"
class Context {
public:
public:
Context() : _stream(nullptr) {
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
}
@ -30,7 +30,7 @@ class Context {
cublasHandle_t get_cublashandle() { return _cublasHandle; }
private:
private:
cudaStream_t _stream;
cublasHandle_t _cublasHandle;
};

View File

@ -8,9 +8,8 @@
#include "cuda_util.h"
template <typename T>
class CrossEntropyLayer {
public:
template <typename T> class CrossEntropyLayer {
public:
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
virtual ~CrossEntropyLayer();
@ -23,7 +22,7 @@ class CrossEntropyLayer {
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
private:
private:
void allocate_mem_buffer() {
// allocate local gpu memory
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);