mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/kernel/cuda_native/csrc/kernels/include/dropout.h code style (#970)
parent
22d1df224d
commit
632e94abde
|
@ -3,12 +3,14 @@
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "kernels.h"
|
||||||
|
|
||||||
template <typename T> class Dropout {
|
template <typename T>
|
||||||
public:
|
class Dropout {
|
||||||
|
public:
|
||||||
struct Config {
|
struct Config {
|
||||||
float ratio;
|
float ratio;
|
||||||
bool training;
|
bool training;
|
||||||
|
@ -88,7 +90,7 @@ public:
|
||||||
|
|
||||||
void SetTrainingMode(bool training) { _config.training = training; }
|
void SetTrainingMode(bool training) { _config.training = training; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
uint8_t *_mask;
|
uint8_t *_mask;
|
||||||
Config _config;
|
Config _config;
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue