diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 064e55a40..adc65d654 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -65,11 +65,14 @@ class FusedAdam(torch.optim.Optimizer): self.adamw_mode = 1 if adamw_mode else 0 self.set_grad_none = set_grad_none if multi_tensor_applier.available: - import colossalai._C.fused_optim - + try: + from colossalai._C import fused_optim + except: + from colossalai.kernel.op_builder.fused_optim import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer self._dummy_overflow_buf = torch.cuda.IntTensor([0]) - self.multi_tensor_adam = colossalai._C.fused_optim.multi_tensor_adam + self.multi_tensor_adam = fused_optim.multi_tensor_adam else: raise RuntimeError('FusedAdam requires cuda extensions') diff --git a/colossalai/nn/optimizer/fused_lamb.py b/colossalai/nn/optimizer/fused_lamb.py index 2e33d7032..b480b8cd5 100644 --- a/colossalai/nn/optimizer/fused_lamb.py +++ b/colossalai/nn/optimizer/fused_lamb.py @@ -76,13 +76,18 @@ class FusedLAMB(torch.optim.Optimizer): max_grad_norm=max_grad_norm) super(FusedLAMB, self).__init__(params, defaults) if multi_tensor_applier.available: - import colossalai._C.fused_optim - self.multi_tensor_l2norm = colossalai._C.fused_optim.multi_tensor_l2norm + try: + from colossalai._C import fused_optim + except: + from colossalai.kernel.op_builder.fused_optim import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + + self.multi_tensor_l2norm = fused_optim.multi_tensor_l2norm # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_lamb = colossalai._C.fused_optim.multi_tensor_lamb + self.multi_tensor_lamb = fused_optim.multi_tensor_lamb else: raise RuntimeError('FusedLAMB requires cuda extensions') diff --git a/colossalai/nn/optimizer/fused_sgd.py b/colossalai/nn/optimizer/fused_sgd.py index 03c3da28d..a0141473b 100644 --- a/colossalai/nn/optimizer/fused_sgd.py +++ b/colossalai/nn/optimizer/fused_sgd.py @@ -80,13 +80,16 @@ class FusedSGD(Optimizer): self.wd_after_momentum = wd_after_momentum if multi_tensor_applier.available: - import colossalai._C.fused_optim - + try: + from colossalai._C import fused_optim + except: + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() # Skip buffer self._dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=self.param_groups[0]["params"][0].device) - self.multi_tensor_sgd = colossalai._C.fused_optim.multi_tensor_sgd + self.multi_tensor_sgd = fused_optim.multi_tensor_sgd else: raise RuntimeError('FusedSGD requires cuda extensions') diff --git a/colossalai/utils/common.py b/colossalai/utils/common.py index d8cd709b3..496ac136a 100644 --- a/colossalai/utils/common.py +++ b/colossalai/utils/common.py @@ -12,9 +12,10 @@ from torch._six import inf from torch.nn.parameter import Parameter try: - import colossalai._C.fused_optim + from colossalai._C import fused_optim except: - pass + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() from collections import defaultdict from contextlib import contextmanager @@ -133,7 +134,7 @@ def _calc_l2_norm(grads): if len(grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) norm, _ = multi_tensor_applier( - colossalai._C.fused_optim.multi_tensor_l2norm, + fused_optim.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm @@ -270,8 +271,8 @@ def _clip_grad_norm(parameters, max_norm: float, total_norm: float) -> None: cpu_grads.append(p.grad.detach()) if len(cuda_grads) > 0: dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf, - [cuda_grads, cuda_grads], clip_coef) + multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [cuda_grads, cuda_grads], + clip_coef) for g in cpu_grads: g.mul_(clip_coef) @@ -397,8 +398,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if enable_cuda_kernels: grads = [p.grad.detach() for p in params] dummy_overflow_buf = torch.cuda.IntTensor([0]) - multi_tensor_applier(colossalai._C.fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], - clip_coeff) + multi_tensor_applier(fused_optim.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coeff) else: for p in params: p.grad.detach().mul_(clip_coeff) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py index d95a23702..0668e7a46 100644 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ b/tests/test_optimizer/test_fused_adam_kernel.py @@ -49,9 +49,12 @@ def test_adam(adamw, step, p_dtype, g_dtype): try: import colossalai._C.fused_optim fused_adam = colossalai._C.fused_optim.multi_tensor_adam - dummy_overflow_buf = torch.cuda.IntTensor([0]) except: - raise ImportError("No colossalai._C.fused_optim kernel installed.") + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + fused_adam = fused_optim.multi_tensor_adam + + dummy_overflow_buf = torch.cuda.IntTensor([0]) count = 0