mirror of https://github.com/hpcaitech/ColossalAI
Frank Lee
1 year ago
committed by
GitHub
5 changed files with 51 additions and 11 deletions
@ -0,0 +1,39 @@
|
||||
from .builder import Builder |
||||
|
||||
|
||||
class ElixirSimulatorBuilder(Builder): |
||||
NAME = "elixir_simulator" |
||||
PREBUILT_IMPORT_PATH = "colossalai._C.elixir_simulator" |
||||
|
||||
def __init__(self): |
||||
super().__init__(name=ElixirSimulatorBuilder.NAME, |
||||
prebuilt_import_path=ElixirSimulatorBuilder.PREBUILT_IMPORT_PATH) |
||||
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5'] |
||||
|
||||
# necessary 4 functions |
||||
def sources_files(self): |
||||
ret = [ |
||||
self.relative_to_abs_path('elixir/simulator.cpp'), |
||||
] |
||||
return ret |
||||
|
||||
def include_dirs(self): |
||||
return [] |
||||
|
||||
def cxx_flags(self): |
||||
return ['-O3'] + self.version_dependent_macros |
||||
|
||||
def nvcc_flags(self): |
||||
return [] |
||||
|
||||
def builder(self) -> 'CppExtension': |
||||
""" |
||||
This function should return a CppExtension object. |
||||
""" |
||||
from torch.utils.cpp_extension import CppExtension |
||||
|
||||
return CppExtension(name=self.prebuilt_import_path, |
||||
sources=self.strip_empty_entries(self.sources_files()), |
||||
extra_compile_args={ |
||||
'cxx': self.strip_empty_entries(self.cxx_flags()), |
||||
}) |
Loading…
Reference in new issue