mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
41 lines
1.0 KiB
41 lines
1.0 KiB
import platform |
|
|
|
from ..cpp_extension import _CppExtension |
|
|
|
|
|
class CpuAdamArmExtension(_CppExtension): |
|
def __init__(self): |
|
super().__init__(name="cpu_adam_arm") |
|
|
|
def is_hardware_available(self) -> bool: |
|
# only arm allowed |
|
return platform.machine() == "aarch64" |
|
|
|
def assert_hardware_compatible(self) -> None: |
|
arch = platform.machine() |
|
assert ( |
|
arch == "aarch64" |
|
), f"[extension] The {self.name} kernel requires the CPU architecture to be aarch64 but got {arch}" |
|
|
|
# necessary 4 functions |
|
def sources_files(self): |
|
ret = [ |
|
self.csrc_abs_path("arm/cpu_adam_arm.cpp"), |
|
] |
|
return ret |
|
|
|
def include_dirs(self): |
|
return [] |
|
|
|
def cxx_flags(self): |
|
extra_cxx_flags = [ |
|
"-std=c++14", |
|
"-std=c++17", |
|
"-g", |
|
"-Wno-reorder", |
|
"-fopenmp", |
|
] |
|
return ["-O3"] + self.version_dependent_macros + extra_cxx_flags |
|
|
|
def nvcc_flags(self): |
|
return []
|
|
|