mirror of https://github.com/hpcaitech/ColossalAI
[builder] runtime adam and fused_optim builder (#2184)
parent
550f8f8905
commit
d42afd30f8
|
@ -5,9 +5,11 @@ import torch
|
|||
import torch.distributed as dist
|
||||
|
||||
try:
|
||||
import colossalai._C.fused_optim
|
||||
from colossalai._C import fused_optim
|
||||
except:
|
||||
print('Colossalai should be built with cuda extension to use the FP16 optimizer')
|
||||
from colossalai.kernel.op_builder.fused_optim import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
@ -35,7 +37,7 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
|
|||
if overflow_buf:
|
||||
overflow_buf.fill_(0)
|
||||
# Scaling with factor `1.0` is equivalent to copy.
|
||||
multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||
multi_tensor_applier(fused_optim.multi_tensor_scale, overflow_buf, [this, that], 1.0)
|
||||
else:
|
||||
for this_, that_ in zip(this, that):
|
||||
that_.copy_(this_)
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
from .cpu_adam import CPUAdamBuilder
|
||||
from .fused_optim import FusedOptimBuilder
|
||||
|
||||
__all__ = ['CPUAdamBuilder', 'FusedOptimBuilder']
|
|
@ -0,0 +1,45 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class Builder(object):
|
||||
|
||||
def colossalai_src_path(self, code_path):
|
||||
if os.path.isabs(code_path):
|
||||
return code_path
|
||||
else:
|
||||
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
|
||||
|
||||
def strip_empty_entries(self, args):
|
||||
'''
|
||||
Drop any empty strings from the list of compile and link flags
|
||||
'''
|
||||
return [x for x in args if len(x) > 0]
|
||||
|
||||
def load(self, verbose=True):
|
||||
"""
|
||||
|
||||
load and compile cpu_adam lib at runtime
|
||||
|
||||
Args:
|
||||
verbose (bool, optional): show detailed info. Defaults to True.
|
||||
"""
|
||||
import time
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
start_build = time.time()
|
||||
|
||||
op_module = load(name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources),
|
||||
extra_include_paths=self.strip_empty_entries(self.extra_include_paths),
|
||||
extra_cflags=self.extra_cxx_flags,
|
||||
extra_cuda_cflags=self.extra_cuda_flags,
|
||||
extra_ldflags=[],
|
||||
verbose=verbose)
|
||||
|
||||
build_duration = time.time() - start_build
|
||||
if verbose:
|
||||
print(f"Time to load {self.name} op: {build_duration} seconds")
|
||||
|
||||
return op_module
|
|
@ -0,0 +1,84 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from .builder import Builder
|
||||
|
||||
|
||||
class CPUAdamBuilder(Builder):
|
||||
NAME = "cpu_adam"
|
||||
BASE_DIR = "cuda_native"
|
||||
|
||||
def __init__(self):
|
||||
self.name = CPUAdamBuilder.NAME
|
||||
super().__init__()
|
||||
|
||||
self.sources = [self.colossalai_src_path(path) for path in self.sources_files()]
|
||||
self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()]
|
||||
self.extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
|
||||
self.extra_cuda_flags = [
|
||||
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
|
||||
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
|
||||
]
|
||||
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
|
||||
def sources_files(self):
|
||||
return [
|
||||
os.path.join(CPUAdamBuilder.BASE_DIR, "csrc/cpu_adam.cpp"),
|
||||
]
|
||||
|
||||
def include_paths(self):
|
||||
import torch
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return [os.path.join(CPUAdamBuilder.BASE_DIR, "includes"), cuda_include]
|
||||
|
||||
def colossalai_src_path(self, code_path):
|
||||
if os.path.isabs(code_path):
|
||||
return code_path
|
||||
else:
|
||||
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
|
||||
|
||||
def strip_empty_entries(self, args):
|
||||
'''
|
||||
Drop any empty strings from the list of compile and link flags
|
||||
'''
|
||||
return [x for x in args if len(x) > 0]
|
||||
|
||||
def builder(self):
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
return CUDAExtension(
|
||||
name=self.name,
|
||||
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
|
||||
include_dirs=self.extra_include_paths,
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cxx_flags,
|
||||
'nvcc': ['-O3', '--use_fast_math'] + self.extra_cuda_flags
|
||||
})
|
||||
|
||||
def load(self, verbose=True):
|
||||
"""
|
||||
|
||||
load and compile cpu_adam lib at runtime
|
||||
|
||||
Args:
|
||||
verbose (bool, optional): show detailed info. Defaults to True.
|
||||
"""
|
||||
import time
|
||||
|
||||
from torch.utils.cpp_extension import load
|
||||
start_build = time.time()
|
||||
|
||||
op_module = load(name=self.name,
|
||||
sources=self.strip_empty_entries(self.sources),
|
||||
extra_include_paths=self.strip_empty_entries(self.extra_include_paths),
|
||||
extra_cflags=self.extra_cxx_flags,
|
||||
extra_cuda_cflags=self.extra_cuda_flags,
|
||||
extra_ldflags=[],
|
||||
verbose=verbose)
|
||||
|
||||
build_duration = time.time() - start_build
|
||||
if verbose:
|
||||
print(f"Time to load {self.name} op: {build_duration} seconds")
|
||||
|
||||
return op_module
|
|
@ -0,0 +1,53 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
import torch
|
||||
|
||||
from .builder import Builder
|
||||
|
||||
|
||||
class FusedOptimBuilder(Builder):
|
||||
NAME = "fused_optim"
|
||||
BASE_DIR = "cuda_native/csrc"
|
||||
|
||||
def __init__(self):
|
||||
self.name = FusedOptimBuilder.NAME
|
||||
super().__init__()
|
||||
|
||||
self.extra_cxx_flags = []
|
||||
self.extra_cuda_flags = ['-lineinfo']
|
||||
for arch in torch.cuda.get_arch_list():
|
||||
res = re.search(r'sm_(\d+)', arch)
|
||||
if res:
|
||||
arch_cap = res[1]
|
||||
if int(arch_cap) >= 60:
|
||||
self.extra_cuda_flags.extend(['-gencode', f'arch=compute_{arch_cap},code={arch}'])
|
||||
|
||||
self.sources = [self.colossalai_src_path(path) for path in self.sources_files()]
|
||||
self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()]
|
||||
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
|
||||
|
||||
def sources_files(self):
|
||||
return [
|
||||
os.path.join(FusedOptimBuilder.BASE_DIR, fname) for fname in [
|
||||
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
|
||||
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
|
||||
]
|
||||
]
|
||||
|
||||
def include_paths(self):
|
||||
import torch
|
||||
from torch.utils.cpp_extension import CUDA_HOME
|
||||
cuda_include = os.path.join(CUDA_HOME, "include")
|
||||
return [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), cuda_include]
|
||||
|
||||
def builder(self):
|
||||
from torch.utils.cpp_extension import CUDAExtension
|
||||
return CUDAExtension(
|
||||
name=self.name,
|
||||
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
|
||||
include_dirs=self.extra_include_paths,
|
||||
extra_compile_args={
|
||||
'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cxx_flags,
|
||||
'nvcc': ['-O3', '--use_fast_math'] + self.extra_cuda_flags
|
||||
})
|
|
@ -77,15 +77,15 @@ class HybridAdam(NVMeOptimizer):
|
|||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
try:
|
||||
import colossalai._C.cpu_optim
|
||||
import colossalai._C.fused_optim
|
||||
from colossalai._C import cpu_optim, fused_optim
|
||||
except ImportError:
|
||||
raise ImportError('Please install colossalai from source code to use HybridAdam')
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
|
||||
self.cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay,
|
||||
adamw_mode)
|
||||
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
self.gpu_adam_op = colossalai._C.fused_optim.multi_tensor_adam
|
||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
@ -69,8 +69,12 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
|||
try:
|
||||
import colossalai._C.cpu_optim
|
||||
cpu_adam_op = colossalai._C.cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
print("use prebuilt CPUAdamOptimizer")
|
||||
except:
|
||||
raise ImportError("Import cpu adam error, please install colossal from source code")
|
||||
from colossalai.kernel.op_builder.cpu_adam import CPUAdamBuilder
|
||||
lib = CPUAdamBuilder().load()
|
||||
cpu_adam_op = lib.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
print("build CPUAdamOptimizer at runtime")
|
||||
|
||||
cpu_adam_op.step(
|
||||
step,
|
||||
|
@ -115,3 +119,7 @@ def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
|||
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||||
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
||||
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cpu_adam()
|
||||
|
|
Loading…
Reference in New Issue