diff --git a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h index 005a36ba1..ec447ad84 100644 --- a/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h +++ b/colossalai/kernel/cuda_native/csrc/kernels/include/softmax.h @@ -10,8 +10,9 @@ using namespace std; -template class Softmax { -public: +template +class Softmax { + public: struct Config { size_t nhead; Config(size_t nhead) : nhead(nhead) {} @@ -36,6 +37,6 @@ public: void reset_size(size_t nhead) { config_.nhead = nhead; } -private: + private: Config config_; };