mirror of https://github.com/hpcaitech/ColossalAI
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zeropull/3898/head^2
parent
07cb21142f
commit
ae02d4e4f7
|
@ -0,0 +1,9 @@
|
|||
from .base import MixedPrecisionMixin
|
||||
from .bf16 import BF16MixedPrecisionMixin
|
||||
from .fp16 import FP16MixedPrecisionMixin
|
||||
|
||||
__all__ = [
|
||||
'MixedPrecisionMixin',
|
||||
'FP16MixedPrecisionMixin',
|
||||
'BF16MixedPrecisionMixin',
|
||||
]
|
|
@ -0,0 +1,91 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
class MixedPrecisionMixin(ABC):
|
||||
"""A helper class for mixed precision training. This mixin is used in mixed precision optimizers.
|
||||
|
||||
Attributes:
|
||||
dtype (torc.dtype): The expected dtype of the gradients.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
class MyMixedPrecisionOptimizer(OptimizerWrapper):
|
||||
def __init__(self, optim: Optimizer):
|
||||
super().__init__(optim)
|
||||
self.mixed_precision = MixedPrecisionMixin()
|
||||
|
||||
def backward(self, loss):
|
||||
loss = self.mixed_precision.pre_backward(loss)
|
||||
loss.backward()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
||||
tensor.backward(grad)
|
||||
|
||||
def step(self):
|
||||
if self.mixed_precision.should_skip_step():
|
||||
self.zero_grad()
|
||||
return
|
||||
div_scale = self.mixed_precision.get_grad_div_scale()
|
||||
# maybe clip grad here
|
||||
# maybe scale grad here
|
||||
self.optim.step()
|
||||
|
||||
def zero_grad(self):
|
||||
self.mixed_precision.pre_zero_grad()
|
||||
return self.optim.zero_grad()
|
||||
```
|
||||
"""
|
||||
dtype: torch.dtype
|
||||
|
||||
@abstractmethod
|
||||
def pre_backward(self, loss: Tensor) -> Tensor:
|
||||
"""Called before backward.
|
||||
|
||||
Args:
|
||||
loss (Tensor): Loss value.
|
||||
|
||||
Returns:
|
||||
Tensor: Loss value (possibly scaled).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
|
||||
"""Called before backward by grad. This is helpful for pipeline parallelism.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): Tensor to backward.
|
||||
grad (Tensor): Gradient of the tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Gradient of the tensor (possibly scaled).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def should_skip_step(self) -> bool:
|
||||
"""Called before step.
|
||||
|
||||
Returns:
|
||||
bool: Whether to skip the step.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pre_zero_grad(self) -> None:
|
||||
"""Called before zero_grad.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_grad_div_scale(self) -> float:
|
||||
"""Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.
|
||||
|
||||
Returns:
|
||||
float: A divisor for gradient clipping or step.
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,23 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .base import MixedPrecisionMixin
|
||||
|
||||
|
||||
class BF16MixedPrecisionMixin(MixedPrecisionMixin):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
def pre_backward(self, loss: Tensor) -> Tensor:
|
||||
return loss
|
||||
|
||||
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
|
||||
return grad
|
||||
|
||||
def should_skip_step(self) -> bool:
|
||||
return False
|
||||
|
||||
def pre_zero_grad(self) -> None:
|
||||
pass
|
||||
|
||||
def get_grad_div_scale(self) -> float:
|
||||
return 1.0
|
|
@ -0,0 +1,84 @@
|
|||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .base import MixedPrecisionMixin
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 0
|
||||
UNSCALED = 1
|
||||
|
||||
|
||||
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
|
||||
dtype = torch.float16
|
||||
|
||||
def __init__(self,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__()
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||
|
||||
@property
|
||||
def loss_scale(self) -> float:
|
||||
return self.grad_scaler.scale.item()
|
||||
|
||||
@abstractmethod
|
||||
def check_local_overflow(self) -> bool:
|
||||
"""Check whether there is overflow in the local process. This method should be implemented by subclasses.
|
||||
|
||||
Returns:
|
||||
bool: Whether there is overflow in the local process.
|
||||
"""
|
||||
pass
|
||||
|
||||
def check_overflow(self) -> bool:
|
||||
# clear previous overflow record
|
||||
self.found_overflow.fill_(0.0)
|
||||
if self.check_local_overflow():
|
||||
self.found_overflow.fill_(1.0)
|
||||
dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX)
|
||||
return self.found_overflow.item() > 0
|
||||
|
||||
def pre_backward(self, loss: Tensor) -> Tensor:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
return loss
|
||||
|
||||
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
|
||||
self.optim_state = OptimState.SCALED
|
||||
return grad
|
||||
|
||||
def should_skip_step(self) -> bool:
|
||||
found_inf = self.check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
if found_inf:
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
return found_inf
|
||||
|
||||
def pre_zero_grad(self) -> None:
|
||||
pass
|
||||
|
||||
def get_grad_div_scale(self) -> float:
|
||||
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
return self.loss_scale
|
|
@ -23,6 +23,9 @@ from .dp_plugin_base import DPPluginBase
|
|||
|
||||
__all__ = ['GeminiPlugin']
|
||||
|
||||
SUPPORTED_PRECISION = ['fp16', 'bf16']
|
||||
PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}
|
||||
|
||||
|
||||
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||
|
||||
|
@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
Args:
|
||||
device (torch.device): device to place the model.
|
||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
|
||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
||||
|
@ -203,6 +207,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
placement_policy: str = "cpu",
|
||||
precision: str = "fp16",
|
||||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
|
@ -223,6 +228,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
verbose: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
|
||||
self.gemini_config = dict(
|
||||
device=(device or get_current_device()),
|
||||
placement_policy=placement_policy,
|
||||
|
@ -233,6 +239,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_mb=min_chunk_size_mb,
|
||||
memstats=memstats,
|
||||
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||
)
|
||||
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
|
||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||
|
@ -253,7 +260,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
return True
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16']
|
||||
return SUPPORTED_PRECISION
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -20,12 +21,15 @@ from .torch_ddp_plugin import TorchDDPCheckpointIO
|
|||
__all__ = ['LowLevelZeroPlugin']
|
||||
|
||||
|
||||
def _convert_to_fp16(x):
|
||||
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
||||
return x.half()
|
||||
return x.to(dtype)
|
||||
return x
|
||||
|
||||
|
||||
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
|
@ -49,17 +53,24 @@ class LowLevelZeroModel(ModelWrapper):
|
|||
|
||||
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
|
||||
super().__init__(module)
|
||||
self.convert_inputs = (precision == 'fp16')
|
||||
module = zero_model_wrapper(module, zero_stage=stage)
|
||||
self.dtype = None
|
||||
if precision == 'fp16':
|
||||
module = module.half()
|
||||
self.dtype = torch.float16
|
||||
elif precision == 'bf16':
|
||||
self.dtype = torch.bfloat16
|
||||
module = zero_model_wrapper(module, zero_stage=stage)
|
||||
if self.dtype is not None:
|
||||
module = module.to(self.dtype)
|
||||
module = module.to(get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_inputs:
|
||||
args = tree_map(_convert_to_fp16, args)
|
||||
kwargs = tree_map(_convert_to_fp16, kwargs)
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
|
||||
|
@ -110,7 +121,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
|
||||
Args:
|
||||
strage (int, optional): ZeRO stage. Defaults to 1.
|
||||
precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
|
||||
precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
||||
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
||||
|
@ -149,7 +160,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
|
||||
assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'
|
||||
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
|
||||
|
||||
self.stage = stage
|
||||
self.precision = precision
|
||||
|
@ -175,7 +186,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
return True
|
||||
|
||||
def supported_precisions(self) -> List[str]:
|
||||
return ['fp16', 'fp32']
|
||||
return SUPPORTED_PRECISION
|
||||
|
||||
def control_device(self) -> bool:
|
||||
return True
|
||||
|
|
|
@ -171,6 +171,21 @@
|
|||
using g_scalar_t_##LEVEL = at::Half; \
|
||||
using p_scalar_t_##LEVEL = at::Half; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::Float && \
|
||||
PTYPE == at::ScalarType::BFloat16) { \
|
||||
using g_scalar_t_##LEVEL = float; \
|
||||
using p_scalar_t_##LEVEL = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::BFloat16 && \
|
||||
PTYPE == at::ScalarType::Float) { \
|
||||
using g_scalar_t_##LEVEL = at::BFloat16; \
|
||||
using p_scalar_t_##LEVEL = float; \
|
||||
__VA_ARGS__; \
|
||||
} else if (GTYPE == at::ScalarType::BFloat16 && \
|
||||
PTYPE == at::ScalarType::BFloat16) { \
|
||||
using g_scalar_t_##LEVEL = at::BFloat16; \
|
||||
using p_scalar_t_##LEVEL = at::BFloat16; \
|
||||
__VA_ARGS__; \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
||||
"'"); \
|
||||
|
|
|
@ -93,8 +93,7 @@ class CPUAdam(NVMeOptimizer):
|
|||
bias_correction1,
|
||||
bias_correction2,
|
||||
use_adamw=False):
|
||||
# FIXME(ver217): remove the below line when replace torch adam with fused adam
|
||||
grad = grad.float()
|
||||
grad = grad.to(data.dtype)
|
||||
|
||||
if weight_decay != 0:
|
||||
if use_adamw:
|
||||
|
@ -133,10 +132,12 @@ class CPUAdam(NVMeOptimizer):
|
|||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
|
||||
# gradient momentums
|
||||
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
|
||||
state['exp_avg'] = torch.zeros_like(p, device=target_device)
|
||||
# gradient variances
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
|
||||
self._post_state_init(p)
|
||||
|
||||
state['step'] += 1
|
||||
|
@ -147,9 +148,17 @@ class CPUAdam(NVMeOptimizer):
|
|||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
||||
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
||||
state['exp_avg_sq'], div_scale)
|
||||
if p.grad.dtype is torch.bfloat16:
|
||||
# cpu adam kernel does not support bf16 now
|
||||
bias_correction1 = 1 - beta1**state['step']
|
||||
bias_correction2 = 1 - beta2**state['step']
|
||||
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
|
||||
bias_correction2, self.adamw_mode)
|
||||
else:
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'], div_scale)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
elif target_device.type == 'cuda':
|
||||
assert div_scale == -1, "div_scale should remain default"
|
||||
|
|
|
@ -134,8 +134,8 @@ class FusedAdam(torch.optim.Optimizer):
|
|||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||
|
||||
if p.dtype not in [torch.float16, torch.float32]:
|
||||
raise RuntimeError('FusedAdam only support fp16 and fp32.')
|
||||
if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]:
|
||||
raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.')
|
||||
|
||||
g_l.append(p.grad.data)
|
||||
p_l.append(p.data)
|
||||
|
|
|
@ -1,16 +1,17 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.optim import Adam
|
||||
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
from colossalai.registry import OPTIMIZERS
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
from .nvme_optimizer import NVMeOptimizer
|
||||
from .cpu_adam import CPUAdam
|
||||
|
||||
|
||||
@OPTIMIZERS.register_module
|
||||
class HybridAdam(NVMeOptimizer):
|
||||
class HybridAdam(CPUAdam):
|
||||
"""Implements Adam algorithm.
|
||||
|
||||
Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
|
||||
|
@ -74,15 +75,9 @@ class HybridAdam(NVMeOptimizer):
|
|||
nvme_offload_dir: Optional[str] = None,
|
||||
**defaults: Any):
|
||||
|
||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
||||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
||||
self.adamw_mode = adamw_mode
|
||||
|
||||
# build during runtime if not found
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction,
|
||||
nvme_offload_dir)
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||
|
||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
|
@ -108,10 +103,12 @@ class HybridAdam(NVMeOptimizer):
|
|||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
|
||||
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||
assert p.dtype is torch.float, "HybridAdam only support fp32 parameters"
|
||||
# gradient momentums
|
||||
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
|
||||
state['exp_avg'] = torch.zeros_like(p, device=target_device)
|
||||
# gradient variances
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
|
||||
self._post_state_init(p)
|
||||
|
||||
state['step'] += 1
|
||||
|
@ -122,9 +119,17 @@ class HybridAdam(NVMeOptimizer):
|
|||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
||||
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
||||
state['exp_avg_sq'], div_scale)
|
||||
if p.grad.dtype is torch.bfloat16:
|
||||
# cpu adam kernel does not support bf16 now
|
||||
bias_correction1 = 1 - beta1**state['step']
|
||||
bias_correction2 = 1 - beta2**state['step']
|
||||
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
|
||||
bias_correction2, self.adamw_mode)
|
||||
else:
|
||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||
state['exp_avg'], state['exp_avg_sq'], div_scale)
|
||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||
|
||||
elif target_device.type == 'cuda':
|
||||
|
|
|
@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP):
|
|||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
|
||||
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -59,7 +60,9 @@ class ZeroDDP(ColoDDP):
|
|||
pin_memory: bool = False,
|
||||
force_outputs_fp32: bool = False,
|
||||
strict_ddp_mode: bool = False,
|
||||
scatter_after_inference: bool = True) -> None:
|
||||
scatter_after_inference: bool = True,
|
||||
mixed_precision: torch.dtype = torch.float16) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
|
@ -71,6 +74,7 @@ class ZeroDDP(ColoDDP):
|
|||
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||
self.scatter_after_inference = scatter_after_inference
|
||||
self.mixed_precision = mixed_precision
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
|
@ -151,7 +155,7 @@ class ZeroDDP(ColoDDP):
|
|||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
if not grad_flag:
|
||||
outputs = self._inference_forward(*args, **kwargs)
|
||||
|
@ -570,14 +574,14 @@ class ZeroDDP(ColoDDP):
|
|||
|
||||
# move ignored parameters to CUDA
|
||||
if is_ddp_ignored(p):
|
||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
||||
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
|
||||
continue
|
||||
|
||||
# create a fp32 parameter
|
||||
fp32_data = p.data.float()
|
||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||
# create a fp16 parameter
|
||||
p.data = p.data.half()
|
||||
p.data = p.data.to(self.mixed_precision)
|
||||
|
||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||
dp_world_size = p.process_group.dp_world_size()
|
||||
|
@ -613,7 +617,7 @@ class ZeroDDP(ColoDDP):
|
|||
buffer.materialize()
|
||||
buffer.data = buffer.cuda()
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.half()
|
||||
buffer.data = buffer.to(self.mixed_precision)
|
||||
|
||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
||||
"""Convert parameter to ColoParameter in-place.
|
||||
|
@ -736,6 +740,7 @@ class GeminiDDP(ZeroDDP):
|
|||
hidden_dim: Optional[int] = None,
|
||||
min_chunk_size_mb: float = 32,
|
||||
memstats: Optional[MemStats] = None,
|
||||
mixed_precision: torch.dtype = torch.float16,
|
||||
verbose: bool = False) -> None:
|
||||
"""
|
||||
A torch.Module wrapper using ZeRO-DP and Gemini.
|
||||
|
@ -776,5 +781,10 @@ class GeminiDDP(ZeroDDP):
|
|||
strict_ddp_flag=strict_ddp_mode,
|
||||
verbose=verbose)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
|
||||
scatter_after_inference)
|
||||
super().__init__(module,
|
||||
gemini_manager,
|
||||
pin_memory,
|
||||
force_outputs_fp32,
|
||||
strict_ddp_mode,
|
||||
scatter_after_inference,
|
||||
mixed_precision=mixed_precision)
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||
import math
|
||||
import warnings
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -9,7 +8,7 @@ import torch.distributed as dist
|
|||
from torch.nn import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||
|
@ -22,9 +21,26 @@ __all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
|
|||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
SCALED = 0
|
||||
UNSCALED = 1
|
||||
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
module: ZeroDDP,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
self.module = module
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
return self.module.overflow_counter > 0
|
||||
|
||||
def pre_zero_grad(self) -> None:
|
||||
self.module.overflow_counter = 0
|
||||
|
||||
|
||||
class ZeroOptimizer(ColossalaiOptimizer):
|
||||
|
@ -79,7 +95,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self.module = module
|
||||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
||||
self.chunk16_set: Set[Chunk] = set()
|
||||
|
@ -107,15 +122,20 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
self.__init__optimizer()
|
||||
|
||||
# Grad scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
||||
if module.mixed_precision is torch.float16:
|
||||
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
elif module.mixed_precision is torch.bfloat16:
|
||||
self.mix_precision_mixin = BF16MixedPrecisionMixin()
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}")
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
|
@ -151,15 +171,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
for chunk16 in self.chunk16_set:
|
||||
chunk16.optim_update()
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(self.module.overflow_counter)
|
||||
|
||||
# all-reduce across global group
|
||||
dist.all_reduce(self._found_overflow)
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
def _clear_global_norm(self) -> None:
|
||||
for c16 in self.chunk16_set:
|
||||
c16.l2_norm = None
|
||||
|
@ -190,40 +201,25 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
return global_norm
|
||||
|
||||
def _get_combined_scale(self):
|
||||
loss_scale = 1
|
||||
div_scale = self.mix_precision_mixin.get_grad_div_scale()
|
||||
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
loss_scale = self.loss_scale
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
|
||||
combined_scale = loss_scale
|
||||
if self.clipping_flag:
|
||||
total_norm = self._calc_global_norm()
|
||||
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * loss_scale
|
||||
div_scale = clip * div_scale
|
||||
|
||||
if combined_scale == 1:
|
||||
return -1
|
||||
else:
|
||||
return combined_scale
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale.item()
|
||||
return -1 if div_scale == 1.0 else div_scale
|
||||
|
||||
def zero_grad(self, *args, **kwargs):
|
||||
self.module.overflow_counter = 0
|
||||
self.mix_precision_mixin.pre_zero_grad()
|
||||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._maybe_move_fp32_params()
|
||||
self._set_grad_ptr()
|
||||
|
||||
found_inf = self._check_overflow()
|
||||
if found_inf:
|
||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
||||
if self.mix_precision_mixin.should_skip_step():
|
||||
if self.verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self._clear_global_norm() # clear recorded norm
|
||||
|
@ -234,7 +230,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
# get combined scale. combined scale = loss scale * clipping norm
|
||||
# so that gradient = gradient / combined scale
|
||||
combined_scale = self._get_combined_scale()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
||||
self._register_states()
|
||||
|
@ -246,8 +241,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
raise NotImplementedError
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
loss = self.mix_precision_mixin.pre_backward(loss)
|
||||
self.module.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||
|
@ -255,7 +249,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
self.optim_state = OptimState.SCALED
|
||||
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
|
||||
self.module.backward_by_grad(tensor, grad)
|
||||
|
||||
def _maybe_move_fp32_params(self):
|
||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
|
||||
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||
from colossalai.zero.legacy.sharded_param import ShardedParamV2
|
||||
|
||||
|
@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
seed (int, optional): Random seed for weight initialization
|
||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
|
||||
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
|
||||
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||
"""
|
||||
|
||||
|
@ -64,6 +65,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
seed: int = 2**10 - 1,
|
||||
shard_param: bool = False,
|
||||
default_dtype: Optional[torch.dtype] = None,
|
||||
bf16: bool = False,
|
||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
|
||||
|
||||
super().__init__(default_dtype=default_dtype)
|
||||
|
@ -71,6 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self.param_list = []
|
||||
self.model_numel_tensor = model_numel_tensor
|
||||
self.seed = seed
|
||||
self.bf16 = bf16
|
||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
||||
|
@ -183,9 +186,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
self.top_module = module
|
||||
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
|
||||
|
||||
def half_fn(t: torch.Tensor):
|
||||
return t.half() if t.is_floating_point() else t
|
||||
return t.to(half_dtype) if t.is_floating_point() else t
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
|
@ -226,9 +230,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
# We must cast buffers
|
||||
# If we use BN, buffers may be on CPU and Float
|
||||
# We must cast them
|
||||
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
|
||||
for buffer in module.buffers(recurse=False):
|
||||
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
||||
buffer.data = cast_tensor_to_fp16(buffer.data)
|
||||
buffer.data = cast_fn(buffer.data)
|
||||
|
||||
|
||||
class ZeroContextMgr(metaclass=SingletonMeta):
|
||||
|
|
|
@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te
|
|||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
|
||||
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
|
||||
return tensor.float()
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
||||
return tensor.bfloat16()
|
||||
return tensor
|
||||
|
||||
|
||||
def apply_to_tensors(x: Any, fn: Callable):
|
||||
if torch.is_tensor(x):
|
||||
return fn(x)
|
||||
|
|
|
@ -28,6 +28,7 @@ from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBuc
|
|||
|
||||
from ._utils import (
|
||||
cast_float_arguments,
|
||||
cast_tensor_to_bf16,
|
||||
cast_tensor_to_fp16,
|
||||
cast_tensor_to_fp32,
|
||||
chunk_and_pad,
|
||||
|
@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module):
|
|||
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
||||
We find that PyTorch's optimizers don't support mixed precision,
|
||||
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
|
||||
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
|
@ -86,11 +88,13 @@ class ShardedModelV2(nn.Module):
|
|||
tensor_placement_policy: str = 'cuda',
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
reuse_fp16_shard: bool = False,
|
||||
bf16: bool = False,
|
||||
*args,
|
||||
**kwargs):
|
||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger()
|
||||
self.bf16 = bf16
|
||||
|
||||
# We force users to use ZeroInitContext
|
||||
for submodule in module.modules():
|
||||
|
@ -232,7 +236,8 @@ class ShardedModelV2(nn.Module):
|
|||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
self._pre_forward_operations(*args)
|
||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
|
||||
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self._post_forward_operations()
|
||||
return outputs
|
||||
|
|
|
@ -94,6 +94,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
super().__init__(optimizer)
|
||||
self.shard_strategy = sharded_model.shard_strategy
|
||||
self.model: ShardedModelV2 = sharded_model
|
||||
self.bf16 = sharded_model.bf16
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||
|
@ -117,6 +118,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
|
||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||
self._verbose = verbose
|
||||
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
|
||||
|
||||
# Store fp32 param shards
|
||||
self._register_master_weight()
|
||||
|
@ -166,8 +168,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self._zero_grad()
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
if not self.bf16:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
self._grad_prepared = False
|
||||
self.model.backward(loss)
|
||||
|
||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||
|
@ -175,30 +179,33 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# It receives the scaled grad from the previous rank
|
||||
# No need to scale the grad again
|
||||
# Need to unscale when optimizing
|
||||
self.optim_state = OptimState.SCALED
|
||||
if not self.bf16:
|
||||
self.optim_state = OptimState.SCALED
|
||||
self._grad_prepared = False
|
||||
self.model.backward_by_grad(tensor, grad)
|
||||
|
||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._prepare_grads()
|
||||
self._prepare_grads()
|
||||
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
return super().clip_grad_norm(model, max_norm)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
|
||||
self._prepare_grads()
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._prepare_grads()
|
||||
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
|
||||
self._maybe_move_fp32_shards()
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
if not self.bf16:
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
if found_inf:
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self._zero_grad(recover_data=True)
|
||||
return
|
||||
if found_inf:
|
||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||
self._zero_grad(recover_data=True)
|
||||
return
|
||||
|
||||
self._point_param_fp16_to_master_param()
|
||||
|
||||
|
@ -304,6 +311,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
state[k] = v.cuda()
|
||||
|
||||
def _prepare_grads(self):
|
||||
if self._grad_prepared:
|
||||
return
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if p.colo_attr.saved_grad.is_null():
|
||||
|
@ -320,6 +329,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
p.grad = p.colo_attr.grad_payload
|
||||
# Set p.data to empty tensor, in case of memory leaking
|
||||
p.colo_attr.set_data_none()
|
||||
self._grad_prepared = True
|
||||
|
||||
def _point_param_fp16_to_master_param(self):
|
||||
# assign master param pointers to p.data.
|
||||
|
@ -357,7 +367,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
|
||||
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
|
||||
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
|
|
|
@ -6,7 +6,11 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||
BF16MixedPrecisionMixin,
|
||||
FP16MixedPrecisionMixin,
|
||||
MixedPrecisionMixin,
|
||||
)
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -27,6 +31,31 @@ from ._utils import (
|
|||
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
||||
|
||||
|
||||
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||
|
||||
def __init__(self,
|
||||
num_working_param_groups: int,
|
||||
grad_store: GradientStore,
|
||||
initial_scale: float = 2**16,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
max_scale: float = 2**32) -> None:
|
||||
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||
max_scale)
|
||||
self.num_working_param_groups = num_working_param_groups
|
||||
self.grad_store = grad_store
|
||||
|
||||
def check_local_overflow(self) -> bool:
|
||||
for group_id in range(self.num_working_param_groups):
|
||||
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||
"""
|
||||
|
@ -100,17 +129,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
self._reduce_bucket_size = reduce_bucket_size
|
||||
self._communication_dtype = communication_dtype
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale,
|
||||
verbose=verbose)
|
||||
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
|
||||
|
||||
# gradient clipping
|
||||
self._clip_grad_norm = clip_grad_norm
|
||||
|
||||
|
@ -200,14 +218,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
if self._overlap_communication or self._partition_grads:
|
||||
self._attach_reduction_hook()
|
||||
|
||||
# initialize mixed precision mixin
|
||||
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
|
||||
if self._dtype is torch.float16:
|
||||
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
|
||||
self._grad_store,
|
||||
initial_scale=initial_scale,
|
||||
min_scale=min_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
max_scale=max_scale)
|
||||
elif self._dtype is torch.bfloat16:
|
||||
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
return self.grad_scaler.scale
|
||||
|
||||
@property
|
||||
def num_param_groups(self):
|
||||
return len(self._working_param_groups)
|
||||
|
@ -392,7 +421,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=False, sync_grad=True):
|
||||
loss = self.loss_scale * loss
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
# finish gradient reduction
|
||||
|
@ -419,6 +449,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
||||
:type set_to_none: bool
|
||||
"""
|
||||
if self.mixed_precision_mixin is not None:
|
||||
self.mixed_precision_mixin.pre_zero_grad()
|
||||
for _, param_group in self._working_param_groups.items():
|
||||
for param in param_group:
|
||||
if set_to_none:
|
||||
|
@ -435,12 +467,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
|
||||
# check for overflow
|
||||
found_inf = self._check_overflow()
|
||||
self.grad_scaler.update(found_inf)
|
||||
|
||||
# update loss scale if overflow occurs
|
||||
if found_inf:
|
||||
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||
self._grad_store.reset_all_average_gradients()
|
||||
if self._verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
|
@ -507,41 +534,20 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
# Mixed Precision Utilities #
|
||||
#############################
|
||||
|
||||
def _check_overflow(self):
|
||||
# clear previous overflow record
|
||||
self._found_overflow.fill_(0.0)
|
||||
|
||||
# check for overflow
|
||||
for group_id in range(len(self._working_param_groups)):
|
||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
|
||||
# all-reduce across dp group
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
|
||||
|
||||
# all-reduce over model parallel group
|
||||
if self._mp_torch_group:
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
|
||||
|
||||
if self._found_overflow.item() > 0:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
||||
# compute combined scale factor for this group
|
||||
combined_scale = self.loss_scale
|
||||
div_scale = 1.0
|
||||
if self.mixed_precision_mixin is not None:
|
||||
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
|
||||
|
||||
if self._clip_grad_norm > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
|
||||
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
|
||||
if clip > 1:
|
||||
combined_scale = clip * self.loss_scale
|
||||
div_scale = clip * div_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
grad.data.mul_(1. / combined_scale)
|
||||
grad.data.mul_(1. / div_scale)
|
||||
|
||||
############################
|
||||
# Gradient Synchronization #
|
||||
|
|
|
@ -0,0 +1,131 @@
|
|||
# This test checks adam kernels
|
||||
# Baseline is pure fp32 torch adam optimizer
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from typing import Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from colossalai.utils import get_current_device, multi_tensor_applier
|
||||
|
||||
_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
|
||||
(torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16),
|
||||
(torch.bfloat16, torch.bfloat16)]
|
||||
|
||||
_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
|
||||
(torch.half, torch.half)]
|
||||
|
||||
|
||||
class AdamKernel:
|
||||
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
self.lr = lr
|
||||
self.beta1 = beta1
|
||||
self.beta2 = beta2
|
||||
self.eps = eps
|
||||
self.weight_decay = weight_decay
|
||||
self.use_adamw = use_adamw
|
||||
|
||||
@abstractmethod
|
||||
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||
pass
|
||||
|
||||
|
||||
class TorchAdamKernel(AdamKernel):
|
||||
|
||||
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||
bias_correction1 = 1 - self.beta1**step
|
||||
bias_correction2 = 1 - self.beta2**step
|
||||
|
||||
if self.weight_decay != 0:
|
||||
if self.use_adamw:
|
||||
# Perform stepweight decay
|
||||
param.mul_(1 - self.lr * self.weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=self.weight_decay)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
|
||||
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
|
||||
|
||||
step_size = self.lr / bias_correction1
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
class FusedAdamKernel(AdamKernel):
|
||||
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
self.fused_adam = fused_optim.multi_tensor_adam
|
||||
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||
multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]],
|
||||
self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay,
|
||||
-1)
|
||||
|
||||
|
||||
class CPUAdamKernel(AdamKernel):
|
||||
|
||||
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
|
||||
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||
|
||||
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||
self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1),
|
||||
grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1)
|
||||
|
||||
|
||||
def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype,
|
||||
g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float):
|
||||
lr = 1e-3
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
master_p = torch.rand(64, device=device)
|
||||
master_g = torch.rand_like(master_p)
|
||||
master_exp_avg = torch.zeros_like(master_p)
|
||||
master_exp_avg_sq = torch.zeros_like(master_p)
|
||||
p = master_p.clone().to(p_dtype)
|
||||
g = master_g.clone().to(g_dtype)
|
||||
exp_avg = master_exp_avg.clone()
|
||||
exp_avg_sq = master_exp_avg_sq.clone()
|
||||
|
||||
for step in range(1, 1 + n_steps):
|
||||
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
|
||||
adam_kernel.update(step, p, g, exp_avg, exp_avg_sq)
|
||||
# if overflow, the weight won't be updated. so there will be no nan in p
|
||||
assert not torch.isnan(p).any()
|
||||
assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('adamw', [False, True])
|
||||
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
|
||||
@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES)
|
||||
def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
|
||||
rtol, atol = 1e-5, 1e-8
|
||||
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
||||
rtol, atol = 1e-3, 1e-3
|
||||
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
|
||||
rtol, atol = 4e-3, 4e-3
|
||||
check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('adamw', [False, True])
|
||||
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
|
||||
@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES)
|
||||
def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
|
||||
rtol, atol = 1e-5, 1e-8
|
||||
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
||||
rtol, atol = 1e-3, 1e-3
|
||||
check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol)
|
|
@ -0,0 +1,86 @@
|
|||
from copy import deepcopy
|
||||
from typing import Type, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import Adam, AdamW
|
||||
|
||||
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
_ALLOWED_OPTIM_DEVICES = [
|
||||
(FusedAdam, torch.device('cuda:0')),
|
||||
(CPUAdam, torch.device('cpu')),
|
||||
(CPUAdam, torch.device('cuda:0')),
|
||||
(HybridAdam, torch.device('cpu')),
|
||||
(HybridAdam, torch.device('cuda:0')),
|
||||
]
|
||||
|
||||
_ALLOWED_P_G_TYPES = [
|
||||
(torch.float, torch.float), # pure fp32
|
||||
(torch.float, torch.half), # fp16 amp
|
||||
(torch.float, torch.bfloat16), # bfloat16 amp
|
||||
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
|
||||
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
|
||||
]
|
||||
|
||||
N_STEPS = 3
|
||||
|
||||
|
||||
def setup_param_groups(bert_model: nn.Module) -> list:
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.1,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
return optimizer_grouped_parameters
|
||||
|
||||
|
||||
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
torch_p.grad = torch.rand_like(torch_p)
|
||||
# avoid inconsistent grad and param dtype error
|
||||
orig_p = p.data
|
||||
p.data = torch_p.grad.clone().to(g_dtype)
|
||||
p.grad = p.data
|
||||
p.data = orig_p
|
||||
|
||||
|
||||
@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES)
|
||||
@pytest.mark.parametrize('adamw', [False, True])
|
||||
@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES)
|
||||
def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device,
|
||||
adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None:
|
||||
model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values()))
|
||||
torch_model = model_fn().to(device)
|
||||
model = deepcopy(torch_model).to(p_dtype)
|
||||
lr = 1e-3
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
torch_optim_cls = AdamW if adamw else Adam
|
||||
torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
|
||||
optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)
|
||||
|
||||
rtol, atol = 1e-5, 1e-5
|
||||
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
||||
rtol, atol = 2e-3, 2e-3
|
||||
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
|
||||
rtol, atol = 4e-3, 4e-3
|
||||
|
||||
for _ in range(N_STEPS):
|
||||
set_grad(model, torch_model, g_dtype)
|
||||
torch_optim.step()
|
||||
optim.step()
|
||||
torch_optim.zero_grad()
|
||||
optim.zero_grad()
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
# if overflow, the weight won't be updated. so there will be no nan in p
|
||||
assert not torch.isnan(p).any()
|
||||
assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)
|
|
@ -1,121 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
def torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
param,
|
||||
grad,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
use_adamw,
|
||||
):
|
||||
bias_correction1 = 1 - beta1**step
|
||||
bias_correction2 = 1 - beta2**step
|
||||
|
||||
if weight_decay != 0:
|
||||
if use_adamw:
|
||||
# Perform stepweight decay
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
def assertLess(data_diff, threshold, msg):
|
||||
assert data_diff < threshold, msg
|
||||
|
||||
|
||||
def assertTrue(condition, msg):
|
||||
assert condition, msg
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('adamw', [True, False])
|
||||
@parameterize('step', [1, 2])
|
||||
@parameterize('p_dtype', [torch.float, torch.half])
|
||||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
||||
lr = 1e-3
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
weight_decay = 0
|
||||
|
||||
for i in range(3):
|
||||
p_data = torch.rand(64, dtype=p_dtype)
|
||||
p_data_copy = p_data.clone().float()
|
||||
p_grad = torch.rand(64, dtype=g_dtype)
|
||||
p_grad_copy = p_grad.clone().float()
|
||||
exp_avg = torch.rand(p_data.shape)
|
||||
exp_avg_copy = exp_avg.clone()
|
||||
exp_avg_sq = torch.rand(p_data.shape)
|
||||
exp_avg_sq_copy = exp_avg_sq.clone()
|
||||
|
||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||
cpu_optim = CPUAdamBuilder().load()
|
||||
|
||||
cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||
|
||||
cpu_adam_op.step(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
True,
|
||||
p_data.view(-1), # fp32 data
|
||||
p_grad.view(-1), # fp32 grad
|
||||
exp_avg.view(-1),
|
||||
exp_avg_sq.view(-1),
|
||||
-1,
|
||||
)
|
||||
|
||||
torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
p_data_copy, # fp32 data
|
||||
p_grad_copy, # fp32 grad
|
||||
exp_avg_copy,
|
||||
exp_avg_sq_copy,
|
||||
adamw,
|
||||
)
|
||||
var = p_data_copy - p_data
|
||||
data_diff = torch.max(torch.abs(var))
|
||||
threshold = 1e-3
|
||||
assertLess(
|
||||
data_diff,
|
||||
threshold,
|
||||
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps "
|
||||
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
|
||||
)
|
||||
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
|
||||
assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
|
||||
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
||||
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
||||
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
||||
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_cpu_adam()
|
|
@ -1,64 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.adam import Adam
|
||||
|
||||
from colossalai.nn.optimizer.fused_adam import FusedAdam
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
|
||||
class FC(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(nn.Linear(64, 64))
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc(x)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('adamw', [False, True])
|
||||
@parameterize('p_dtype', [torch.float, torch.half])
|
||||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_adam(adamw, p_dtype, g_dtype):
|
||||
model = FC().cuda().to(p_dtype)
|
||||
state = model.state_dict()
|
||||
model_copy = FC().cuda().to(p_dtype)
|
||||
model_copy.load_state_dict(state.copy())
|
||||
|
||||
if adamw:
|
||||
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
|
||||
torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
|
||||
else:
|
||||
optim = FusedAdam(model.parameters(), lr=1e-3)
|
||||
torch_optim = Adam(model_copy.parameters(), lr=1e-3)
|
||||
|
||||
data = torch.rand(1024, 64).cuda().to(p_dtype)
|
||||
data_copy = data.clone()
|
||||
label = torch.rand(1024, 64).cuda().to(p_dtype)
|
||||
|
||||
for d, l in zip(data, label):
|
||||
y = model(d)
|
||||
loss = ((l - y)**2).sum()
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
if p_dtype != g_dtype:
|
||||
for i in range(len(optim.param_groups[0]['params'])):
|
||||
optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
|
||||
optim.step()
|
||||
|
||||
for d, l in zip(data_copy, label):
|
||||
y = model_copy(d)
|
||||
loss = ((l - y)**2).sum()
|
||||
torch_optim.zero_grad()
|
||||
loss.backward()
|
||||
torch_optim.step()
|
||||
|
||||
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
|
||||
|
||||
for i in range(len(optim.param_groups[0]['params'])):
|
||||
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
|
||||
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
|
||||
continue
|
||||
assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)
|
|
@ -1,95 +0,0 @@
|
|||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from numpy import dtype
|
||||
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
from colossalai.utils import multi_tensor_applier
|
||||
|
||||
|
||||
def torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
param,
|
||||
grad,
|
||||
exp_avg,
|
||||
exp_avg_sq,
|
||||
use_adamw,
|
||||
):
|
||||
bias_correction1 = 1 - beta1**step
|
||||
bias_correction2 = 1 - beta2**step
|
||||
|
||||
if weight_decay != 0:
|
||||
if use_adamw:
|
||||
# Perform stepweight decay
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
||||
|
||||
step_size = lr / bias_correction1
|
||||
|
||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('adamw', [False, True])
|
||||
@parameterize('step', [1, 2])
|
||||
@parameterize('p_dtype', [torch.float, torch.half])
|
||||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_adam(adamw, step, p_dtype, g_dtype):
|
||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||
fused_optim = FusedOptimBuilder().load()
|
||||
fused_adam = fused_optim.multi_tensor_adam
|
||||
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
|
||||
count = 0
|
||||
|
||||
for i in range(3):
|
||||
p = torch.rand(64, dtype=p_dtype).cuda()
|
||||
p_copy = p.clone().float()
|
||||
g = torch.rand(p.shape, dtype=g_dtype).cuda()
|
||||
g_copy = g.clone().float()
|
||||
m = torch.rand(p.shape).cuda()
|
||||
m_copy = m.clone()
|
||||
v = torch.rand(p.shape).cuda()
|
||||
v_copy = v.clone()
|
||||
|
||||
lr = 1e-3
|
||||
beta1, beta2 = 0.9, 0.999
|
||||
eps = 1e-8
|
||||
weight_decay = 0
|
||||
|
||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
|
||||
True, weight_decay, -1)
|
||||
|
||||
torch_adam_update(
|
||||
step,
|
||||
lr,
|
||||
beta1,
|
||||
beta2,
|
||||
eps,
|
||||
weight_decay,
|
||||
p_copy, # fp32 data
|
||||
g_copy, # fp32 grad
|
||||
m_copy,
|
||||
v_copy,
|
||||
adamw,
|
||||
)
|
||||
|
||||
if torch.isnan(p).any() or torch.isnan(p_copy).any():
|
||||
count += 1
|
||||
continue
|
||||
assert count < 200, "too many nans"
|
||||
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5,
|
||||
1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
|
@ -1,42 +0,0 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.adam import Adam
|
||||
|
||||
from colossalai.nn.optimizer.hybrid_adam import HybridAdam
|
||||
from colossalai.testing import clear_cache_before_run, parameterize
|
||||
|
||||
RE = 3
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('adamw', [False, True])
|
||||
@parameterize('device', ['cpu', 'cuda:0'])
|
||||
@parameterize('p_dtype', [torch.float])
|
||||
@parameterize('g_dtype', [torch.float, torch.half])
|
||||
def test_adam(adamw, device, p_dtype, g_dtype):
|
||||
rng_state = torch.get_rng_state()
|
||||
p = nn.Parameter(torch.rand(64).to(device, p_dtype))
|
||||
torch.set_rng_state(rng_state)
|
||||
p_copy = nn.Parameter(torch.rand(64).to(device).float())
|
||||
|
||||
if adamw:
|
||||
optim = HybridAdam([p], lr=1e-3, adamw_mode=True)
|
||||
torch_optim = AdamW([p_copy], lr=1e-3)
|
||||
else:
|
||||
optim = HybridAdam([p], lr=1e-3)
|
||||
torch_optim = Adam([p_copy], lr=1e-3)
|
||||
|
||||
print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}")
|
||||
for i in range(RE):
|
||||
p.grad = torch.rand(64).to(device, p_dtype)
|
||||
p_copy.grad = p.grad.clone().float()
|
||||
p.grad.data = p.grad.data.to(g_dtype)
|
||||
|
||||
optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
|
||||
continue
|
||||
assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
|
||||
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
|
@ -21,23 +21,40 @@ TEST_MODELS = ['gpt2']
|
|||
# these models are too small, all parameters in these models are compacted into one chunk
|
||||
EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers']
|
||||
|
||||
# bfloat16 cannot represent them exactly
|
||||
BF16_IGNORED_KEYS = [
|
||||
'albert.embeddings.word_embeddings.weight',
|
||||
'albert.embeddings.position_embeddings.weight',
|
||||
'masked_bias',
|
||||
]
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
||||
zero_dict = model.state_dict(only_rank_0=False)
|
||||
|
||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
|
||||
torch_dict = torch_model.state_dict()
|
||||
|
||||
for key, value in torch_dict.items():
|
||||
# key is 'module.model.PARAMETER', so we truncate it
|
||||
key = key[7:]
|
||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
||||
temp_zero_value = zero_dict[key].to(device=value.device)
|
||||
if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS):
|
||||
continue
|
||||
rtol, atol = 1e-3, 4e-3
|
||||
if dtype is torch.bfloat16:
|
||||
rtol, atol = 4e-3, 8e-3
|
||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
|
||||
assert_close(value.float(),
|
||||
temp_zero_value.float(),
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}')
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('model_name', TEST_MODELS)
|
||||
def exam_model_step(placement_policy, model_name: str):
|
||||
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||
def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||
set_seed(42)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -65,7 +82,7 @@ def exam_model_step(placement_policy, model_name: str):
|
|||
init_device = None
|
||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
||||
|
@ -74,6 +91,7 @@ def exam_model_step(placement_policy, model_name: str):
|
|||
torch_model.eval()
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
rtol, atol = 1e-4, 1e-5
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
@ -83,17 +101,18 @@ def exam_model_step(placement_policy, model_name: str):
|
|||
|
||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss)
|
||||
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
check_param(model, torch_model)
|
||||
check_param(model, torch_model, mixed_precision)
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||
@parameterize('model_name', EXAMPLE_MODELS)
|
||||
def exam_tiny_example(placement_policy, model_name: str):
|
||||
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||
def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||
set_seed(2008)
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
@ -113,7 +132,7 @@ def exam_tiny_example(placement_policy, model_name: str):
|
|||
|
||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
||||
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
||||
|
||||
|
@ -121,6 +140,9 @@ def exam_tiny_example(placement_policy, model_name: str):
|
|||
torch_model.eval()
|
||||
|
||||
set_seed(dist.get_rank() * 3 + 128)
|
||||
rtol, atol = 1.5e-6, 2e-5
|
||||
if mixed_precision is torch.bfloat16:
|
||||
rtol, atol = 2e-3, 2e-3
|
||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
@ -133,12 +155,12 @@ def exam_tiny_example(placement_policy, model_name: str):
|
|||
|
||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||
assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12
|
||||
assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12
|
||||
|
||||
zero_optim.step()
|
||||
torch_optim.step()
|
||||
|
||||
check_param(model, torch_model)
|
||||
check_param(model, torch_model, mixed_precision)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -16,7 +16,11 @@ from colossalai.zero.low_level._utils import has_inf_or_nan
|
|||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, parallel_config):
|
||||
def run_dist(rank, world_size, port, parallel_config, bf16):
|
||||
is_mp_config = parallel_config == MP_PARALLEL_CONFIG
|
||||
is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG
|
||||
if bf16:
|
||||
parallel_config['zero']['model_config']['bf16'] = True
|
||||
colossalai.launch(config=parallel_config,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
|
@ -30,7 +34,8 @@ def run_dist(rank, world_size, port, parallel_config):
|
|||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||
shard_param=True):
|
||||
shard_param=True,
|
||||
bf16=bf16):
|
||||
colo_model = model_builder(checkpoint=True)
|
||||
|
||||
colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
|
||||
|
@ -38,7 +43,8 @@ def run_dist(rank, world_size, port, parallel_config):
|
|||
optimizer=colo_optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
torch_model = model_builder(checkpoint=True).half()
|
||||
dtype = torch.bfloat16 if bf16 else torch.float16
|
||||
torch_model = model_builder(checkpoint=True).to(dtype)
|
||||
col_model_deepcopy(engine.model, torch_model)
|
||||
torch_model = torch_model.cuda().float()
|
||||
|
||||
|
@ -80,9 +86,9 @@ def run_dist(rank, world_size, port, parallel_config):
|
|||
torch_optimizer.step()
|
||||
i += 1
|
||||
|
||||
if parallel_config == MP_PARALLEL_CONFIG:
|
||||
if is_mp_config:
|
||||
check_params(torch_model, colo_model, loose=True)
|
||||
elif parallel_config == ZERO_PARALLEL_CONFIG:
|
||||
elif is_zero_config:
|
||||
check_sharded_model_params(torch_model, colo_model, loose=True)
|
||||
|
||||
|
||||
|
@ -97,9 +103,10 @@ def test_mp_engine(world_size):
|
|||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
@pytest.mark.parametrize("bf16", [True, False])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_engine(world_size):
|
||||
spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG)
|
||||
def test_zero_engine(world_size, bf16):
|
||||
spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -82,7 +82,6 @@ def exam_zero_1_2_grad_acc():
|
|||
|
||||
def exam_zero_1_grad_acc():
|
||||
local_rank = torch.distributed.get_rank()
|
||||
grad_scale = 32
|
||||
seed_all(2008)
|
||||
|
||||
# create models
|
||||
|
@ -101,7 +100,6 @@ def exam_zero_1_grad_acc():
|
|||
# level 1 and 2 will produce exactly the same results
|
||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
overlap_communication=False,
|
||||
initial_scale=grad_scale,
|
||||
reduce_bucket_size=262144,
|
||||
clip_grad_norm=1.0)
|
||||
|
||||
|
@ -128,9 +126,8 @@ def exam_zero_1_grad_acc():
|
|||
if check_flag:
|
||||
# check grad
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
unscale_grad = z1p.grad / grad_scale
|
||||
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
||||
assert torch.equal(p.grad, unscale_grad)
|
||||
assert torch.equal(p.grad, z1p.grad)
|
||||
|
||||
zero_optimizer._sync_grad()
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
|||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
@ -25,15 +25,18 @@ class MlpModel(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def half_close(a, b, loose=False):
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
rtol = None
|
||||
atol = None
|
||||
if loose:
|
||||
if dtype is torch.float16:
|
||||
rtol = 5e-2
|
||||
atol = 5e-4
|
||||
elif dtype is torch.bfloat16:
|
||||
rtol = 4e-3
|
||||
atol = 4e-3
|
||||
|
||||
a = a.detach().half()
|
||||
b = b.detach().half()
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype)
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
@ -96,7 +99,8 @@ def exam_zero_1_2():
|
|||
assert torch.equal(z1p.data, z2p.data)
|
||||
|
||||
|
||||
def exam_zero_1_torch_ddp():
|
||||
@parameterize('dtype', [torch.float16, torch.bfloat16])
|
||||
def exam_zero_1_torch_ddp(dtype: torch.dtype):
|
||||
"""
|
||||
In this test, two pairs of model and optimizers are created.
|
||||
1. zero: use sharded optimizer and fp16 parameters
|
||||
|
@ -109,15 +113,10 @@ def exam_zero_1_torch_ddp():
|
|||
seed_all(1453)
|
||||
|
||||
# create models
|
||||
zero_model = MlpModel()
|
||||
torch_model = copy.deepcopy(zero_model)
|
||||
torch_model = MlpModel().cuda()
|
||||
zero_model = copy.deepcopy(torch_model).to(dtype)
|
||||
|
||||
zero_model = zero_model.cuda().half()
|
||||
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
|
||||
torch_model = torch_model.cuda()
|
||||
|
||||
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
# half_close(p.data, z1p.data)
|
||||
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda()
|
||||
|
||||
# create optimizer
|
||||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
||||
|
@ -137,11 +136,11 @@ def exam_zero_1_torch_ddp():
|
|||
input_data = torch.rand(32, 128).cuda()
|
||||
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(input_data.half())
|
||||
zero_output = zero_model(input_data.to(dtype))
|
||||
|
||||
# torch-ddp forward
|
||||
torch_output = torch_model(input_data)
|
||||
half_close(zero_output, torch_output, loose=True)
|
||||
loose_close(zero_output, torch_output, dtype=dtype)
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
|
||||
|
@ -151,7 +150,7 @@ def exam_zero_1_torch_ddp():
|
|||
|
||||
# check grad
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
half_close(p.grad, z1p.grad, loose=True)
|
||||
loose_close(p.grad, z1p.grad, dtype=dtype)
|
||||
|
||||
# zero-dp step
|
||||
zero_optimizer._sync_grad()
|
||||
|
@ -163,7 +162,7 @@ def exam_zero_1_torch_ddp():
|
|||
# check updated param
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
# print(n, torch.max(torch.abs(p.data - z1p.data)))
|
||||
half_close(p.data, z1p.data, loose=True)
|
||||
loose_close(p.data, z1p.data, dtype=dtype)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
Loading…
Reference in New Issue