[bf16] add bf16 support (#3882)

* [bf16] add bf16 support for fused adam (#3844)

* [bf16] fused adam kernel support bf16

* [test] update fused adam kernel test

* [test] update fused adam test

* [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860)

* [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869)

* [bf16] add mixed precision mixin

* [bf16] low level zero optim support bf16

* [text] update low level zero test

* [text] fix low level zero grad acc test

* [bf16] add bf16 support for gemini (#3872)

* [bf16] gemini support bf16

* [test] update gemini bf16 test

* [doc] update gemini docstring

* [bf16] add bf16 support for plugins (#3877)

* [bf16] add bf16 support for legacy zero (#3879)

* [zero] init context support bf16

* [zero] legacy zero support bf16

* [test] add zero bf16 test

* [doc] add bf16 related docstring for legacy zero
pull/3898/head^2
Hongxin Liu 2023-06-05 15:58:31 +08:00 committed by GitHub
parent 07cb21142f
commit ae02d4e4f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 738 additions and 525 deletions

View File

@ -0,0 +1,9 @@
from .base import MixedPrecisionMixin
from .bf16 import BF16MixedPrecisionMixin
from .fp16 import FP16MixedPrecisionMixin
__all__ = [
'MixedPrecisionMixin',
'FP16MixedPrecisionMixin',
'BF16MixedPrecisionMixin',
]

View File

@ -0,0 +1,91 @@
from abc import ABC, abstractmethod
import torch
from torch import Tensor
class MixedPrecisionMixin(ABC):
"""A helper class for mixed precision training. This mixin is used in mixed precision optimizers.
Attributes:
dtype (torc.dtype): The expected dtype of the gradients.
Examples:
```python
class MyMixedPrecisionOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer):
super().__init__(optim)
self.mixed_precision = MixedPrecisionMixin()
def backward(self, loss):
loss = self.mixed_precision.pre_backward(loss)
loss.backward()
def backward_by_grad(self, tensor, grad):
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
tensor.backward(grad)
def step(self):
if self.mixed_precision.should_skip_step():
self.zero_grad()
return
div_scale = self.mixed_precision.get_grad_div_scale()
# maybe clip grad here
# maybe scale grad here
self.optim.step()
def zero_grad(self):
self.mixed_precision.pre_zero_grad()
return self.optim.zero_grad()
```
"""
dtype: torch.dtype
@abstractmethod
def pre_backward(self, loss: Tensor) -> Tensor:
"""Called before backward.
Args:
loss (Tensor): Loss value.
Returns:
Tensor: Loss value (possibly scaled).
"""
pass
@abstractmethod
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
"""Called before backward by grad. This is helpful for pipeline parallelism.
Args:
tensor (Tensor): Tensor to backward.
grad (Tensor): Gradient of the tensor.
Returns:
Tensor: Gradient of the tensor (possibly scaled).
"""
pass
@abstractmethod
def should_skip_step(self) -> bool:
"""Called before step.
Returns:
bool: Whether to skip the step.
"""
pass
@abstractmethod
def pre_zero_grad(self) -> None:
"""Called before zero_grad.
"""
pass
@abstractmethod
def get_grad_div_scale(self) -> float:
"""Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.
Returns:
float: A divisor for gradient clipping or step.
"""
pass

View File

@ -0,0 +1,23 @@
import torch
from torch import Tensor
from .base import MixedPrecisionMixin
class BF16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.bfloat16
def pre_backward(self, loss: Tensor) -> Tensor:
return loss
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
return grad
def should_skip_step(self) -> bool:
return False
def pre_zero_grad(self) -> None:
pass
def get_grad_div_scale(self) -> float:
return 1.0

View File

@ -0,0 +1,84 @@
from abc import abstractmethod
from enum import Enum
import torch
import torch.distributed as dist
from torch import Tensor
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.utils import get_current_device
from .base import MixedPrecisionMixin
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
dtype = torch.float16
def __init__(self,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__()
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self.optim_state = OptimState.UNSCALED
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
@property
def loss_scale(self) -> float:
return self.grad_scaler.scale.item()
@abstractmethod
def check_local_overflow(self) -> bool:
"""Check whether there is overflow in the local process. This method should be implemented by subclasses.
Returns:
bool: Whether there is overflow in the local process.
"""
pass
def check_overflow(self) -> bool:
# clear previous overflow record
self.found_overflow.fill_(0.0)
if self.check_local_overflow():
self.found_overflow.fill_(1.0)
dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX)
return self.found_overflow.item() > 0
def pre_backward(self, loss: Tensor) -> Tensor:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
return loss
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
self.optim_state = OptimState.SCALED
return grad
def should_skip_step(self) -> bool:
found_inf = self.check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self.optim_state = OptimState.UNSCALED
return found_inf
def pre_zero_grad(self) -> None:
pass
def get_grad_div_scale(self) -> float:
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
self.optim_state = OptimState.UNSCALED
return self.loss_scale

View File

@ -23,6 +23,9 @@ from .dp_plugin_base import DPPluginBase
__all__ = ['GeminiPlugin']
SUPPORTED_PRECISION = ['fp16', 'bf16']
PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}
class GeminiCheckpointIO(GeneralCheckpointIO):
@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase):
Args:
device (torch.device): device to place the model.
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
@ -203,6 +207,7 @@ class GeminiPlugin(DPPluginBase):
self,
device: Optional[torch.device] = None,
placement_policy: str = "cpu",
precision: str = "fp16",
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
@ -223,6 +228,7 @@ class GeminiPlugin(DPPluginBase):
verbose: bool = False,
) -> None:
super().__init__()
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
self.gemini_config = dict(
device=(device or get_current_device()),
placement_policy=placement_policy,
@ -233,6 +239,7 @@ class GeminiPlugin(DPPluginBase):
hidden_dim=hidden_dim,
min_chunk_size_mb=min_chunk_size_mb,
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
)
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
self.optim_kwargs = dict(initial_scale=initial_scale,
@ -253,7 +260,7 @@ class GeminiPlugin(DPPluginBase):
return True
def supported_precisions(self) -> List[str]:
return ['fp16']
return SUPPORTED_PRECISION
def control_device(self) -> bool:
return True

View File

@ -1,4 +1,5 @@
import warnings
from functools import partial
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
@ -20,12 +21,15 @@ from .torch_ddp_plugin import TorchDDPCheckpointIO
__all__ = ['LowLevelZeroPlugin']
def _convert_to_fp16(x):
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.half()
return x.to(dtype)
return x
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
@ -49,17 +53,24 @@ class LowLevelZeroModel(ModelWrapper):
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
super().__init__(module)
self.convert_inputs = (precision == 'fp16')
module = zero_model_wrapper(module, zero_stage=stage)
self.dtype = None
if precision == 'fp16':
module = module.half()
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
module = zero_model_wrapper(module, zero_stage=stage)
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs):
if self.convert_inputs:
args = tree_map(_convert_to_fp16, args)
kwargs = tree_map(_convert_to_fp16, kwargs)
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
@ -110,7 +121,7 @@ class LowLevelZeroPlugin(DPPluginBase):
Args:
strage (int, optional): ZeRO stage. Defaults to 1.
precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'.
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
@ -149,7 +160,7 @@ class LowLevelZeroPlugin(DPPluginBase):
) -> None:
super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
self.stage = stage
self.precision = precision
@ -175,7 +186,7 @@ class LowLevelZeroPlugin(DPPluginBase):
return True
def supported_precisions(self) -> List[str]:
return ['fp16', 'fp32']
return SUPPORTED_PRECISION
def control_device(self) -> bool:
return True

View File

@ -171,6 +171,21 @@
using g_scalar_t_##LEVEL = at::Half; \
using p_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::Float) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::BFloat16 && \
PTYPE == at::ScalarType::BFloat16) { \
using g_scalar_t_##LEVEL = at::BFloat16; \
using p_scalar_t_##LEVEL = at::BFloat16; \
__VA_ARGS__; \
} else { \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
"'"); \

View File

@ -93,8 +93,7 @@ class CPUAdam(NVMeOptimizer):
bias_correction1,
bias_correction2,
use_adamw=False):
# FIXME(ver217): remove the below line when replace torch adam with fused adam
grad = grad.float()
grad = grad.to(data.dtype)
if weight_decay != 0:
if use_adamw:
@ -133,10 +132,12 @@ class CPUAdam(NVMeOptimizer):
if len(state) == 0:
state['step'] = 0
# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
# gradient momentums
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg'] = torch.zeros_like(p, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
self._post_state_init(p)
state['step'] += 1
@ -147,9 +148,17 @@ class CPUAdam(NVMeOptimizer):
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], div_scale)
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
bias_correction2, self.adamw_mode)
else:
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':
assert div_scale == -1, "div_scale should remain default"

View File

@ -134,8 +134,8 @@ class FusedAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p)
if p.dtype not in [torch.float16, torch.float32]:
raise RuntimeError('FusedAdam only support fp16 and fp32.')
if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]:
raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.')
g_l.append(p.grad.data)
p_l.append(p.data)

View File

@ -1,16 +1,17 @@
from typing import Any, Optional
import torch
from torch.optim import Adam
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.registry import OPTIMIZERS
from colossalai.utils import multi_tensor_applier
from .nvme_optimizer import NVMeOptimizer
from .cpu_adam import CPUAdam
@OPTIMIZERS.register_module
class HybridAdam(NVMeOptimizer):
class HybridAdam(CPUAdam):
"""Implements Adam algorithm.
Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
@ -74,15 +75,9 @@ class HybridAdam(NVMeOptimizer):
nvme_offload_dir: Optional[str] = None,
**defaults: Any):
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
self.adamw_mode = adamw_mode
# build during runtime if not found
cpu_optim = CPUAdamBuilder().load()
super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction,
nvme_offload_dir)
fused_optim = FusedOptimBuilder().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
self.gpu_adam_op = fused_optim.multi_tensor_adam
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
@ -108,10 +103,12 @@ class HybridAdam(NVMeOptimizer):
if len(state) == 0:
state['step'] = 0
# FIXME(ver217): CPU adam kernel only supports fp32 states now
assert p.dtype is torch.float, "HybridAdam only support fp32 parameters"
# gradient momentums
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg'] = torch.zeros_like(p, device=target_device)
# gradient variances
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
self._post_state_init(p)
state['step'] += 1
@ -122,9 +119,17 @@ class HybridAdam(NVMeOptimizer):
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
state['exp_avg_sq'], div_scale)
if p.grad.dtype is torch.bfloat16:
# cpu adam kernel does not support bf16 now
bias_correction1 = 1 - beta1**state['step']
bias_correction2 = 1 - beta2**state['step']
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
bias_correction2, self.adamw_mode)
else:
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
state['exp_avg'], state['exp_avg_sq'], div_scale)
self._post_update(p, 'exp_avg', 'exp_avg_sq')
elif target_device.type == 'cuda':

View File

@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP):
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
"""
def __init__(self,
@ -59,7 +60,9 @@ class ZeroDDP(ColoDDP):
pin_memory: bool = False,
force_outputs_fp32: bool = False,
strict_ddp_mode: bool = False,
scatter_after_inference: bool = True) -> None:
scatter_after_inference: bool = True,
mixed_precision: torch.dtype = torch.float16) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
self.gemini_manager = gemini_manager
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
@ -71,6 +74,7 @@ class ZeroDDP(ColoDDP):
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
self.scatter_after_inference = scatter_after_inference
self.mixed_precision = mixed_precision
self._logger = get_dist_logger()
@ -151,7 +155,7 @@ class ZeroDDP(ColoDDP):
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
self.module.zero_grad(set_to_none=True)
if not grad_flag:
outputs = self._inference_forward(*args, **kwargs)
@ -570,14 +574,14 @@ class ZeroDDP(ColoDDP):
# move ignored parameters to CUDA
if is_ddp_ignored(p):
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
continue
# create a fp32 parameter
fp32_data = p.data.float()
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
# create a fp16 parameter
p.data = p.data.half()
p.data = p.data.to(self.mixed_precision)
# register the fp16 parameter and fp32 parameter in the chunk manager
dp_world_size = p.process_group.dp_world_size()
@ -613,7 +617,7 @@ class ZeroDDP(ColoDDP):
buffer.materialize()
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()
buffer.data = buffer.to(self.mixed_precision)
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
"""Convert parameter to ColoParameter in-place.
@ -736,6 +740,7 @@ class GeminiDDP(ZeroDDP):
hidden_dim: Optional[int] = None,
min_chunk_size_mb: float = 32,
memstats: Optional[MemStats] = None,
mixed_precision: torch.dtype = torch.float16,
verbose: bool = False) -> None:
"""
A torch.Module wrapper using ZeRO-DP and Gemini.
@ -776,5 +781,10 @@ class GeminiDDP(ZeroDDP):
strict_ddp_flag=strict_ddp_mode,
verbose=verbose)
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
scatter_after_inference)
super().__init__(module,
gemini_manager,
pin_memory,
force_outputs_fp32,
strict_ddp_mode,
scatter_after_inference,
mixed_precision=mixed_precision)

View File

@ -1,7 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import math
import warnings
from enum import Enum
from typing import Any, Dict, Set, Tuple
import torch
@ -9,7 +8,7 @@ import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
@ -22,9 +21,26 @@ __all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
class OptimState(Enum):
SCALED = 0
UNSCALED = 1
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
module: ZeroDDP,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.module = module
def check_local_overflow(self) -> bool:
return self.module.overflow_counter > 0
def pre_zero_grad(self) -> None:
self.module.overflow_counter = 0
class ZeroOptimizer(ColossalaiOptimizer):
@ -79,7 +95,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.module = module
self.gemini_manager = module.gemini_manager
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
self.optim_state = OptimState.UNSCALED
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
self.chunk16_set: Set[Chunk] = set()
@ -107,15 +122,20 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.__init__optimizer()
# Grad scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
if module.mixed_precision is torch.float16:
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif module.mixed_precision is torch.bfloat16:
self.mix_precision_mixin = BF16MixedPrecisionMixin()
else:
raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}")
self._logger = get_dist_logger()
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
@ -151,15 +171,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
for chunk16 in self.chunk16_set:
chunk16.optim_update()
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(self.module.overflow_counter)
# all-reduce across global group
dist.all_reduce(self._found_overflow)
return self._found_overflow.item() > 0
def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
c16.l2_norm = None
@ -190,40 +201,25 @@ class ZeroOptimizer(ColossalaiOptimizer):
return global_norm
def _get_combined_scale(self):
loss_scale = 1
div_scale = self.mix_precision_mixin.get_grad_div_scale()
if self.optim_state == OptimState.SCALED:
loss_scale = self.loss_scale
self.optim_state = OptimState.UNSCALED
combined_scale = loss_scale
if self.clipping_flag:
total_norm = self._calc_global_norm()
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
if clip > 1:
combined_scale = clip * loss_scale
div_scale = clip * div_scale
if combined_scale == 1:
return -1
else:
return combined_scale
@property
def loss_scale(self):
return self.grad_scaler.scale.item()
return -1 if div_scale == 1.0 else div_scale
def zero_grad(self, *args, **kwargs):
self.module.overflow_counter = 0
self.mix_precision_mixin.pre_zero_grad()
return self.optim.zero_grad(set_to_none=True)
def step(self, *args, **kwargs):
self._maybe_move_fp32_params()
self._set_grad_ptr()
found_inf = self._check_overflow()
if found_inf:
self.optim_state = OptimState.UNSCALED # no need to unscale grad
self.grad_scaler.update(found_inf) # update gradient scaler
if self.mix_precision_mixin.should_skip_step():
if self.verbose:
self._logger.info(f'Found overflow. Skip step')
self._clear_global_norm() # clear recorded norm
@ -234,7 +230,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
# get combined scale. combined scale = loss scale * clipping norm
# so that gradient = gradient / combined scale
combined_scale = self._get_combined_scale()
self.grad_scaler.update(found_inf)
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
self._register_states()
@ -246,8 +241,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
raise NotImplementedError
def backward(self, loss: torch.Tensor):
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
loss = self.mix_precision_mixin.pre_backward(loss)
self.module.backward(loss)
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
@ -255,7 +249,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
self.module.backward_by_grad(tensor, grad)
def _maybe_move_fp32_params(self):

View File

@ -14,7 +14,7 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
from colossalai.zero.legacy.sharded_param import ShardedParamV2
@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
seed (int, optional): Random seed for weight initialization
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
"""
@ -64,6 +65,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
seed: int = 2**10 - 1,
shard_param: bool = False,
default_dtype: Optional[torch.dtype] = None,
bf16: bool = False,
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
super().__init__(default_dtype=default_dtype)
@ -71,6 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.param_list = []
self.model_numel_tensor = model_numel_tensor
self.seed = seed
self.bf16 = bf16
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
@ -183,9 +186,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
NOTE() The module may be passed to this function multiple times.
"""
self.top_module = module
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
def half_fn(t: torch.Tensor):
return t.half() if t.is_floating_point() else t
return t.to(half_dtype) if t.is_floating_point() else t
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
@ -226,9 +230,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
for buffer in module.buffers(recurse=False):
buffer.data = buffer.data.to(device=torch.cuda.current_device())
buffer.data = cast_tensor_to_fp16(buffer.data)
buffer.data = cast_fn(buffer.data)
class ZeroContextMgr(metaclass=SingletonMeta):

View File

@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
return tensor.float()
return tensor
def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
return tensor.bfloat16()
return tensor
def apply_to_tensors(x: Any, fn: Callable):
if torch.is_tensor(x):
return fn(x)

View File

@ -28,6 +28,7 @@ from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBuc
from ._utils import (
cast_float_arguments,
cast_tensor_to_bf16,
cast_tensor_to_fp16,
cast_tensor_to_fp32,
chunk_and_pad,
@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module):
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
We find that PyTorch's optimizers don't support mixed precision,
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
"""
def __init__(self,
@ -86,11 +88,13 @@ class ShardedModelV2(nn.Module):
tensor_placement_policy: str = 'cuda',
gradient_predivide_factor: Optional[float] = 1.0,
reuse_fp16_shard: bool = False,
bf16: bool = False,
*args,
**kwargs):
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
super().__init__()
self.logger = get_dist_logger()
self.bf16 = bf16
# We force users to use ZeroInitContext
for submodule in module.modules():
@ -232,7 +236,8 @@ class ShardedModelV2(nn.Module):
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations(*args)
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
outputs = self.module(*args, **kwargs)
self._post_forward_operations()
return outputs

View File

@ -94,6 +94,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
super().__init__(optimizer)
self.shard_strategy = sharded_model.shard_strategy
self.model: ShardedModelV2 = sharded_model
self.bf16 = sharded_model.bf16
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
@ -117,6 +118,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
self._logger = get_dist_logger("ShardedOptimizerV2")
self._verbose = verbose
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
# Store fp32 param shards
self._register_master_weight()
@ -166,8 +168,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._zero_grad()
def backward(self, loss: Tensor) -> None:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
if not self.bf16:
loss = self.loss_scale * loss
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward(loss)
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
@ -175,30 +179,33 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# It receives the scaled grad from the previous rank
# No need to scale the grad again
# Need to unscale when optimizing
self.optim_state = OptimState.SCALED
if not self.bf16:
self.optim_state = OptimState.SCALED
self._grad_prepared = False
self.model.backward_by_grad(tensor, grad)
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
self._prepare_grads()
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()
return super().clip_grad_norm(model, max_norm)
def step(self, *args, **kwargs):
self._prepare_grads()
# unscale grads if scaled
if self.optim_state == OptimState.SCALED:
self._prepare_grads()
if not self.bf16 and self.optim_state == OptimState.SCALED:
self._unscale_grads()
self._maybe_move_fp32_shards()
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if not self.bf16:
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
if found_inf:
self._logger.warning('found inf during ShardedOptimV2 step')
self._zero_grad(recover_data=True)
return
if found_inf:
self._logger.warning('found inf during ShardedOptimV2 step')
self._zero_grad(recover_data=True)
return
self._point_param_fp16_to_master_param()
@ -304,6 +311,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
state[k] = v.cuda()
def _prepare_grads(self):
if self._grad_prepared:
return
for group in self.optim.param_groups:
for p in group['params']:
if p.colo_attr.saved_grad.is_null():
@ -320,6 +329,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
p.grad = p.colo_attr.grad_payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.set_data_none()
self._grad_prepared = True
def _point_param_fp16_to_master_param(self):
# assign master param pointers to p.data.
@ -357,7 +367,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
p.colo_attr.set_data_none()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:

View File

@ -6,7 +6,11 @@ import torch
import torch.distributed as dist
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.amp.naive_amp.mixed_precision_mixin import (
BF16MixedPrecisionMixin,
FP16MixedPrecisionMixin,
MixedPrecisionMixin,
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
@ -27,6 +31,31 @@ from ._utils import (
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
def __init__(self,
num_working_param_groups: int,
grad_store: GradientStore,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32) -> None:
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
max_scale)
self.num_working_param_groups = num_working_param_groups
self.grad_store = grad_store
def check_local_overflow(self) -> bool:
for group_id in range(self.num_working_param_groups):
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
return True
return False
class LowLevelZeroOptimizer(ColossalaiOptimizer):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
@ -100,17 +129,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
# gradient scaler
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
verbose=verbose)
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
# gradient clipping
self._clip_grad_norm = clip_grad_norm
@ -200,14 +218,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
if self._overlap_communication or self._partition_grads:
self._attach_reduction_hook()
# initialize mixed precision mixin
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
if self._dtype is torch.float16:
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
self._grad_store,
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale)
elif self._dtype is torch.bfloat16:
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
@property
def dtype(self):
return self._dtype
@property
def loss_scale(self):
return self.grad_scaler.scale
@property
def num_param_groups(self):
return len(self._working_param_groups)
@ -392,7 +421,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
################################
def backward(self, loss, retain_graph=False, sync_grad=True):
loss = self.loss_scale * loss
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph)
# finish gradient reduction
@ -419,6 +449,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
:param set_to_none: Whether set the gradient to None. Default value is True.
:type set_to_none: bool
"""
if self.mixed_precision_mixin is not None:
self.mixed_precision_mixin.pre_zero_grad()
for _, param_group in self._working_param_groups.items():
for param in param_group:
if set_to_none:
@ -435,12 +467,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
# check for overflow
found_inf = self._check_overflow()
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
self._grad_store.reset_all_average_gradients()
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
@ -507,41 +534,20 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# Mixed Precision Utilities #
#############################
def _check_overflow(self):
# clear previous overflow record
self._found_overflow.fill_(0.0)
# check for overflow
for group_id in range(len(self._working_param_groups)):
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
break
# all-reduce across dp group
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
# all-reduce over model parallel group
if self._mp_torch_group:
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
if self._found_overflow.item() > 0:
return True
else:
return False
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
# compute combined scale factor for this group
combined_scale = self.loss_scale
div_scale = 1.0
if self.mixed_precision_mixin is not None:
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
if self._clip_grad_norm > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
if clip > 1:
combined_scale = clip * self.loss_scale
div_scale = clip * div_scale
for grad in grad_groups_flat:
grad.data.mul_(1. / combined_scale)
grad.data.mul_(1. / div_scale)
############################
# Gradient Synchronization #

View File

@ -0,0 +1,131 @@
# This test checks adam kernels
# Baseline is pure fp32 torch adam optimizer
import math
from abc import abstractmethod
from typing import Type
import pytest
import torch
from torch import Tensor
from colossalai.utils import get_current_device, multi_tensor_applier
_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
(torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16),
(torch.bfloat16, torch.bfloat16)]
_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
(torch.half, torch.half)]
class AdamKernel:
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.weight_decay = weight_decay
self.use_adamw = use_adamw
@abstractmethod
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
pass
class TorchAdamKernel(AdamKernel):
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
bias_correction1 = 1 - self.beta1**step
bias_correction2 = 1 - self.beta2**step
if self.weight_decay != 0:
if self.use_adamw:
# Perform stepweight decay
param.mul_(1 - self.lr * self.weight_decay)
else:
grad = grad.add(param, alpha=self.weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
step_size = self.lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
class FusedAdamKernel(AdamKernel):
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
self.fused_adam = fused_optim.multi_tensor_adam
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]],
self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay,
-1)
class CPUAdamKernel(AdamKernel):
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
from colossalai.kernel.op_builder import CPUAdamBuilder
cpu_optim = CPUAdamBuilder().load()
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1),
grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1)
def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype,
g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float):
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw)
adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw)
master_p = torch.rand(64, device=device)
master_g = torch.rand_like(master_p)
master_exp_avg = torch.zeros_like(master_p)
master_exp_avg_sq = torch.zeros_like(master_p)
p = master_p.clone().to(p_dtype)
g = master_g.clone().to(g_dtype)
exp_avg = master_exp_avg.clone()
exp_avg_sq = master_exp_avg_sq.clone()
for step in range(1, 1 + n_steps):
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
adam_kernel.update(step, p, g, exp_avg, exp_avg_sq)
# if overflow, the weight won't be updated. so there will be no nan in p
assert not torch.isnan(p).any()
assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol)
@pytest.mark.parametrize('adamw', [False, True])
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES)
def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
rtol, atol = 1e-5, 1e-8
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 1e-3, 1e-3
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
rtol, atol = 4e-3, 4e-3
check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol)
@pytest.mark.parametrize('adamw', [False, True])
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES)
def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
rtol, atol = 1e-5, 1e-8
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 1e-3, 1e-3
check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol)

View File

@ -0,0 +1,86 @@
from copy import deepcopy
from typing import Type, Union
import pytest
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
from tests.kit.model_zoo import model_zoo
_ALLOWED_OPTIM_DEVICES = [
(FusedAdam, torch.device('cuda:0')),
(CPUAdam, torch.device('cpu')),
(CPUAdam, torch.device('cuda:0')),
(HybridAdam, torch.device('cpu')),
(HybridAdam, torch.device('cuda:0')),
]
_ALLOWED_P_G_TYPES = [
(torch.float, torch.float), # pure fp32
(torch.float, torch.half), # fp16 amp
(torch.float, torch.bfloat16), # bfloat16 amp
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
]
N_STEPS = 3
def setup_param_groups(bert_model: nn.Module) -> list:
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": 0.1,
},
{
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
return optimizer_grouped_parameters
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
torch_p.grad = torch.rand_like(torch_p)
# avoid inconsistent grad and param dtype error
orig_p = p.data
p.data = torch_p.grad.clone().to(g_dtype)
p.grad = p.data
p.data = orig_p
@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES)
@pytest.mark.parametrize('adamw', [False, True])
@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES)
def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device,
adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None:
model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values()))
torch_model = model_fn().to(device)
model = deepcopy(torch_model).to(p_dtype)
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
torch_optim_cls = AdamW if adamw else Adam
torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)
rtol, atol = 1e-5, 1e-5
if p_dtype is torch.float16 or g_dtype is torch.float16:
rtol, atol = 2e-3, 2e-3
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
rtol, atol = 4e-3, 4e-3
for _ in range(N_STEPS):
set_grad(model, torch_model, g_dtype)
torch_optim.step()
optim.step()
torch_optim.zero_grad()
optim.zero_grad()
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
# if overflow, the weight won't be updated. so there will be no nan in p
assert not torch.isnan(p).any()
assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)

View File

@ -1,121 +0,0 @@
import math
import torch
from colossalai.testing import clear_cache_before_run, parameterize
def torch_adam_update(
step,
lr,
beta1,
beta2,
eps,
weight_decay,
param,
grad,
exp_avg,
exp_avg_sq,
use_adamw,
):
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
if weight_decay != 0:
if use_adamw:
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
def assertLess(data_diff, threshold, msg):
assert data_diff < threshold, msg
def assertTrue(condition, msg):
assert condition, msg
@clear_cache_before_run()
@parameterize('adamw', [True, False])
@parameterize('step', [1, 2])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_cpu_adam(adamw, step, p_dtype, g_dtype):
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
weight_decay = 0
for i in range(3):
p_data = torch.rand(64, dtype=p_dtype)
p_data_copy = p_data.clone().float()
p_grad = torch.rand(64, dtype=g_dtype)
p_grad_copy = p_grad.clone().float()
exp_avg = torch.rand(p_data.shape)
exp_avg_copy = exp_avg.clone()
exp_avg_sq = torch.rand(p_data.shape)
exp_avg_sq_copy = exp_avg_sq.clone()
from colossalai.kernel.op_builder import CPUAdamBuilder
cpu_optim = CPUAdamBuilder().load()
cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
cpu_adam_op.step(
step,
lr,
beta1,
beta2,
eps,
weight_decay,
True,
p_data.view(-1), # fp32 data
p_grad.view(-1), # fp32 grad
exp_avg.view(-1),
exp_avg_sq.view(-1),
-1,
)
torch_adam_update(
step,
lr,
beta1,
beta2,
eps,
weight_decay,
p_data_copy, # fp32 data
p_grad_copy, # fp32 grad
exp_avg_copy,
exp_avg_sq_copy,
adamw,
)
var = p_data_copy - p_data
data_diff = torch.max(torch.abs(var))
threshold = 1e-3
assertLess(
data_diff,
threshold,
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps "
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
)
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
if __name__ == '__main__':
test_cpu_adam()

View File

@ -1,64 +0,0 @@
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.adam import Adam
from colossalai.nn.optimizer.fused_adam import FusedAdam
from colossalai.testing import clear_cache_before_run, parameterize
class FC(nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc = nn.Sequential(nn.Linear(64, 64))
def forward(self, x):
return self.fc(x)
@clear_cache_before_run()
@parameterize('adamw', [False, True])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, p_dtype, g_dtype):
model = FC().cuda().to(p_dtype)
state = model.state_dict()
model_copy = FC().cuda().to(p_dtype)
model_copy.load_state_dict(state.copy())
if adamw:
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
else:
optim = FusedAdam(model.parameters(), lr=1e-3)
torch_optim = Adam(model_copy.parameters(), lr=1e-3)
data = torch.rand(1024, 64).cuda().to(p_dtype)
data_copy = data.clone()
label = torch.rand(1024, 64).cuda().to(p_dtype)
for d, l in zip(data, label):
y = model(d)
loss = ((l - y)**2).sum()
optim.zero_grad()
loss.backward()
if p_dtype != g_dtype:
for i in range(len(optim.param_groups[0]['params'])):
optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
optim.step()
for d, l in zip(data_copy, label):
y = model_copy(d)
loss = ((l - y)**2).sum()
torch_optim.zero_grad()
loss.backward()
torch_optim.step()
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
for i in range(len(optim.param_groups[0]['params'])):
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
continue
assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)

View File

@ -1,95 +0,0 @@
import math
import torch
import torch.nn as nn
from numpy import dtype
from colossalai.testing import clear_cache_before_run, parameterize
from colossalai.utils import multi_tensor_applier
def torch_adam_update(
step,
lr,
beta1,
beta2,
eps,
weight_decay,
param,
grad,
exp_avg,
exp_avg_sq,
use_adamw,
):
bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
if weight_decay != 0:
if use_adamw:
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
else:
grad = grad.add(param, alpha=weight_decay)
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
step_size = lr / bias_correction1
param.addcdiv_(exp_avg, denom, value=-step_size)
@clear_cache_before_run()
@parameterize('adamw', [False, True])
@parameterize('step', [1, 2])
@parameterize('p_dtype', [torch.float, torch.half])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, step, p_dtype, g_dtype):
from colossalai.kernel.op_builder import FusedOptimBuilder
fused_optim = FusedOptimBuilder().load()
fused_adam = fused_optim.multi_tensor_adam
dummy_overflow_buf = torch.cuda.IntTensor([0])
count = 0
for i in range(3):
p = torch.rand(64, dtype=p_dtype).cuda()
p_copy = p.clone().float()
g = torch.rand(p.shape, dtype=g_dtype).cuda()
g_copy = g.clone().float()
m = torch.rand(p.shape).cuda()
m_copy = m.clone()
v = torch.rand(p.shape).cuda()
v_copy = v.clone()
lr = 1e-3
beta1, beta2 = 0.9, 0.999
eps = 1e-8
weight_decay = 0
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
True, weight_decay, -1)
torch_adam_update(
step,
lr,
beta1,
beta2,
eps,
weight_decay,
p_copy, # fp32 data
g_copy, # fp32 grad
m_copy,
v_copy,
adamw,
)
if torch.isnan(p).any() or torch.isnan(p_copy).any():
count += 1
continue
assert count < 200, "too many nans"
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5,
1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"

View File

@ -1,42 +0,0 @@
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.adam import Adam
from colossalai.nn.optimizer.hybrid_adam import HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize
RE = 3
@clear_cache_before_run()
@parameterize('adamw', [False, True])
@parameterize('device', ['cpu', 'cuda:0'])
@parameterize('p_dtype', [torch.float])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, device, p_dtype, g_dtype):
rng_state = torch.get_rng_state()
p = nn.Parameter(torch.rand(64).to(device, p_dtype))
torch.set_rng_state(rng_state)
p_copy = nn.Parameter(torch.rand(64).to(device).float())
if adamw:
optim = HybridAdam([p], lr=1e-3, adamw_mode=True)
torch_optim = AdamW([p_copy], lr=1e-3)
else:
optim = HybridAdam([p], lr=1e-3)
torch_optim = Adam([p_copy], lr=1e-3)
print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}")
for i in range(RE):
p.grad = torch.rand(64).to(device, p_dtype)
p_copy.grad = p.grad.clone().float()
p.grad.data = p.grad.data.to(g_dtype)
optim.step()
torch_optim.step()
if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
continue
assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"

View File

@ -21,23 +21,40 @@ TEST_MODELS = ['gpt2']
# these models are too small, all parameters in these models are compacted into one chunk
EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers']
# bfloat16 cannot represent them exactly
BF16_IGNORED_KEYS = [
'albert.embeddings.word_embeddings.weight',
'albert.embeddings.position_embeddings.weight',
'masked_bias',
]
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
zero_dict = model.state_dict(only_rank_0=False)
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
torch_dict = torch_model.state_dict()
for key, value in torch_dict.items():
# key is 'module.model.PARAMETER', so we truncate it
key = key[7:]
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
temp_zero_value = zero_dict[key].to(device=value.device)
if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS):
continue
rtol, atol = 1e-3, 4e-3
if dtype is torch.bfloat16:
rtol, atol = 4e-3, 8e-3
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
assert_close(value.float(),
temp_zero_value.float(),
rtol=rtol,
atol=atol,
msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}')
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', TEST_MODELS)
def exam_model_step(placement_policy, model_name: str):
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -65,7 +82,7 @@ def exam_model_step(placement_policy, model_name: str):
init_device = None
chunk_manager = ChunkManager(config_dict, init_device=init_device)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
@ -74,6 +91,7 @@ def exam_model_step(placement_policy, model_name: str):
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
rtol, atol = 1e-4, 1e-5
for i, (input_ids, label) in enumerate(train_dataloader):
if i > 2:
break
@ -83,17 +101,18 @@ def exam_model_step(placement_policy, model_name: str):
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss)
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
check_param(model, torch_model, mixed_precision)
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
@parameterize('model_name', EXAMPLE_MODELS)
def exam_tiny_example(placement_policy, model_name: str):
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype):
set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -113,7 +132,7 @@ def exam_tiny_example(placement_policy, model_name: str):
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pin_memory=True)
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
optimizer = HybridAdam(model.parameters(), lr=1e-3)
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
@ -121,6 +140,9 @@ def exam_tiny_example(placement_policy, model_name: str):
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
rtol, atol = 1.5e-6, 2e-5
if mixed_precision is torch.bfloat16:
rtol, atol = 2e-3, 2e-3
for i, (input_ids, label) in enumerate(train_dataloader):
if i > 2:
break
@ -133,12 +155,12 @@ def exam_tiny_example(placement_policy, model_name: str):
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12
assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
check_param(model, torch_model, mixed_precision)
def run_dist(rank, world_size, port):

View File

@ -16,7 +16,11 @@ from colossalai.zero.low_level._utils import has_inf_or_nan
from tests.components_to_test.registry import non_distributed_component_funcs
def run_dist(rank, world_size, port, parallel_config):
def run_dist(rank, world_size, port, parallel_config, bf16):
is_mp_config = parallel_config == MP_PARALLEL_CONFIG
is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG
if bf16:
parallel_config['zero']['model_config']['bf16'] = True
colossalai.launch(config=parallel_config,
rank=rank,
world_size=world_size,
@ -30,7 +34,8 @@ def run_dist(rank, world_size, port, parallel_config):
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
with ZeroInitContext(target_device=torch.cuda.current_device(),
shard_strategy=gpc.config.zero.model_config.shard_strategy,
shard_param=True):
shard_param=True,
bf16=bf16):
colo_model = model_builder(checkpoint=True)
colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
@ -38,7 +43,8 @@ def run_dist(rank, world_size, port, parallel_config):
optimizer=colo_optimizer,
criterion=criterion,
train_dataloader=train_dataloader)
torch_model = model_builder(checkpoint=True).half()
dtype = torch.bfloat16 if bf16 else torch.float16
torch_model = model_builder(checkpoint=True).to(dtype)
col_model_deepcopy(engine.model, torch_model)
torch_model = torch_model.cuda().float()
@ -80,9 +86,9 @@ def run_dist(rank, world_size, port, parallel_config):
torch_optimizer.step()
i += 1
if parallel_config == MP_PARALLEL_CONFIG:
if is_mp_config:
check_params(torch_model, colo_model, loose=True)
elif parallel_config == ZERO_PARALLEL_CONFIG:
elif is_zero_config:
check_sharded_model_params(torch_model, colo_model, loose=True)
@ -97,9 +103,10 @@ def test_mp_engine(world_size):
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2])
@pytest.mark.parametrize("bf16", [True, False])
@rerun_if_address_is_in_use()
def test_zero_engine(world_size):
spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG)
def test_zero_engine(world_size, bf16):
spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16)
if __name__ == '__main__':

View File

@ -82,7 +82,6 @@ def exam_zero_1_2_grad_acc():
def exam_zero_1_grad_acc():
local_rank = torch.distributed.get_rank()
grad_scale = 32
seed_all(2008)
# create models
@ -101,7 +100,6 @@ def exam_zero_1_grad_acc():
# level 1 and 2 will produce exactly the same results
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
initial_scale=grad_scale,
reduce_bucket_size=262144,
clip_grad_norm=1.0)
@ -128,9 +126,8 @@ def exam_zero_1_grad_acc():
if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
unscale_grad = z1p.grad / grad_scale
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, unscale_grad)
assert torch.equal(p.grad, z1p.grad)
zero_optimizer._sync_grad()

View File

@ -7,7 +7,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
@ -25,15 +25,18 @@ class MlpModel(nn.Module):
return x
def half_close(a, b, loose=False):
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if loose:
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
a = a.detach().half()
b = b.detach().half()
a = a.detach().to(dtype)
b = b.detach().to(dtype)
assert_close(a, b, rtol=rtol, atol=atol)
@ -96,7 +99,8 @@ def exam_zero_1_2():
assert torch.equal(z1p.data, z2p.data)
def exam_zero_1_torch_ddp():
@parameterize('dtype', [torch.float16, torch.bfloat16])
def exam_zero_1_torch_ddp(dtype: torch.dtype):
"""
In this test, two pairs of model and optimizers are created.
1. zero: use sharded optimizer and fp16 parameters
@ -109,15 +113,10 @@ def exam_zero_1_torch_ddp():
seed_all(1453)
# create models
zero_model = MlpModel()
torch_model = copy.deepcopy(zero_model)
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model).to(dtype)
zero_model = zero_model.cuda().half()
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
torch_model = torch_model.cuda()
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# half_close(p.data, z1p.data)
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda()
# create optimizer
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
@ -137,11 +136,11 @@ def exam_zero_1_torch_ddp():
input_data = torch.rand(32, 128).cuda()
# zero-dp forward
zero_output = zero_model(input_data.half())
zero_output = zero_model(input_data.to(dtype))
# torch-ddp forward
torch_output = torch_model(input_data)
half_close(zero_output, torch_output, loose=True)
loose_close(zero_output, torch_output, dtype=dtype)
# zero-dp backward
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
@ -151,7 +150,7 @@ def exam_zero_1_torch_ddp():
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
half_close(p.grad, z1p.grad, loose=True)
loose_close(p.grad, z1p.grad, dtype=dtype)
# zero-dp step
zero_optimizer._sync_grad()
@ -163,7 +162,7 @@ def exam_zero_1_torch_ddp():
# check updated param
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, torch.max(torch.abs(p.data - z1p.data)))
half_close(p.data, z1p.data, loose=True)
loose_close(p.data, z1p.data, dtype=dtype)
def run_dist(rank, world_size, port):