mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
134 lines
4.6 KiB
134 lines
4.6 KiB
import importlib |
|
import os |
|
import time |
|
from abc import abstractmethod |
|
from pathlib import Path |
|
from typing import List |
|
|
|
from .base_extension import _Extension |
|
|
|
__all__ = ["_CppExtension"] |
|
|
|
|
|
class _CppExtension(_Extension): |
|
def __init__(self, name: str, priority: int = 1): |
|
super().__init__(name, support_aot=True, support_jit=True, priority=priority) |
|
|
|
# we store the op as an attribute to avoid repeated building and loading |
|
self.cached_op = None |
|
|
|
# build-related variables |
|
self.prebuilt_module_path = "colossalai._C" |
|
self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}" |
|
self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"] |
|
|
|
def csrc_abs_path(self, path): |
|
return os.path.join(self.relative_to_abs_path("csrc"), path) |
|
|
|
def relative_to_abs_path(self, code_path: str) -> str: |
|
""" |
|
This function takes in a path relative to the colossalai root directory and return the absolute path. |
|
""" |
|
|
|
# get the current file path |
|
# iteratively check the parent directory |
|
# if the parent directory is "extensions", then the current file path is the root directory |
|
# otherwise, the current file path is inside the root directory |
|
current_file_path = Path(__file__) |
|
while True: |
|
if current_file_path.name == "extensions": |
|
break |
|
else: |
|
current_file_path = current_file_path.parent |
|
extension_module_path = current_file_path |
|
code_abs_path = extension_module_path.joinpath(code_path) |
|
return str(code_abs_path) |
|
|
|
# functions must be overrided over |
|
def strip_empty_entries(self, args): |
|
""" |
|
Drop any empty strings from the list of compile and link flags |
|
""" |
|
return [x for x in args if len(x) > 0] |
|
|
|
def import_op(self): |
|
""" |
|
This function will import the op module by its string name. |
|
""" |
|
return importlib.import_module(self.prebuilt_import_path) |
|
|
|
def build_aot(self) -> "CppExtension": |
|
from torch.utils.cpp_extension import CppExtension |
|
|
|
return CppExtension( |
|
name=self.prebuilt_import_path, |
|
sources=self.strip_empty_entries(self.sources_files()), |
|
include_dirs=self.strip_empty_entries(self.include_dirs()), |
|
extra_compile_args=self.strip_empty_entries(self.cxx_flags()), |
|
) |
|
|
|
def build_jit(self) -> None: |
|
from torch.utils.cpp_extension import load |
|
|
|
build_directory = _Extension.get_jit_extension_folder_path() |
|
build_directory = Path(build_directory) |
|
build_directory.mkdir(parents=True, exist_ok=True) |
|
|
|
# check if the kernel has been built |
|
compiled_before = False |
|
kernel_file_path = build_directory.joinpath(f"{self.name}.o") |
|
if kernel_file_path.exists(): |
|
compiled_before = True |
|
|
|
# load the kernel |
|
if compiled_before: |
|
print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now") |
|
else: |
|
print(f"[extension] Compiling the JIT {self.name} kernel during runtime now") |
|
|
|
build_start = time.time() |
|
op_kernel = load( |
|
name=self.name, |
|
sources=self.strip_empty_entries(self.sources_files()), |
|
extra_include_paths=self.strip_empty_entries(self.include_dirs()), |
|
extra_cflags=self.cxx_flags(), |
|
extra_ldflags=[], |
|
build_directory=str(build_directory), |
|
) |
|
build_duration = time.time() - build_start |
|
|
|
if compiled_before: |
|
print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds") |
|
else: |
|
print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds") |
|
|
|
return op_kernel |
|
|
|
# functions must be overrided begin |
|
@abstractmethod |
|
def sources_files(self) -> List[str]: |
|
""" |
|
This function should return a list of source files for extensions. |
|
""" |
|
|
|
@abstractmethod |
|
def include_dirs(self) -> List[str]: |
|
""" |
|
This function should return a list of include files for extensions. |
|
""" |
|
|
|
@abstractmethod |
|
def cxx_flags(self) -> List[str]: |
|
""" |
|
This function should return a list of cxx compilation flags for extensions. |
|
""" |
|
|
|
def load(self): |
|
try: |
|
op_kernel = self.import_op() |
|
except (ImportError, ModuleNotFoundError): |
|
# if import error occurs, it means that the kernel is not pre-built |
|
# so we build it jit |
|
op_kernel = self.build_jit() |
|
|
|
return op_kernel
|
|
|