mirror of https://github.com/hpcaitech/ColossalAI
refactor code
parent
21e1e3645c
commit
095c070a6e
|
@ -21,7 +21,7 @@ class CpuAdamX86Extension(_CudaExtension):
|
||||||
# necessary 4 functions
|
# necessary 4 functions
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path("cuda/cpu_adam.cpp"),
|
self.csrc_abs_path("x86/cpu_adam.cpp"),
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@
|
||||||
|
|
||||||
|
|
||||||
#include "block_reduce.h"
|
#include "block_reduce.h"
|
||||||
#include "type_shim.h"
|
#include "../common/micros.h"
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__global__ void rms_layernorm_kernel(
|
__global__ void rms_layernorm_kernel(
|
||||||
|
|
|
@ -10,7 +10,7 @@ class InferenceOpsCudaExtension(_CudaExtension):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in [
|
for fname in [
|
||||||
"cuda/colossal_inference_C_frontend.cpp",
|
"cuda/pybind/inference.cpp",
|
||||||
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
"cuda/decode_kv_cache_memcpy_kernel.cu",
|
||||||
"cuda/activation_kernel.cu",
|
"cuda/activation_kernel.cu",
|
||||||
"cuda/rms_layernorm_kernel.cu",
|
"cuda/rms_layernorm_kernel.cu",
|
||||||
|
|
|
@ -7,7 +7,7 @@ class LayerNormCudaExtension(_CudaExtension):
|
||||||
super().__init__(name="layernorm_cuda")
|
super().__init__(name="layernorm_cuda")
|
||||||
|
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [self.csrc_abs_path(fname) for fname in ["cuda/layer_norm_cuda.cpp", "cuda/layer_norm_cuda_kernel.cu"]]
|
ret = [self.csrc_abs_path(fname) for fname in ["cuda/pybind/layer_norm.cpp", "cuda/layer_norm_kernel.cu"]]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def include_dirs(self):
|
def include_dirs(self):
|
||||||
|
|
|
@ -11,7 +11,7 @@ class MoeCudaExtension(_CudaExtension):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe_cuda.cpp", "cuda/moe_cuda_kernel.cu"]]
|
ret = [self.csrc_abs_path(fname) for fname in ["cuda/moe.cpp", "cuda/moe_kernel.cu"]]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cxx_flags(self):
|
def cxx_flags(self):
|
||||||
|
|
|
@ -10,12 +10,12 @@ class FusedOptimizerCudaExtension(_CudaExtension):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in [
|
for fname in [
|
||||||
"cuda/colossal_C_frontend.cpp",
|
"cuda/pybind/optimizer.cpp",
|
||||||
"cuda/multi_tensor_sgd_kernel.cu",
|
"cuda/multi_tensor_sgd_kernel.cu",
|
||||||
"cuda/multi_tensor_scale_kernel.cu",
|
"cuda/multi_tensor_scale_kernel.cu",
|
||||||
"cuda/multi_tensor_adam.cu",
|
"cuda/multi_tensor_adam_kernel.cu",
|
||||||
"cuda/multi_tensor_l2norm_kernel.cu",
|
"cuda/multi_tensor_l2norm_kernel.cu",
|
||||||
"cuda/multi_tensor_lamb.cu",
|
"cuda/multi_tensor_lamb_kernel.cu",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -9,7 +9,7 @@ class ScaledMaskedSoftmaxCudaExtension(_CudaExtension):
|
||||||
def sources_files(self):
|
def sources_files(self):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in ["cuda/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_cuda.cu"]
|
for fname in ["cuda/pybind/scaled_masked_softmax.cpp", "cuda/scaled_masked_softmax_kernel.cu"]
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
|
@ -13,8 +13,8 @@ class ScaledUpperTriangleMaskedSoftmaxCudaExtension(_CudaExtension):
|
||||||
ret = [
|
ret = [
|
||||||
self.csrc_abs_path(fname)
|
self.csrc_abs_path(fname)
|
||||||
for fname in [
|
for fname in [
|
||||||
"cuda/scaled_upper_triang_masked_softmax.cpp",
|
"cuda/pybind/scaled_upper_triang_masked_softmax.cpp",
|
||||||
"cuda/scaled_upper_triang_masked_softmax_cuda.cu",
|
"cuda/scaled_upper_triang_masked_softmax_kernel.cu",
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
Loading…
Reference in New Issue