[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#641)

pull/673/head
Xue Fuzhao 2022-04-02 14:38:40 +08:00 committed by binmakeswell
parent 055d0270c8
commit 10afec728f
2 changed files with 7 additions and 9 deletions

View File

@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file,
template <typename T> template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele); void print_vec(const T *outv, std::string outn, int num_output_ele);
template <typename T> template <typename T> T *cuda_malloc(size_t ele_num);
T *cuda_malloc(size_t ele_num);
void cuda_free(void *pdata); void cuda_free(void *pdata);

View File

@ -1,15 +1,14 @@
#pragma once #pragma once
#include <string>
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include <string>
#include "kernels.h" #include "kernels.h"
template <typename T> template <typename T> class Dropout {
class Dropout { public:
public:
struct Config { struct Config {
float ratio; float ratio;
bool training; bool training;
@ -89,7 +88,7 @@ class Dropout {
void SetTrainingMode(bool training) { _config.training = training; } void SetTrainingMode(bool training) { _config.training = training; }
private: private:
uint8_t *_mask; uint8_t *_mask;
Config _config; Config _config;
}; };