From ae02d4e4f70e8ba4f8ae1058ac48bd08b06b6d24 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Mon, 5 Jun 2023 15:58:31 +0800 Subject: [PATCH] [bf16] add bf16 support (#3882) * [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zero --- .../mixed_precision_mixin/__init__.py | 9 ++ .../naive_amp/mixed_precision_mixin/base.py | 91 ++++++++++++ .../naive_amp/mixed_precision_mixin/bf16.py | 23 +++ .../naive_amp/mixed_precision_mixin/fp16.py | 84 +++++++++++ colossalai/booster/plugin/gemini_plugin.py | 9 +- .../booster/plugin/low_level_zero_plugin.py | 33 +++-- .../kernel/cuda_native/csrc/type_shim.h | 15 ++ colossalai/nn/optimizer/cpu_adam.py | 23 ++- colossalai/nn/optimizer/fused_adam.py | 4 +- colossalai/nn/optimizer/hybrid_adam.py | 37 ++--- colossalai/zero/gemini/gemini_ddp.py | 24 +++- colossalai/zero/gemini/gemini_optimizer.py | 92 ++++++------ .../zero/legacy/init_ctx/init_context.py | 11 +- .../zero/legacy/sharded_model/_utils.py | 10 +- .../legacy/sharded_model/sharded_model_v2.py | 7 +- .../legacy/sharded_optim/sharded_optim_v2.py | 39 ++++-- colossalai/zero/low_level/low_level_optim.py | 106 +++++++------- tests/test_optimizer/test_adam_kernel.py | 131 ++++++++++++++++++ tests/test_optimizer/test_adam_optim.py | 86 ++++++++++++ tests/test_optimizer/test_cpu_adam.py | 121 ---------------- tests/test_optimizer/test_fused_adam.py | 64 --------- .../test_optimizer/test_fused_adam_kernel.py | 95 ------------- tests/test_optimizer/test_hybrid_adam.py | 42 ------ tests/test_zero/test_gemini/test_optim.py | 46 ++++-- .../test_zero/test_legacy/test_zero_engine.py | 21 ++- .../test_zero/test_low_level/test_grad_acc.py | 5 +- .../test_zero/test_low_level/test_zero1_2.py | 35 +++-- 27 files changed, 738 insertions(+), 525 deletions(-) create mode 100644 colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py create mode 100644 colossalai/amp/naive_amp/mixed_precision_mixin/base.py create mode 100644 colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py create mode 100644 colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py create mode 100644 tests/test_optimizer/test_adam_kernel.py create mode 100644 tests/test_optimizer/test_adam_optim.py delete mode 100644 tests/test_optimizer/test_cpu_adam.py delete mode 100644 tests/test_optimizer/test_fused_adam.py delete mode 100644 tests/test_optimizer/test_fused_adam_kernel.py delete mode 100644 tests/test_optimizer/test_hybrid_adam.py diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py new file mode 100644 index 000000000..b0348e147 --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/__init__.py @@ -0,0 +1,9 @@ +from .base import MixedPrecisionMixin +from .bf16 import BF16MixedPrecisionMixin +from .fp16 import FP16MixedPrecisionMixin + +__all__ = [ + 'MixedPrecisionMixin', + 'FP16MixedPrecisionMixin', + 'BF16MixedPrecisionMixin', +] diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/base.py b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py new file mode 100644 index 000000000..a52a9747a --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/base.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod + +import torch +from torch import Tensor + + +class MixedPrecisionMixin(ABC): + """A helper class for mixed precision training. This mixin is used in mixed precision optimizers. + + Attributes: + dtype (torc.dtype): The expected dtype of the gradients. + + Examples: + ```python + class MyMixedPrecisionOptimizer(OptimizerWrapper): + def __init__(self, optim: Optimizer): + super().__init__(optim) + self.mixed_precision = MixedPrecisionMixin() + + def backward(self, loss): + loss = self.mixed_precision.pre_backward(loss) + loss.backward() + + def backward_by_grad(self, tensor, grad): + grad = self.mixed_precision.pre_backward_by_grad(tensor, grad) + tensor.backward(grad) + + def step(self): + if self.mixed_precision.should_skip_step(): + self.zero_grad() + return + div_scale = self.mixed_precision.get_grad_div_scale() + # maybe clip grad here + # maybe scale grad here + self.optim.step() + + def zero_grad(self): + self.mixed_precision.pre_zero_grad() + return self.optim.zero_grad() + ``` + """ + dtype: torch.dtype + + @abstractmethod + def pre_backward(self, loss: Tensor) -> Tensor: + """Called before backward. + + Args: + loss (Tensor): Loss value. + + Returns: + Tensor: Loss value (possibly scaled). + """ + pass + + @abstractmethod + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + """Called before backward by grad. This is helpful for pipeline parallelism. + + Args: + tensor (Tensor): Tensor to backward. + grad (Tensor): Gradient of the tensor. + + Returns: + Tensor: Gradient of the tensor (possibly scaled). + """ + pass + + @abstractmethod + def should_skip_step(self) -> bool: + """Called before step. + + Returns: + bool: Whether to skip the step. + """ + pass + + @abstractmethod + def pre_zero_grad(self) -> None: + """Called before zero_grad. + """ + pass + + @abstractmethod + def get_grad_div_scale(self) -> float: + """Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads. + + Returns: + float: A divisor for gradient clipping or step. + """ + pass diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py new file mode 100644 index 000000000..9454f6eb8 --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/bf16.py @@ -0,0 +1,23 @@ +import torch +from torch import Tensor + +from .base import MixedPrecisionMixin + + +class BF16MixedPrecisionMixin(MixedPrecisionMixin): + dtype = torch.bfloat16 + + def pre_backward(self, loss: Tensor) -> Tensor: + return loss + + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + return grad + + def should_skip_step(self) -> bool: + return False + + def pre_zero_grad(self) -> None: + pass + + def get_grad_div_scale(self) -> float: + return 1.0 diff --git a/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py new file mode 100644 index 000000000..1ce8e42eb --- /dev/null +++ b/colossalai/amp/naive_amp/mixed_precision_mixin/fp16.py @@ -0,0 +1,84 @@ +from abc import abstractmethod +from enum import Enum + +import torch +import torch.distributed as dist +from torch import Tensor + +from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.utils import get_current_device + +from .base import MixedPrecisionMixin + + +class OptimState(Enum): + SCALED = 0 + UNSCALED = 1 + + +class FP16MixedPrecisionMixin(MixedPrecisionMixin): + dtype = torch.float16 + + def __init__(self, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__() + self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + self.optim_state = OptimState.UNSCALED + self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device()) + + @property + def loss_scale(self) -> float: + return self.grad_scaler.scale.item() + + @abstractmethod + def check_local_overflow(self) -> bool: + """Check whether there is overflow in the local process. This method should be implemented by subclasses. + + Returns: + bool: Whether there is overflow in the local process. + """ + pass + + def check_overflow(self) -> bool: + # clear previous overflow record + self.found_overflow.fill_(0.0) + if self.check_local_overflow(): + self.found_overflow.fill_(1.0) + dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX) + return self.found_overflow.item() > 0 + + def pre_backward(self, loss: Tensor) -> Tensor: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + return loss + + def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor: + self.optim_state = OptimState.SCALED + return grad + + def should_skip_step(self) -> bool: + found_inf = self.check_overflow() + self.grad_scaler.update(found_inf) + if found_inf: + self.optim_state = OptimState.UNSCALED + return found_inf + + def pre_zero_grad(self) -> None: + pass + + def get_grad_div_scale(self) -> float: + assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping' + self.optim_state = OptimState.UNSCALED + return self.loss_scale diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index adbf4803e..46714fe1c 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -23,6 +23,9 @@ from .dp_plugin_base import DPPluginBase __all__ = ['GeminiPlugin'] +SUPPORTED_PRECISION = ['fp16', 'bf16'] +PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16} + class GeminiCheckpointIO(GeneralCheckpointIO): @@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase): Args: device (torch.device): device to place the model. placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu". + precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'. pin_memory (bool, optional): use pin memory on CPU. Defaults to False. force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False. strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False. @@ -203,6 +207,7 @@ class GeminiPlugin(DPPluginBase): self, device: Optional[torch.device] = None, placement_policy: str = "cpu", + precision: str = "fp16", pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, @@ -223,6 +228,7 @@ class GeminiPlugin(DPPluginBase): verbose: bool = False, ) -> None: super().__init__() + assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported' self.gemini_config = dict( device=(device or get_current_device()), placement_policy=placement_policy, @@ -233,6 +239,7 @@ class GeminiPlugin(DPPluginBase): hidden_dim=hidden_dim, min_chunk_size_mb=min_chunk_size_mb, memstats=memstats, + mixed_precision=PRECISION_STR_TO_DTYPE[precision], ) self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,) self.optim_kwargs = dict(initial_scale=initial_scale, @@ -253,7 +260,7 @@ class GeminiPlugin(DPPluginBase): return True def supported_precisions(self) -> List[str]: - return ['fp16'] + return SUPPORTED_PRECISION def control_device(self) -> bool: return True diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 5d93cf0e3..2b312d0f9 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -1,4 +1,5 @@ import warnings +from functools import partial from typing import Callable, Iterator, List, Optional, Tuple, Union import torch @@ -20,12 +21,15 @@ from .torch_ddp_plugin import TorchDDPCheckpointIO __all__ = ['LowLevelZeroPlugin'] -def _convert_to_fp16(x): +def _convert_floating_point(x, dtype: torch.dtype = torch.float16): if isinstance(x, torch.Tensor) and torch.is_floating_point(x): - return x.half() + return x.to(dtype) return x +SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] + + class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): @@ -49,17 +53,24 @@ class LowLevelZeroModel(ModelWrapper): def __init__(self, module: nn.Module, stage: int, precision: str) -> None: super().__init__(module) - self.convert_inputs = (precision == 'fp16') - module = zero_model_wrapper(module, zero_stage=stage) + self.dtype = None if precision == 'fp16': - module = module.half() + self.dtype = torch.float16 + elif precision == 'bf16': + self.dtype = torch.bfloat16 + module = zero_model_wrapper(module, zero_stage=stage) + if self.dtype is not None: + module = module.to(self.dtype) module = module.to(get_current_device()) self.module = module + self.convert_fn = None + if self.dtype is not None: + self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) def forward(self, *args, **kwargs): - if self.convert_inputs: - args = tree_map(_convert_to_fp16, args) - kwargs = tree_map(_convert_to_fp16, kwargs) + if self.convert_fn is not None: + args = tree_map(self.convert_fn, args) + kwargs = tree_map(self.convert_fn, kwargs) return super().forward(*args, **kwargs) @@ -110,7 +121,7 @@ class LowLevelZeroPlugin(DPPluginBase): Args: strage (int, optional): ZeRO stage. Defaults to 1. - precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'. + precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'. initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32. min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1. growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2. @@ -149,7 +160,7 @@ class LowLevelZeroPlugin(DPPluginBase): ) -> None: super().__init__() assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' - assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training' + assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' self.stage = stage self.precision = precision @@ -175,7 +186,7 @@ class LowLevelZeroPlugin(DPPluginBase): return True def supported_precisions(self) -> List[str]: - return ['fp16', 'fp32'] + return SUPPORTED_PRECISION def control_device(self) -> bool: return True diff --git a/colossalai/kernel/cuda_native/csrc/type_shim.h b/colossalai/kernel/cuda_native/csrc/type_shim.h index 2f180a778..03ccc0263 100644 --- a/colossalai/kernel/cuda_native/csrc/type_shim.h +++ b/colossalai/kernel/cuda_native/csrc/type_shim.h @@ -171,6 +171,21 @@ using g_scalar_t_##LEVEL = at::Half; \ using p_scalar_t_##LEVEL = at::Half; \ __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::Float && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = float; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::Float) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = float; \ + __VA_ARGS__; \ + } else if (GTYPE == at::ScalarType::BFloat16 && \ + PTYPE == at::ScalarType::BFloat16) { \ + using g_scalar_t_##LEVEL = at::BFloat16; \ + using p_scalar_t_##LEVEL = at::BFloat16; \ + __VA_ARGS__; \ } else { \ AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \ "'"); \ diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index bb561a106..7070c0a1e 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -93,8 +93,7 @@ class CPUAdam(NVMeOptimizer): bias_correction1, bias_correction2, use_adamw=False): - # FIXME(ver217): remove the below line when replace torch adam with fused adam - grad = grad.float() + grad = grad.to(data.dtype) if weight_decay != 0: if use_adamw: @@ -133,10 +132,12 @@ class CPUAdam(NVMeOptimizer): if len(state) == 0: state['step'] = 0 + # FIXME(ver217): CPU adam kernel only supports fp32 states now + assert p.dtype is torch.float, "CPUAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg'] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) self._post_state_init(p) state['step'] += 1 @@ -147,9 +148,17 @@ class CPUAdam(NVMeOptimizer): assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" self._pre_update(p, 'exp_avg', 'exp_avg_sq') - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], - group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], div_scale) + if p.grad.dtype is torch.bfloat16: + # cpu adam kernel does not support bf16 now + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] + self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], + beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, + bias_correction2, self.adamw_mode) + else: + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': assert div_scale == -1, "div_scale should remain default" diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 987af8a96..82a6250f1 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -134,8 +134,8 @@ class FusedAdam(torch.optim.Optimizer): # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p) - if p.dtype not in [torch.float16, torch.float32]: - raise RuntimeError('FusedAdam only support fp16 and fp32.') + if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]: + raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.') g_l.append(p.grad.data) p_l.append(p.data) diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index be6311c6c..526071b06 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -1,16 +1,17 @@ from typing import Any, Optional import torch +from torch.optim import Adam -from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder +from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.registry import OPTIMIZERS from colossalai.utils import multi_tensor_applier -from .nvme_optimizer import NVMeOptimizer +from .cpu_adam import CPUAdam @OPTIMIZERS.register_module -class HybridAdam(NVMeOptimizer): +class HybridAdam(CPUAdam): """Implements Adam algorithm. Supports parameters updating on both GPU and CPU, depanding on the device of parameters. @@ -74,15 +75,9 @@ class HybridAdam(NVMeOptimizer): nvme_offload_dir: Optional[str] = None, **defaults: Any): - default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction) - super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir) - self.adamw_mode = adamw_mode - - # build during runtime if not found - cpu_optim = CPUAdamBuilder().load() + super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction, + nvme_offload_dir) fused_optim = FusedOptimBuilder().load() - self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode) - self.gpu_adam_op = fused_optim.multi_tensor_adam self._dummy_overflow_buf = torch.cuda.IntTensor([0]) @@ -108,10 +103,12 @@ class HybridAdam(NVMeOptimizer): if len(state) == 0: state['step'] = 0 + # FIXME(ver217): CPU adam kernel only supports fp32 states now + assert p.dtype is torch.float, "HybridAdam only support fp32 parameters" # gradient momentums - state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg'] = torch.zeros_like(p, device=target_device) # gradient variances - state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device) + state['exp_avg_sq'] = torch.zeros_like(p, device=target_device) self._post_state_init(p) state['step'] += 1 @@ -122,9 +119,17 @@ class HybridAdam(NVMeOptimizer): assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu" assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu" self._pre_update(p, 'exp_avg', 'exp_avg_sq') - self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'], - group['bias_correction'], p.data, p.grad.data, state['exp_avg'], - state['exp_avg_sq'], div_scale) + if p.grad.dtype is torch.bfloat16: + # cpu adam kernel does not support bf16 now + bias_correction1 = 1 - beta1**state['step'] + bias_correction2 = 1 - beta2**state['step'] + self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'], + beta1, beta2, group['eps'], group['weight_decay'], bias_correction1, + bias_correction2, self.adamw_mode) + else: + self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], + group['weight_decay'], group['bias_correction'], p.data, p.grad.data, + state['exp_avg'], state['exp_avg_sq'], div_scale) self._post_update(p, 'exp_avg', 'exp_avg_sq') elif target_device.type == 'cuda': diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index fd49362d6..7e23fdb42 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP): strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated. Defaults to False. Users can set it to True, when they clearly know that they only need DDP. scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference. + mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16. """ def __init__(self, @@ -59,7 +60,9 @@ class ZeroDDP(ColoDDP): pin_memory: bool = False, force_outputs_fp32: bool = False, strict_ddp_mode: bool = False, - scatter_after_inference: bool = True) -> None: + scatter_after_inference: bool = True, + mixed_precision: torch.dtype = torch.float16) -> None: + assert mixed_precision in (torch.float16, torch.bfloat16) self.gemini_manager = gemini_manager self.chunk_manager: ChunkManager = gemini_manager.chunk_manager self.force_outputs_fp32 = force_outputs_fp32 @@ -71,6 +74,7 @@ class ZeroDDP(ColoDDP): self.param2name: Dict[nn.Parameter, str] = dict() self.name2param: Dict[str, nn.Parameter] = dict() self.scatter_after_inference = scatter_after_inference + self.mixed_precision = mixed_precision self._logger = get_dist_logger() @@ -151,7 +155,7 @@ class ZeroDDP(ColoDDP): assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( ), "You should run a completed iteration as your warmup iter" - args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) + args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision) self.module.zero_grad(set_to_none=True) if not grad_flag: outputs = self._inference_forward(*args, **kwargs) @@ -570,14 +574,14 @@ class ZeroDDP(ColoDDP): # move ignored parameters to CUDA if is_ddp_ignored(p): - p.data = p.data.to(device=get_current_device(), dtype=torch.float16) + p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision) continue # create a fp32 parameter fp32_data = p.data.float() fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group)) # create a fp16 parameter - p.data = p.data.half() + p.data = p.data.to(self.mixed_precision) # register the fp16 parameter and fp32 parameter in the chunk manager dp_world_size = p.process_group.dp_world_size() @@ -613,7 +617,7 @@ class ZeroDDP(ColoDDP): buffer.materialize() buffer.data = buffer.cuda() if torch.is_floating_point(buffer): - buffer.data = buffer.half() + buffer.data = buffer.to(self.mixed_precision) def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None: """Convert parameter to ColoParameter in-place. @@ -736,6 +740,7 @@ class GeminiDDP(ZeroDDP): hidden_dim: Optional[int] = None, min_chunk_size_mb: float = 32, memstats: Optional[MemStats] = None, + mixed_precision: torch.dtype = torch.float16, verbose: bool = False) -> None: """ A torch.Module wrapper using ZeRO-DP and Gemini. @@ -776,5 +781,10 @@ class GeminiDDP(ZeroDDP): strict_ddp_flag=strict_ddp_mode, verbose=verbose) gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats) - super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode, - scatter_after_inference) + super().__init__(module, + gemini_manager, + pin_memory, + force_outputs_fp32, + strict_ddp_mode, + scatter_after_inference, + mixed_precision=mixed_precision) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 71c4f65cb..267deb1e8 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,7 +1,6 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import math import warnings -from enum import Enum from typing import Any, Dict, Set, Tuple import torch @@ -9,7 +8,7 @@ import torch.distributed as dist from torch.nn import Parameter from torch.optim import Optimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam from colossalai.utils import disposable, get_current_device, is_ddp_ignored @@ -22,9 +21,26 @@ __all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer'] _AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam} -class OptimState(Enum): - SCALED = 0 - UNSCALED = 1 +class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + + def __init__(self, + module: ZeroDDP, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, + max_scale) + self.module = module + + def check_local_overflow(self) -> bool: + return self.module.overflow_counter > 0 + + def pre_zero_grad(self) -> None: + self.module.overflow_counter = 0 class ZeroOptimizer(ColossalaiOptimizer): @@ -79,7 +95,6 @@ class ZeroOptimizer(ColossalaiOptimizer): self.module = module self.gemini_manager = module.gemini_manager self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager - self.optim_state = OptimState.UNSCALED self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict() self.param_to_chunk32: Dict[Parameter, Chunk] = dict() self.chunk16_set: Set[Chunk] = set() @@ -107,15 +122,20 @@ class ZeroOptimizer(ColossalaiOptimizer): self.__init__optimizer() - # Grad scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale) - self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device()) + if module.mixed_precision is torch.float16: + self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + elif module.mixed_precision is torch.bfloat16: + self.mix_precision_mixin = BF16MixedPrecisionMixin() + else: + raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}") + self._logger = get_dist_logger() self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) @@ -151,15 +171,6 @@ class ZeroOptimizer(ColossalaiOptimizer): for chunk16 in self.chunk16_set: chunk16.optim_update() - def _check_overflow(self): - # clear previous overflow record - self._found_overflow.fill_(self.module.overflow_counter) - - # all-reduce across global group - dist.all_reduce(self._found_overflow) - - return self._found_overflow.item() > 0 - def _clear_global_norm(self) -> None: for c16 in self.chunk16_set: c16.l2_norm = None @@ -190,40 +201,25 @@ class ZeroOptimizer(ColossalaiOptimizer): return global_norm def _get_combined_scale(self): - loss_scale = 1 + div_scale = self.mix_precision_mixin.get_grad_div_scale() - if self.optim_state == OptimState.SCALED: - loss_scale = self.loss_scale - self.optim_state = OptimState.UNSCALED - - combined_scale = loss_scale if self.clipping_flag: total_norm = self._calc_global_norm() - clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm + clip = ((total_norm / div_scale) + 1e-6) / self.max_norm if clip > 1: - combined_scale = clip * loss_scale + div_scale = clip * div_scale - if combined_scale == 1: - return -1 - else: - return combined_scale - - @property - def loss_scale(self): - return self.grad_scaler.scale.item() + return -1 if div_scale == 1.0 else div_scale def zero_grad(self, *args, **kwargs): - self.module.overflow_counter = 0 + self.mix_precision_mixin.pre_zero_grad() return self.optim.zero_grad(set_to_none=True) def step(self, *args, **kwargs): self._maybe_move_fp32_params() self._set_grad_ptr() - found_inf = self._check_overflow() - if found_inf: - self.optim_state = OptimState.UNSCALED # no need to unscale grad - self.grad_scaler.update(found_inf) # update gradient scaler + if self.mix_precision_mixin.should_skip_step(): if self.verbose: self._logger.info(f'Found overflow. Skip step') self._clear_global_norm() # clear recorded norm @@ -234,7 +230,6 @@ class ZeroOptimizer(ColossalaiOptimizer): # get combined scale. combined scale = loss scale * clipping norm # so that gradient = gradient / combined scale combined_scale = self._get_combined_scale() - self.grad_scaler.update(found_inf) ret = self.optim.step(div_scale=combined_scale, *args, **kwargs) self._register_states() @@ -246,8 +241,7 @@ class ZeroOptimizer(ColossalaiOptimizer): raise NotImplementedError def backward(self, loss: torch.Tensor): - loss = self.loss_scale * loss - self.optim_state = OptimState.SCALED + loss = self.mix_precision_mixin.pre_backward(loss) self.module.backward(loss) def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor): @@ -255,7 +249,7 @@ class ZeroOptimizer(ColossalaiOptimizer): # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - self.optim_state = OptimState.SCALED + grad = self.mix_precision_mixin.pre_backward_by_grad(grad) self.module.backward_by_grad(tensor, grad) def _maybe_move_fp32_params(self): diff --git a/colossalai/zero/legacy/init_ctx/init_context.py b/colossalai/zero/legacy/init_ctx/init_context.py index a921ca0aa..a3fa46b38 100644 --- a/colossalai/zero/legacy/init_ctx/init_context.py +++ b/colossalai/zero/legacy/init_ctx/init_context.py @@ -14,7 +14,7 @@ from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses from colossalai.zero.legacy.shard_utils import BaseShardStrategy -from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16 +from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16 from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2 from colossalai.zero.legacy.sharded_param import ShardedParamV2 @@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): seed (int, optional): Random seed for weight initialization shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False. default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16. + bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False. model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int). """ @@ -64,6 +65,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): seed: int = 2**10 - 1, shard_param: bool = False, default_dtype: Optional[torch.dtype] = None, + bf16: bool = False, model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)): super().__init__(default_dtype=default_dtype) @@ -71,6 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.param_list = [] self.model_numel_tensor = model_numel_tensor self.seed = seed + self.bf16 = bf16 self.dp_process_group = gpc.get_group(ParallelMode.DATA) self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param) @@ -183,9 +186,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): NOTE() The module may be passed to this function multiple times. """ self.top_module = module + half_dtype = torch.float16 if not self.bf16 else torch.bfloat16 def half_fn(t: torch.Tensor): - return t.half() if t.is_floating_point() else t + return t.to(half_dtype) if t.is_floating_point() else t for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice @@ -226,9 +230,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): # We must cast buffers # If we use BN, buffers may be on CPU and Float # We must cast them + cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16 for buffer in module.buffers(recurse=False): buffer.data = buffer.data.to(device=torch.cuda.current_device()) - buffer.data = cast_tensor_to_fp16(buffer.data) + buffer.data = cast_fn(buffer.data) class ZeroContextMgr(metaclass=SingletonMeta): diff --git a/colossalai/zero/legacy/sharded_model/_utils.py b/colossalai/zero/legacy/sharded_model/_utils.py index 2bd01531a..f1d642cf3 100644 --- a/colossalai/zero/legacy/sharded_model/_utils.py +++ b/colossalai/zero/legacy/sharded_model/_utils.py @@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te if isinstance(tensor, StatefulTensor): tensor = tensor.payload - if torch.is_floating_point(tensor) and tensor.dtype is torch.float16: + if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16): return tensor.float() return tensor +def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: + return tensor.bfloat16() + return tensor + + def apply_to_tensors(x: Any, fn: Callable): if torch.is_tensor(x): return fn(x) diff --git a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py index b3a83b741..be3842beb 100644 --- a/colossalai/zero/legacy/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/legacy/sharded_model/sharded_model_v2.py @@ -28,6 +28,7 @@ from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBuc from ._utils import ( cast_float_arguments, + cast_tensor_to_bf16, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, @@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module): In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad). We find that PyTorch's optimizers don't support mixed precision, so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False. + bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False. """ def __init__(self, @@ -86,11 +88,13 @@ class ShardedModelV2(nn.Module): tensor_placement_policy: str = 'cuda', gradient_predivide_factor: Optional[float] = 1.0, reuse_fp16_shard: bool = False, + bf16: bool = False, *args, **kwargs): assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.' super().__init__() self.logger = get_dist_logger() + self.bf16 = bf16 # We force users to use ZeroInitContext for submodule in module.modules(): @@ -232,7 +236,8 @@ class ShardedModelV2(nn.Module): def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._pre_forward_operations(*args) - args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) + cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16 + args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs) outputs = self.module(*args, **kwargs) self._post_forward_operations() return outputs diff --git a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py index be60209af..41dd174cb 100644 --- a/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/legacy/sharded_optim/sharded_optim_v2.py @@ -94,6 +94,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): super().__init__(optimizer) self.shard_strategy = sharded_model.shard_strategy self.model: ShardedModelV2 = sharded_model + self.bf16 = sharded_model.bf16 self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio) assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0' @@ -117,6 +118,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device()) self._logger = get_dist_logger("ShardedOptimizerV2") self._verbose = verbose + self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward # Store fp32 param shards self._register_master_weight() @@ -166,8 +168,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._zero_grad() def backward(self, loss: Tensor) -> None: - loss = self.loss_scale * loss - self.optim_state = OptimState.SCALED + if not self.bf16: + loss = self.loss_scale * loss + self.optim_state = OptimState.SCALED + self._grad_prepared = False self.model.backward(loss) def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: @@ -175,30 +179,33 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # It receives the scaled grad from the previous rank # No need to scale the grad again # Need to unscale when optimizing - self.optim_state = OptimState.SCALED + if not self.bf16: + self.optim_state = OptimState.SCALED + self._grad_prepared = False self.model.backward_by_grad(tensor, grad) def clip_grad_norm(self, model: nn.Module, max_norm: float): - if self.optim_state == OptimState.SCALED: - self._prepare_grads() + self._prepare_grads() + if not self.bf16 and self.optim_state == OptimState.SCALED: self._unscale_grads() return super().clip_grad_norm(model, max_norm) def step(self, *args, **kwargs): + self._prepare_grads() # unscale grads if scaled - if self.optim_state == OptimState.SCALED: - self._prepare_grads() + if not self.bf16 and self.optim_state == OptimState.SCALED: self._unscale_grads() self._maybe_move_fp32_shards() - found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) + if not self.bf16: + found_inf = self._check_overflow() + self.grad_scaler.update(found_inf) - if found_inf: - self._logger.warning('found inf during ShardedOptimV2 step') - self._zero_grad(recover_data=True) - return + if found_inf: + self._logger.warning('found inf during ShardedOptimV2 step') + self._zero_grad(recover_data=True) + return self._point_param_fp16_to_master_param() @@ -304,6 +311,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer): state[k] = v.cuda() def _prepare_grads(self): + if self._grad_prepared: + return for group in self.optim.param_groups: for p in group['params']: if p.colo_attr.saved_grad.is_null(): @@ -320,6 +329,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): p.grad = p.colo_attr.grad_payload # Set p.data to empty tensor, in case of memory leaking p.colo_attr.set_data_none() + self._grad_prepared = True def _point_param_fp16_to_master_param(self): # assign master param pointers to p.data. @@ -357,7 +367,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer): torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device)) # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach()) + half_dtype = torch.bfloat16 if self.bf16 else torch.float16 + p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach()) p.colo_attr.set_data_none() if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated: diff --git a/colossalai/zero/low_level/low_level_optim.py b/colossalai/zero/low_level/low_level_optim.py index 3e7661eca..d4d03e5b5 100644 --- a/colossalai/zero/low_level/low_level_optim.py +++ b/colossalai/zero/low_level/low_level_optim.py @@ -6,7 +6,11 @@ import torch import torch.distributed as dist from torch.optim import Optimizer -from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler +from colossalai.amp.naive_amp.mixed_precision_mixin import ( + BF16MixedPrecisionMixin, + FP16MixedPrecisionMixin, + MixedPrecisionMixin, +) from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.logging import get_dist_logger @@ -27,6 +31,31 @@ from ._utils import ( from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket +class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin): + + def __init__(self, + num_working_param_groups: int, + grad_store: GradientStore, + initial_scale: float = 2**16, + min_scale: float = 1, + growth_factor: float = 2, + backoff_factor: float = 0.5, + growth_interval: int = 1000, + hysteresis: int = 2, + max_scale: float = 2**32) -> None: + super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis, + max_scale) + self.num_working_param_groups = num_working_param_groups + self.grad_store = grad_store + + def check_local_overflow(self) -> bool: + for group_id in range(self.num_working_param_groups): + for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id): + if avg_grad is not None and has_inf_or_nan(avg_grad): + return True + return False + + class LowLevelZeroOptimizer(ColossalaiOptimizer): """Optimizer used for ZeRO-1 and ZeRO-2. """ @@ -100,17 +129,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): self._reduce_bucket_size = reduce_bucket_size self._communication_dtype = communication_dtype - # gradient scaler - self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale, - min_scale=min_scale, - growth_factor=growth_factor, - backoff_factor=backoff_factor, - growth_interval=growth_interval, - hysteresis=hysteresis, - max_scale=max_scale, - verbose=verbose) - self._found_overflow = torch.FloatTensor([0]).to(get_current_device()) - # gradient clipping self._clip_grad_norm = clip_grad_norm @@ -200,14 +218,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): if self._overlap_communication or self._partition_grads: self._attach_reduction_hook() + # initialize mixed precision mixin + self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None + if self._dtype is torch.float16: + self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups, + self._grad_store, + initial_scale=initial_scale, + min_scale=min_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + hysteresis=hysteresis, + max_scale=max_scale) + elif self._dtype is torch.bfloat16: + self.mixed_precision_mixin = BF16MixedPrecisionMixin() + @property def dtype(self): return self._dtype - @property - def loss_scale(self): - return self.grad_scaler.scale - @property def num_param_groups(self): return len(self._working_param_groups) @@ -392,7 +421,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ################################ def backward(self, loss, retain_graph=False, sync_grad=True): - loss = self.loss_scale * loss + if self.mixed_precision_mixin is not None: + loss = self.mixed_precision_mixin.pre_backward(loss) loss.backward(retain_graph=retain_graph) # finish gradient reduction @@ -419,6 +449,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): :param set_to_none: Whether set the gradient to None. Default value is True. :type set_to_none: bool """ + if self.mixed_precision_mixin is not None: + self.mixed_precision_mixin.pre_zero_grad() for _, param_group in self._working_param_groups.items(): for param in param_group: if set_to_none: @@ -435,12 +467,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): def step(self, closure=None): assert closure is None, 'closure is not supported by step()' - # check for overflow - found_inf = self._check_overflow() - self.grad_scaler.update(found_inf) - - # update loss scale if overflow occurs - if found_inf: + if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step(): self._grad_store.reset_all_average_gradients() if self._verbose: self._logger.info(f'Found overflow. Skip step') @@ -507,41 +534,20 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): # Mixed Precision Utilities # ############################# - def _check_overflow(self): - # clear previous overflow record - self._found_overflow.fill_(0.0) - - # check for overflow - for group_id in range(len(self._working_param_groups)): - for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id): - if avg_grad is not None and has_inf_or_nan(avg_grad): - self._found_overflow.fill_(1.0) - break - - # all-reduce across dp group - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group) - - # all-reduce over model parallel group - if self._mp_torch_group: - dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group) - - if self._found_overflow.item() > 0: - return True - else: - return False - def _unscale_and_clip_grads(self, grad_groups_flat, total_norm): # compute combined scale factor for this group - combined_scale = self.loss_scale + div_scale = 1.0 + if self.mixed_precision_mixin is not None: + div_scale = self.mixed_precision_mixin.get_grad_div_scale() if self._clip_grad_norm > 0.: # norm is in fact norm*scale - clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm + clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm if clip > 1: - combined_scale = clip * self.loss_scale + div_scale = clip * div_scale for grad in grad_groups_flat: - grad.data.mul_(1. / combined_scale) + grad.data.mul_(1. / div_scale) ############################ # Gradient Synchronization # diff --git a/tests/test_optimizer/test_adam_kernel.py b/tests/test_optimizer/test_adam_kernel.py new file mode 100644 index 000000000..2186a421f --- /dev/null +++ b/tests/test_optimizer/test_adam_kernel.py @@ -0,0 +1,131 @@ +# This test checks adam kernels +# Baseline is pure fp32 torch adam optimizer +import math +from abc import abstractmethod +from typing import Type + +import pytest +import torch +from torch import Tensor + +from colossalai.utils import get_current_device, multi_tensor_applier + +_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), + (torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16), + (torch.bfloat16, torch.bfloat16)] + +_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float), + (torch.half, torch.half)] + + +class AdamKernel: + + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + self.lr = lr + self.beta1 = beta1 + self.beta2 = beta2 + self.eps = eps + self.weight_decay = weight_decay + self.use_adamw = use_adamw + + @abstractmethod + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + pass + + +class TorchAdamKernel(AdamKernel): + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + bias_correction1 = 1 - self.beta1**step + bias_correction2 = 1 - self.beta2**step + + if self.weight_decay != 0: + if self.use_adamw: + # Perform stepweight decay + param.mul_(1 - self.lr * self.weight_decay) + else: + grad = grad.add(param, alpha=self.weight_decay) + + # Decay the first and second moment running average coefficient + exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1) + exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2) + denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps) + + step_size = self.lr / bias_correction1 + + param.addcdiv_(exp_avg, denom, value=-step_size) + + +class FusedAdamKernel(AdamKernel): + + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) + from colossalai.kernel.op_builder import FusedOptimBuilder + fused_optim = FusedOptimBuilder().load() + self.fused_adam = fused_optim.multi_tensor_adam + self.dummy_overflow_buf = torch.cuda.IntTensor([0]) + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]], + self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay, + -1) + + +class CPUAdamKernel(AdamKernel): + + def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None: + super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw) + from colossalai.kernel.op_builder import CPUAdamBuilder + cpu_optim = CPUAdamBuilder().load() + + self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw) + + def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor): + self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1), + grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1) + + +def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype, + g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float): + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw) + adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw) + master_p = torch.rand(64, device=device) + master_g = torch.rand_like(master_p) + master_exp_avg = torch.zeros_like(master_p) + master_exp_avg_sq = torch.zeros_like(master_p) + p = master_p.clone().to(p_dtype) + g = master_g.clone().to(g_dtype) + exp_avg = master_exp_avg.clone() + exp_avg_sq = master_exp_avg_sq.clone() + + for step in range(1, 1 + n_steps): + torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq) + adam_kernel.update(step, p, g, exp_avg, exp_avg_sq) + # if overflow, the weight won't be updated. so there will be no nan in p + assert not torch.isnan(p).any() + assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol) + + +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) +@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES) +def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): + rtol, atol = 1e-5, 1e-8 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 1e-3, 1e-3 + if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: + rtol, atol = 4e-3, 4e-3 + check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol) + + +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('weight_decay', [0.0, 0.1]) +@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES) +def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype): + rtol, atol = 1e-5, 1e-8 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 1e-3, 1e-3 + check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol) diff --git a/tests/test_optimizer/test_adam_optim.py b/tests/test_optimizer/test_adam_optim.py new file mode 100644 index 000000000..0f72bc134 --- /dev/null +++ b/tests/test_optimizer/test_adam_optim.py @@ -0,0 +1,86 @@ +from copy import deepcopy +from typing import Type, Union + +import pytest +import torch +import torch.nn as nn +from torch.optim import Adam, AdamW + +from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam +from tests.kit.model_zoo import model_zoo + +_ALLOWED_OPTIM_DEVICES = [ + (FusedAdam, torch.device('cuda:0')), + (CPUAdam, torch.device('cpu')), + (CPUAdam, torch.device('cuda:0')), + (HybridAdam, torch.device('cpu')), + (HybridAdam, torch.device('cuda:0')), +] + +_ALLOWED_P_G_TYPES = [ + (torch.float, torch.float), # pure fp32 + (torch.float, torch.half), # fp16 amp + (torch.float, torch.bfloat16), # bfloat16 amp + # (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16 + # (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16 +] + +N_STEPS = 3 + + +def setup_param_groups(bert_model: nn.Module) -> list: + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": 0.1, + }, + { + "params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None: + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + torch_p.grad = torch.rand_like(torch_p) + # avoid inconsistent grad and param dtype error + orig_p = p.data + p.data = torch_p.grad.clone().to(g_dtype) + p.grad = p.data + p.data = orig_p + + +@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES) +@pytest.mark.parametrize('adamw', [False, True]) +@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES) +def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device, + adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None: + model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values())) + torch_model = model_fn().to(device) + model = deepcopy(torch_model).to(p_dtype) + lr = 1e-3 + beta1, beta2 = 0.9, 0.999 + eps = 1e-8 + torch_optim_cls = AdamW if adamw else Adam + torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps) + optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw) + + rtol, atol = 1e-5, 1e-5 + if p_dtype is torch.float16 or g_dtype is torch.float16: + rtol, atol = 2e-3, 2e-3 + if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16: + rtol, atol = 4e-3, 4e-3 + + for _ in range(N_STEPS): + set_grad(model, torch_model, g_dtype) + torch_optim.step() + optim.step() + torch_optim.zero_grad() + optim.zero_grad() + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + # if overflow, the weight won't be updated. so there will be no nan in p + assert not torch.isnan(p).any() + assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol) diff --git a/tests/test_optimizer/test_cpu_adam.py b/tests/test_optimizer/test_cpu_adam.py deleted file mode 100644 index 8b3ecf851..000000000 --- a/tests/test_optimizer/test_cpu_adam.py +++ /dev/null @@ -1,121 +0,0 @@ -import math - -import torch - -from colossalai.testing import clear_cache_before_run, parameterize - - -def torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - param, - grad, - exp_avg, - exp_avg_sq, - use_adamw, -): - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - if weight_decay != 0: - if use_adamw: - # Perform stepweight decay - param.mul_(1 - lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - param.addcdiv_(exp_avg, denom, value=-step_size) - - -def assertLess(data_diff, threshold, msg): - assert data_diff < threshold, msg - - -def assertTrue(condition, msg): - assert condition, msg - - -@clear_cache_before_run() -@parameterize('adamw', [True, False]) -@parameterize('step', [1, 2]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_cpu_adam(adamw, step, p_dtype, g_dtype): - lr = 1e-3 - beta1, beta2 = 0.9, 0.999 - eps = 1e-8 - weight_decay = 0 - - for i in range(3): - p_data = torch.rand(64, dtype=p_dtype) - p_data_copy = p_data.clone().float() - p_grad = torch.rand(64, dtype=g_dtype) - p_grad_copy = p_grad.clone().float() - exp_avg = torch.rand(p_data.shape) - exp_avg_copy = exp_avg.clone() - exp_avg_sq = torch.rand(p_data.shape) - exp_avg_sq_copy = exp_avg_sq.clone() - - from colossalai.kernel.op_builder import CPUAdamBuilder - cpu_optim = CPUAdamBuilder().load() - - cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw) - - cpu_adam_op.step( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - True, - p_data.view(-1), # fp32 data - p_grad.view(-1), # fp32 grad - exp_avg.view(-1), - exp_avg_sq.view(-1), - -1, - ) - - torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - p_data_copy, # fp32 data - p_grad_copy, # fp32 grad - exp_avg_copy, - exp_avg_sq_copy, - adamw, - ) - var = p_data_copy - p_data - data_diff = torch.max(torch.abs(var)) - threshold = 1e-3 - assertLess( - data_diff, - threshold, - f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps " - f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}", - ) - max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad)) - assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}") - max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg)) - assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}") - max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq)) - assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}") - - -if __name__ == '__main__': - test_cpu_adam() diff --git a/tests/test_optimizer/test_fused_adam.py b/tests/test_optimizer/test_fused_adam.py deleted file mode 100644 index 114d5293d..000000000 --- a/tests/test_optimizer/test_fused_adam.py +++ /dev/null @@ -1,64 +0,0 @@ -import torch -import torch.nn as nn -from torch.optim import AdamW -from torch.optim.adam import Adam - -from colossalai.nn.optimizer.fused_adam import FusedAdam -from colossalai.testing import clear_cache_before_run, parameterize - - -class FC(nn.Module): - - def __init__(self) -> None: - super().__init__() - self.fc = nn.Sequential(nn.Linear(64, 64)) - - def forward(self, x): - return self.fc(x) - - -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, p_dtype, g_dtype): - model = FC().cuda().to(p_dtype) - state = model.state_dict() - model_copy = FC().cuda().to(p_dtype) - model_copy.load_state_dict(state.copy()) - - if adamw: - optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True) - torch_optim = AdamW(model_copy.parameters(), lr=1e-3) - else: - optim = FusedAdam(model.parameters(), lr=1e-3) - torch_optim = Adam(model_copy.parameters(), lr=1e-3) - - data = torch.rand(1024, 64).cuda().to(p_dtype) - data_copy = data.clone() - label = torch.rand(1024, 64).cuda().to(p_dtype) - - for d, l in zip(data, label): - y = model(d) - loss = ((l - y)**2).sum() - optim.zero_grad() - loss.backward() - if p_dtype != g_dtype: - for i in range(len(optim.param_groups[0]['params'])): - optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype) - optim.step() - - for d, l in zip(data_copy, label): - y = model_copy(d) - loss = ((l - y)**2).sum() - torch_optim.zero_grad() - loss.backward() - torch_optim.step() - - assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params']) - - for i in range(len(optim.param_groups[0]['params'])): - if torch.isnan(optim.param_groups[0]['params'][i]).any() \ - or torch.isnan(torch_optim.param_groups[0]['params'][i]).any(): - continue - assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3) diff --git a/tests/test_optimizer/test_fused_adam_kernel.py b/tests/test_optimizer/test_fused_adam_kernel.py deleted file mode 100644 index 4afa13349..000000000 --- a/tests/test_optimizer/test_fused_adam_kernel.py +++ /dev/null @@ -1,95 +0,0 @@ -import math - -import torch -import torch.nn as nn -from numpy import dtype - -from colossalai.testing import clear_cache_before_run, parameterize -from colossalai.utils import multi_tensor_applier - - -def torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - param, - grad, - exp_avg, - exp_avg_sq, - use_adamw, -): - bias_correction1 = 1 - beta1**step - bias_correction2 = 1 - beta2**step - - if weight_decay != 0: - if use_adamw: - # Perform stepweight decay - param.mul_(1 - lr * weight_decay) - else: - grad = grad.add(param, alpha=weight_decay) - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps) - - step_size = lr / bias_correction1 - - param.addcdiv_(exp_avg, denom, value=-step_size) - - -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('step', [1, 2]) -@parameterize('p_dtype', [torch.float, torch.half]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, step, p_dtype, g_dtype): - from colossalai.kernel.op_builder import FusedOptimBuilder - fused_optim = FusedOptimBuilder().load() - fused_adam = fused_optim.multi_tensor_adam - - dummy_overflow_buf = torch.cuda.IntTensor([0]) - - count = 0 - - for i in range(3): - p = torch.rand(64, dtype=p_dtype).cuda() - p_copy = p.clone().float() - g = torch.rand(p.shape, dtype=g_dtype).cuda() - g_copy = g.clone().float() - m = torch.rand(p.shape).cuda() - m_copy = m.clone() - v = torch.rand(p.shape).cuda() - v_copy = v.clone() - - lr = 1e-3 - beta1, beta2 = 0.9, 0.999 - eps = 1e-8 - weight_decay = 0 - - multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw, - True, weight_decay, -1) - - torch_adam_update( - step, - lr, - beta1, - beta2, - eps, - weight_decay, - p_copy, # fp32 data - g_copy, # fp32 grad - m_copy, - v_copy, - adamw, - ) - - if torch.isnan(p).any() or torch.isnan(p_copy).any(): - count += 1 - continue - assert count < 200, "too many nans" - assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5, - 1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_optimizer/test_hybrid_adam.py b/tests/test_optimizer/test_hybrid_adam.py deleted file mode 100644 index d075149df..000000000 --- a/tests/test_optimizer/test_hybrid_adam.py +++ /dev/null @@ -1,42 +0,0 @@ -import torch -import torch.nn as nn -from torch.optim import AdamW -from torch.optim.adam import Adam - -from colossalai.nn.optimizer.hybrid_adam import HybridAdam -from colossalai.testing import clear_cache_before_run, parameterize - -RE = 3 - - -@clear_cache_before_run() -@parameterize('adamw', [False, True]) -@parameterize('device', ['cpu', 'cuda:0']) -@parameterize('p_dtype', [torch.float]) -@parameterize('g_dtype', [torch.float, torch.half]) -def test_adam(adamw, device, p_dtype, g_dtype): - rng_state = torch.get_rng_state() - p = nn.Parameter(torch.rand(64).to(device, p_dtype)) - torch.set_rng_state(rng_state) - p_copy = nn.Parameter(torch.rand(64).to(device).float()) - - if adamw: - optim = HybridAdam([p], lr=1e-3, adamw_mode=True) - torch_optim = AdamW([p_copy], lr=1e-3) - else: - optim = HybridAdam([p], lr=1e-3) - torch_optim = Adam([p_copy], lr=1e-3) - - print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}") - for i in range(RE): - p.grad = torch.rand(64).to(device, p_dtype) - p_copy.grad = p.grad.clone().float() - p.grad.data = p.grad.data.to(g_dtype) - - optim.step() - torch_optim.step() - - if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any(): - continue - assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \ - f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}" diff --git a/tests/test_zero/test_gemini/test_optim.py b/tests/test_zero/test_gemini/test_optim.py index 8ce20c16e..66611bcd2 100644 --- a/tests/test_zero/test_gemini/test_optim.py +++ b/tests/test_zero/test_gemini/test_optim.py @@ -21,23 +21,40 @@ TEST_MODELS = ['gpt2'] # these models are too small, all parameters in these models are compacted into one chunk EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers'] +# bfloat16 cannot represent them exactly +BF16_IGNORED_KEYS = [ + 'albert.embeddings.word_embeddings.weight', + 'albert.embeddings.position_embeddings.weight', + 'masked_bias', +] -def check_param(model: ZeroDDP, torch_model: torch.nn.Module): - zero_dict = model.state_dict(only_rank_0=False) + +def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype): + zero_dict = model.state_dict(only_rank_0=False, dtype=dtype) torch_dict = torch_model.state_dict() for key, value in torch_dict.items(): # key is 'module.model.PARAMETER', so we truncate it key = key[7:] assert key in zero_dict, "{} not in ZeRO dictionary.".format(key) - temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) + temp_zero_value = zero_dict[key].to(device=value.device) + if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS): + continue + rtol, atol = 1e-3, 4e-3 + if dtype is torch.bfloat16: + rtol, atol = 4e-3, 8e-3 # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) - assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3) + assert_close(value.float(), + temp_zero_value.float(), + rtol=rtol, + atol=atol, + msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}') @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('model_name', TEST_MODELS) -def exam_model_step(placement_policy, model_name: str): +@parameterize('mixed_precision', [torch.half, torch.bfloat16]) +def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype): set_seed(42) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -65,7 +82,7 @@ def exam_model_step(placement_policy, model_name: str): init_device = None chunk_manager = ChunkManager(config_dict, init_device=init_device) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128) @@ -74,6 +91,7 @@ def exam_model_step(placement_policy, model_name: str): torch_model.eval() set_seed(dist.get_rank() * 3 + 128) + rtol, atol = 1e-4, 1e-5 for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break @@ -83,17 +101,18 @@ def exam_model_step(placement_policy, model_name: str): torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss) + assert_close(torch_loss, loss, rtol=rtol, atol=atol) zero_optim.step() torch_optim.step() - check_param(model, torch_model) + check_param(model, torch_model, mixed_precision) @parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const']) @parameterize('model_name', EXAMPLE_MODELS) -def exam_tiny_example(placement_policy, model_name: str): +@parameterize('mixed_precision', [torch.half, torch.bfloat16]) +def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype): set_seed(2008) get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -113,7 +132,7 @@ def exam_tiny_example(placement_policy, model_name: str): chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1) gemini_manager = GeminiManager(placement_policy, chunk_manager) - model = ZeroDDP(model, gemini_manager, pin_memory=True) + model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision) optimizer = HybridAdam(model.parameters(), lr=1e-3) zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2) @@ -121,6 +140,9 @@ def exam_tiny_example(placement_policy, model_name: str): torch_model.eval() set_seed(dist.get_rank() * 3 + 128) + rtol, atol = 1.5e-6, 2e-5 + if mixed_precision is torch.bfloat16: + rtol, atol = 2e-3, 2e-3 for i, (input_ids, label) in enumerate(train_dataloader): if i > 2: break @@ -133,12 +155,12 @@ def exam_tiny_example(placement_policy, model_name: str): torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) - assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12 + assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12 zero_optim.step() torch_optim.step() - check_param(model, torch_model) + check_param(model, torch_model, mixed_precision) def run_dist(rank, world_size, port): diff --git a/tests/test_zero/test_legacy/test_zero_engine.py b/tests/test_zero/test_legacy/test_zero_engine.py index dc8847ce5..826a543db 100644 --- a/tests/test_zero/test_legacy/test_zero_engine.py +++ b/tests/test_zero/test_legacy/test_zero_engine.py @@ -16,7 +16,11 @@ from colossalai.zero.low_level._utils import has_inf_or_nan from tests.components_to_test.registry import non_distributed_component_funcs -def run_dist(rank, world_size, port, parallel_config): +def run_dist(rank, world_size, port, parallel_config, bf16): + is_mp_config = parallel_config == MP_PARALLEL_CONFIG + is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG + if bf16: + parallel_config['zero']['model_config']['bf16'] = True colossalai.launch(config=parallel_config, rank=rank, world_size=world_size, @@ -30,7 +34,8 @@ def run_dist(rank, world_size, port, parallel_config): model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func() with ZeroInitContext(target_device=torch.cuda.current_device(), shard_strategy=gpc.config.zero.model_config.shard_strategy, - shard_param=True): + shard_param=True, + bf16=bf16): colo_model = model_builder(checkpoint=True) colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3) @@ -38,7 +43,8 @@ def run_dist(rank, world_size, port, parallel_config): optimizer=colo_optimizer, criterion=criterion, train_dataloader=train_dataloader) - torch_model = model_builder(checkpoint=True).half() + dtype = torch.bfloat16 if bf16 else torch.float16 + torch_model = model_builder(checkpoint=True).to(dtype) col_model_deepcopy(engine.model, torch_model) torch_model = torch_model.cuda().float() @@ -80,9 +86,9 @@ def run_dist(rank, world_size, port, parallel_config): torch_optimizer.step() i += 1 - if parallel_config == MP_PARALLEL_CONFIG: + if is_mp_config: check_params(torch_model, colo_model, loose=True) - elif parallel_config == ZERO_PARALLEL_CONFIG: + elif is_zero_config: check_sharded_model_params(torch_model, colo_model, loose=True) @@ -97,9 +103,10 @@ def test_mp_engine(world_size): @pytest.mark.dist @pytest.mark.parametrize("world_size", [1, 2]) +@pytest.mark.parametrize("bf16", [True, False]) @rerun_if_address_is_in_use() -def test_zero_engine(world_size): - spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG) +def test_zero_engine(world_size, bf16): + spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16) if __name__ == '__main__': diff --git a/tests/test_zero/test_low_level/test_grad_acc.py b/tests/test_zero/test_low_level/test_grad_acc.py index 2ae1f3a99..c264a8077 100644 --- a/tests/test_zero/test_low_level/test_grad_acc.py +++ b/tests/test_zero/test_low_level/test_grad_acc.py @@ -82,7 +82,6 @@ def exam_zero_1_2_grad_acc(): def exam_zero_1_grad_acc(): local_rank = torch.distributed.get_rank() - grad_scale = 32 seed_all(2008) # create models @@ -101,7 +100,6 @@ def exam_zero_1_grad_acc(): # level 1 and 2 will produce exactly the same results zero_optimizer = LowLevelZeroOptimizer(zero_optimizer, overlap_communication=False, - initial_scale=grad_scale, reduce_bucket_size=262144, clip_grad_norm=1.0) @@ -128,9 +126,8 @@ def exam_zero_1_grad_acc(): if check_flag: # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - unscale_grad = z1p.grad / grad_scale # print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad))) - assert torch.equal(p.grad, unscale_grad) + assert torch.equal(p.grad, z1p.grad) zero_optimizer._sync_grad() diff --git a/tests/test_zero/test_low_level/test_zero1_2.py b/tests/test_zero/test_low_level/test_zero1_2.py index 4086af9d8..8e2206fe6 100644 --- a/tests/test_zero/test_low_level/test_zero1_2.py +++ b/tests/test_zero/test_low_level/test_zero1_2.py @@ -7,7 +7,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close import colossalai -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing.random import seed_all from colossalai.zero import LowLevelZeroOptimizer @@ -25,15 +25,18 @@ class MlpModel(nn.Module): return x -def half_close(a, b, loose=False): +def loose_close(a, b, dtype: torch.dtype = torch.float32): rtol = None atol = None - if loose: + if dtype is torch.float16: rtol = 5e-2 atol = 5e-4 + elif dtype is torch.bfloat16: + rtol = 4e-3 + atol = 4e-3 - a = a.detach().half() - b = b.detach().half() + a = a.detach().to(dtype) + b = b.detach().to(dtype) assert_close(a, b, rtol=rtol, atol=atol) @@ -96,7 +99,8 @@ def exam_zero_1_2(): assert torch.equal(z1p.data, z2p.data) -def exam_zero_1_torch_ddp(): +@parameterize('dtype', [torch.float16, torch.bfloat16]) +def exam_zero_1_torch_ddp(dtype: torch.dtype): """ In this test, two pairs of model and optimizers are created. 1. zero: use sharded optimizer and fp16 parameters @@ -109,15 +113,10 @@ def exam_zero_1_torch_ddp(): seed_all(1453) # create models - zero_model = MlpModel() - torch_model = copy.deepcopy(zero_model) + torch_model = MlpModel().cuda() + zero_model = copy.deepcopy(torch_model).to(dtype) - zero_model = zero_model.cuda().half() - torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0) - torch_model = torch_model.cuda() - - # for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - # half_close(p.data, z1p.data) + torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda() # create optimizer zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1) @@ -137,11 +136,11 @@ def exam_zero_1_torch_ddp(): input_data = torch.rand(32, 128).cuda() # zero-dp forward - zero_output = zero_model(input_data.half()) + zero_output = zero_model(input_data.to(dtype)) # torch-ddp forward torch_output = torch_model(input_data) - half_close(zero_output, torch_output, loose=True) + loose_close(zero_output, torch_output, dtype=dtype) # zero-dp backward zero_optimizer.backward(zero_output.mean().float(), sync_grad=False) @@ -151,7 +150,7 @@ def exam_zero_1_torch_ddp(): # check grad for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): - half_close(p.grad, z1p.grad, loose=True) + loose_close(p.grad, z1p.grad, dtype=dtype) # zero-dp step zero_optimizer._sync_grad() @@ -163,7 +162,7 @@ def exam_zero_1_torch_ddp(): # check updated param for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()): # print(n, torch.max(torch.abs(p.data - z1p.data))) - half_close(p.data, z1p.data, loose=True) + loose_close(p.data, z1p.data, dtype=dtype) def run_dist(rank, world_size, port):