mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/cuda_util.h code style (#641)
parent
055d0270c8
commit
10afec728f
|
@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file,
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void print_vec(const T *outv, std::string outn, int num_output_ele);
|
void print_vec(const T *outv, std::string outn, int num_output_ele);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T> T *cuda_malloc(size_t ele_num);
|
||||||
T *cuda_malloc(size_t ele_num);
|
|
||||||
|
|
||||||
void cuda_free(void *pdata);
|
void cuda_free(void *pdata);
|
||||||
|
|
||||||
|
@ -29,6 +28,6 @@ template <typename T>
|
||||||
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
|
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
|
||||||
std::string file, int line, cudaStream_t stream);
|
std::string file, int line, cudaStream_t stream);
|
||||||
|
|
||||||
#define CHECK_NAN_INF(ptr, size, stream) \
|
#define CHECK_NAN_INF(ptr, size, stream) \
|
||||||
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
|
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
|
||||||
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))
|
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "kernels.h"
|
||||||
|
|
||||||
template <typename T>
|
template <typename T> class Dropout {
|
||||||
class Dropout {
|
public:
|
||||||
public:
|
|
||||||
struct Config {
|
struct Config {
|
||||||
float ratio;
|
float ratio;
|
||||||
bool training;
|
bool training;
|
||||||
|
@ -89,7 +88,7 @@ class Dropout {
|
||||||
|
|
||||||
void SetTrainingMode(bool training) { _config.training = training; }
|
void SetTrainingMode(bool training) { _config.training = training; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint8_t *_mask;
|
uint8_t *_mask;
|
||||||
Config _config;
|
Config _config;
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue