mirror of https://github.com/hpcaitech/ColossalAI
[bf16] add bf16 support (#3882)
* [bf16] add bf16 support for fused adam (#3844) * [bf16] fused adam kernel support bf16 * [test] update fused adam kernel test * [test] update fused adam test * [bf16] cpu adam and hybrid adam optimizers support bf16 (#3860) * [bf16] implement mixed precision mixin and add bf16 support for low level zero (#3869) * [bf16] add mixed precision mixin * [bf16] low level zero optim support bf16 * [text] update low level zero test * [text] fix low level zero grad acc test * [bf16] add bf16 support for gemini (#3872) * [bf16] gemini support bf16 * [test] update gemini bf16 test * [doc] update gemini docstring * [bf16] add bf16 support for plugins (#3877) * [bf16] add bf16 support for legacy zero (#3879) * [zero] init context support bf16 * [zero] legacy zero support bf16 * [test] add zero bf16 test * [doc] add bf16 related docstring for legacy zeropull/3898/head^2
parent
07cb21142f
commit
ae02d4e4f7
|
@ -0,0 +1,9 @@
|
||||||
|
from .base import MixedPrecisionMixin
|
||||||
|
from .bf16 import BF16MixedPrecisionMixin
|
||||||
|
from .fp16 import FP16MixedPrecisionMixin
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'MixedPrecisionMixin',
|
||||||
|
'FP16MixedPrecisionMixin',
|
||||||
|
'BF16MixedPrecisionMixin',
|
||||||
|
]
|
|
@ -0,0 +1,91 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class MixedPrecisionMixin(ABC):
|
||||||
|
"""A helper class for mixed precision training. This mixin is used in mixed precision optimizers.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
dtype (torc.dtype): The expected dtype of the gradients.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
class MyMixedPrecisionOptimizer(OptimizerWrapper):
|
||||||
|
def __init__(self, optim: Optimizer):
|
||||||
|
super().__init__(optim)
|
||||||
|
self.mixed_precision = MixedPrecisionMixin()
|
||||||
|
|
||||||
|
def backward(self, loss):
|
||||||
|
loss = self.mixed_precision.pre_backward(loss)
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
def backward_by_grad(self, tensor, grad):
|
||||||
|
grad = self.mixed_precision.pre_backward_by_grad(tensor, grad)
|
||||||
|
tensor.backward(grad)
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
if self.mixed_precision.should_skip_step():
|
||||||
|
self.zero_grad()
|
||||||
|
return
|
||||||
|
div_scale = self.mixed_precision.get_grad_div_scale()
|
||||||
|
# maybe clip grad here
|
||||||
|
# maybe scale grad here
|
||||||
|
self.optim.step()
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
self.mixed_precision.pre_zero_grad()
|
||||||
|
return self.optim.zero_grad()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
dtype: torch.dtype
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pre_backward(self, loss: Tensor) -> Tensor:
|
||||||
|
"""Called before backward.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss (Tensor): Loss value.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Loss value (possibly scaled).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
|
||||||
|
"""Called before backward by grad. This is helpful for pipeline parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (Tensor): Tensor to backward.
|
||||||
|
grad (Tensor): Gradient of the tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor: Gradient of the tensor (possibly scaled).
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def should_skip_step(self) -> bool:
|
||||||
|
"""Called before step.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether to skip the step.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def pre_zero_grad(self) -> None:
|
||||||
|
"""Called before zero_grad.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_grad_div_scale(self) -> float:
|
||||||
|
"""Called before step or clip_grad. To keep computation efficiency, this method does not (maybe) unscale grads.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: A divisor for gradient clipping or step.
|
||||||
|
"""
|
||||||
|
pass
|
|
@ -0,0 +1,23 @@
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from .base import MixedPrecisionMixin
|
||||||
|
|
||||||
|
|
||||||
|
class BF16MixedPrecisionMixin(MixedPrecisionMixin):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
def pre_backward(self, loss: Tensor) -> Tensor:
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def should_skip_step(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def pre_zero_grad(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_grad_div_scale(self) -> float:
|
||||||
|
return 1.0
|
|
@ -0,0 +1,84 @@
|
||||||
|
from abc import abstractmethod
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
from .base import MixedPrecisionMixin
|
||||||
|
|
||||||
|
|
||||||
|
class OptimState(Enum):
|
||||||
|
SCALED = 0
|
||||||
|
UNSCALED = 1
|
||||||
|
|
||||||
|
|
||||||
|
class FP16MixedPrecisionMixin(MixedPrecisionMixin):
|
||||||
|
dtype = torch.float16
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
initial_scale: float = 2**16,
|
||||||
|
min_scale: float = 1,
|
||||||
|
growth_factor: float = 2,
|
||||||
|
backoff_factor: float = 0.5,
|
||||||
|
growth_interval: int = 1000,
|
||||||
|
hysteresis: int = 2,
|
||||||
|
max_scale: float = 2**32) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
||||||
|
min_scale=min_scale,
|
||||||
|
growth_factor=growth_factor,
|
||||||
|
backoff_factor=backoff_factor,
|
||||||
|
growth_interval=growth_interval,
|
||||||
|
hysteresis=hysteresis,
|
||||||
|
max_scale=max_scale)
|
||||||
|
self.optim_state = OptimState.UNSCALED
|
||||||
|
self.found_overflow = torch.zeros(1, dtype=torch.float, device=get_current_device())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_scale(self) -> float:
|
||||||
|
return self.grad_scaler.scale.item()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def check_local_overflow(self) -> bool:
|
||||||
|
"""Check whether there is overflow in the local process. This method should be implemented by subclasses.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether there is overflow in the local process.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def check_overflow(self) -> bool:
|
||||||
|
# clear previous overflow record
|
||||||
|
self.found_overflow.fill_(0.0)
|
||||||
|
if self.check_local_overflow():
|
||||||
|
self.found_overflow.fill_(1.0)
|
||||||
|
dist.all_reduce(self.found_overflow, op=dist.ReduceOp.MAX)
|
||||||
|
return self.found_overflow.item() > 0
|
||||||
|
|
||||||
|
def pre_backward(self, loss: Tensor) -> Tensor:
|
||||||
|
loss = self.loss_scale * loss
|
||||||
|
self.optim_state = OptimState.SCALED
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def pre_backward_by_grad(self, tensor: Tensor, grad: Tensor) -> Tensor:
|
||||||
|
self.optim_state = OptimState.SCALED
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def should_skip_step(self) -> bool:
|
||||||
|
found_inf = self.check_overflow()
|
||||||
|
self.grad_scaler.update(found_inf)
|
||||||
|
if found_inf:
|
||||||
|
self.optim_state = OptimState.UNSCALED
|
||||||
|
return found_inf
|
||||||
|
|
||||||
|
def pre_zero_grad(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_grad_div_scale(self) -> float:
|
||||||
|
assert self.optim_state == OptimState.SCALED, 'grads should be scaled before clipping'
|
||||||
|
self.optim_state = OptimState.UNSCALED
|
||||||
|
return self.loss_scale
|
|
@ -23,6 +23,9 @@ from .dp_plugin_base import DPPluginBase
|
||||||
|
|
||||||
__all__ = ['GeminiPlugin']
|
__all__ = ['GeminiPlugin']
|
||||||
|
|
||||||
|
SUPPORTED_PRECISION = ['fp16', 'bf16']
|
||||||
|
PRECISION_STR_TO_DTYPE = {'fp16': torch.half, 'bf16': torch.bfloat16}
|
||||||
|
|
||||||
|
|
||||||
class GeminiCheckpointIO(GeneralCheckpointIO):
|
class GeminiCheckpointIO(GeneralCheckpointIO):
|
||||||
|
|
||||||
|
@ -171,6 +174,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
Args:
|
Args:
|
||||||
device (torch.device): device to place the model.
|
device (torch.device): device to place the model.
|
||||||
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
placement_policy (str, optional): "cpu", "cuda", "auto". Defaults to "cpu".
|
||||||
|
precision (str, optional): precision. Support 'fp16' and 'bf16'. Defaults to 'fp16'.
|
||||||
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
pin_memory (bool, optional): use pin memory on CPU. Defaults to False.
|
||||||
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
force_outputs_fp32 (bool, optional): force outputs are fp32. Defaults to False.
|
||||||
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
strict_ddp_mode (bool, optional): use strict ddp mode (only use dp without other parallelism). Defaults to False.
|
||||||
|
@ -203,6 +207,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
placement_policy: str = "cpu",
|
placement_policy: str = "cpu",
|
||||||
|
precision: str = "fp16",
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
force_outputs_fp32: bool = False,
|
force_outputs_fp32: bool = False,
|
||||||
strict_ddp_mode: bool = False,
|
strict_ddp_mode: bool = False,
|
||||||
|
@ -223,6 +228,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
assert precision in SUPPORTED_PRECISION, f'precision {precision} is not supported'
|
||||||
self.gemini_config = dict(
|
self.gemini_config = dict(
|
||||||
device=(device or get_current_device()),
|
device=(device or get_current_device()),
|
||||||
placement_policy=placement_policy,
|
placement_policy=placement_policy,
|
||||||
|
@ -233,6 +239,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
hidden_dim=hidden_dim,
|
hidden_dim=hidden_dim,
|
||||||
min_chunk_size_mb=min_chunk_size_mb,
|
min_chunk_size_mb=min_chunk_size_mb,
|
||||||
memstats=memstats,
|
memstats=memstats,
|
||||||
|
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||||
)
|
)
|
||||||
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
|
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio,)
|
||||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||||
|
@ -253,7 +260,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def supported_precisions(self) -> List[str]:
|
def supported_precisions(self) -> List[str]:
|
||||||
return ['fp16']
|
return SUPPORTED_PRECISION
|
||||||
|
|
||||||
def control_device(self) -> bool:
|
def control_device(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import warnings
|
import warnings
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -20,12 +21,15 @@ from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||||
__all__ = ['LowLevelZeroPlugin']
|
__all__ = ['LowLevelZeroPlugin']
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_fp16(x):
|
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
||||||
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
|
||||||
return x.half()
|
return x.to(dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
|
|
||||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||||
|
@ -49,17 +53,24 @@ class LowLevelZeroModel(ModelWrapper):
|
||||||
|
|
||||||
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
|
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.convert_inputs = (precision == 'fp16')
|
self.dtype = None
|
||||||
module = zero_model_wrapper(module, zero_stage=stage)
|
|
||||||
if precision == 'fp16':
|
if precision == 'fp16':
|
||||||
module = module.half()
|
self.dtype = torch.float16
|
||||||
|
elif precision == 'bf16':
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
module = zero_model_wrapper(module, zero_stage=stage)
|
||||||
|
if self.dtype is not None:
|
||||||
|
module = module.to(self.dtype)
|
||||||
module = module.to(get_current_device())
|
module = module.to(get_current_device())
|
||||||
self.module = module
|
self.module = module
|
||||||
|
self.convert_fn = None
|
||||||
|
if self.dtype is not None:
|
||||||
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
if self.convert_inputs:
|
if self.convert_fn is not None:
|
||||||
args = tree_map(_convert_to_fp16, args)
|
args = tree_map(self.convert_fn, args)
|
||||||
kwargs = tree_map(_convert_to_fp16, kwargs)
|
kwargs = tree_map(self.convert_fn, kwargs)
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -110,7 +121,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strage (int, optional): ZeRO stage. Defaults to 1.
|
strage (int, optional): ZeRO stage. Defaults to 1.
|
||||||
precision (str, optional): precision. Support 'fp16' and 'fp32'. Defaults to 'fp16'.
|
precision (str, optional): precision. Support 'fp16', 'bf16' and 'fp32'. Defaults to 'fp16'.
|
||||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||||
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
||||||
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
||||||
|
@ -149,7 +160,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
|
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
|
||||||
assert precision in ('fp16', 'fp32'), f'LowLevelZeroPlugin only supports fp16/fp32 training'
|
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
|
||||||
|
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
|
@ -175,7 +186,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def supported_precisions(self) -> List[str]:
|
def supported_precisions(self) -> List[str]:
|
||||||
return ['fp16', 'fp32']
|
return SUPPORTED_PRECISION
|
||||||
|
|
||||||
def control_device(self) -> bool:
|
def control_device(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -171,6 +171,21 @@
|
||||||
using g_scalar_t_##LEVEL = at::Half; \
|
using g_scalar_t_##LEVEL = at::Half; \
|
||||||
using p_scalar_t_##LEVEL = at::Half; \
|
using p_scalar_t_##LEVEL = at::Half; \
|
||||||
__VA_ARGS__; \
|
__VA_ARGS__; \
|
||||||
|
} else if (GTYPE == at::ScalarType::Float && \
|
||||||
|
PTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
using g_scalar_t_##LEVEL = float; \
|
||||||
|
using p_scalar_t_##LEVEL = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (GTYPE == at::ScalarType::BFloat16 && \
|
||||||
|
PTYPE == at::ScalarType::Float) { \
|
||||||
|
using g_scalar_t_##LEVEL = at::BFloat16; \
|
||||||
|
using p_scalar_t_##LEVEL = float; \
|
||||||
|
__VA_ARGS__; \
|
||||||
|
} else if (GTYPE == at::ScalarType::BFloat16 && \
|
||||||
|
PTYPE == at::ScalarType::BFloat16) { \
|
||||||
|
using g_scalar_t_##LEVEL = at::BFloat16; \
|
||||||
|
using p_scalar_t_##LEVEL = at::BFloat16; \
|
||||||
|
__VA_ARGS__; \
|
||||||
} else { \
|
} else { \
|
||||||
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
|
||||||
"'"); \
|
"'"); \
|
||||||
|
|
|
@ -93,8 +93,7 @@ class CPUAdam(NVMeOptimizer):
|
||||||
bias_correction1,
|
bias_correction1,
|
||||||
bias_correction2,
|
bias_correction2,
|
||||||
use_adamw=False):
|
use_adamw=False):
|
||||||
# FIXME(ver217): remove the below line when replace torch adam with fused adam
|
grad = grad.to(data.dtype)
|
||||||
grad = grad.float()
|
|
||||||
|
|
||||||
if weight_decay != 0:
|
if weight_decay != 0:
|
||||||
if use_adamw:
|
if use_adamw:
|
||||||
|
@ -133,10 +132,12 @@ class CPUAdam(NVMeOptimizer):
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
|
|
||||||
|
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||||
|
assert p.dtype is torch.float, "CPUAdam only support fp32 parameters"
|
||||||
# gradient momentums
|
# gradient momentums
|
||||||
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
|
state['exp_avg'] = torch.zeros_like(p, device=target_device)
|
||||||
# gradient variances
|
# gradient variances
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
|
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
|
||||||
self._post_state_init(p)
|
self._post_state_init(p)
|
||||||
|
|
||||||
state['step'] += 1
|
state['step'] += 1
|
||||||
|
@ -147,9 +148,17 @@ class CPUAdam(NVMeOptimizer):
|
||||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
if p.grad.dtype is torch.bfloat16:
|
||||||
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
# cpu adam kernel does not support bf16 now
|
||||||
state['exp_avg_sq'], div_scale)
|
bias_correction1 = 1 - beta1**state['step']
|
||||||
|
bias_correction2 = 1 - beta2**state['step']
|
||||||
|
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||||
|
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
|
||||||
|
bias_correction2, self.adamw_mode)
|
||||||
|
else:
|
||||||
|
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||||
|
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||||
|
state['exp_avg'], state['exp_avg_sq'], div_scale)
|
||||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||||
elif target_device.type == 'cuda':
|
elif target_device.type == 'cuda':
|
||||||
assert div_scale == -1, "div_scale should remain default"
|
assert div_scale == -1, "div_scale should remain default"
|
||||||
|
|
|
@ -134,8 +134,8 @@ class FusedAdam(torch.optim.Optimizer):
|
||||||
# Exponential moving average of squared gradient values
|
# Exponential moving average of squared gradient values
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p)
|
state['exp_avg_sq'] = torch.zeros_like(p)
|
||||||
|
|
||||||
if p.dtype not in [torch.float16, torch.float32]:
|
if p.dtype not in [torch.float16, torch.float32, torch.bfloat16]:
|
||||||
raise RuntimeError('FusedAdam only support fp16 and fp32.')
|
raise RuntimeError('FusedAdam only support fp16, fp32 and bf16.')
|
||||||
|
|
||||||
g_l.append(p.grad.data)
|
g_l.append(p.grad.data)
|
||||||
p_l.append(p.data)
|
p_l.append(p.data)
|
||||||
|
|
|
@ -1,16 +1,17 @@
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch.optim import Adam
|
||||||
|
|
||||||
from colossalai.kernel.op_builder import CPUAdamBuilder, FusedOptimBuilder
|
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||||
from colossalai.registry import OPTIMIZERS
|
from colossalai.registry import OPTIMIZERS
|
||||||
from colossalai.utils import multi_tensor_applier
|
from colossalai.utils import multi_tensor_applier
|
||||||
|
|
||||||
from .nvme_optimizer import NVMeOptimizer
|
from .cpu_adam import CPUAdam
|
||||||
|
|
||||||
|
|
||||||
@OPTIMIZERS.register_module
|
@OPTIMIZERS.register_module
|
||||||
class HybridAdam(NVMeOptimizer):
|
class HybridAdam(CPUAdam):
|
||||||
"""Implements Adam algorithm.
|
"""Implements Adam algorithm.
|
||||||
|
|
||||||
Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
|
Supports parameters updating on both GPU and CPU, depanding on the device of parameters.
|
||||||
|
@ -74,15 +75,9 @@ class HybridAdam(NVMeOptimizer):
|
||||||
nvme_offload_dir: Optional[str] = None,
|
nvme_offload_dir: Optional[str] = None,
|
||||||
**defaults: Any):
|
**defaults: Any):
|
||||||
|
|
||||||
default_args = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, bias_correction=bias_correction)
|
super().__init__(model_params, lr, bias_correction, betas, eps, weight_decay, adamw_mode, nvme_offload_fraction,
|
||||||
super(HybridAdam, self).__init__(model_params, default_args, nvme_offload_fraction, nvme_offload_dir)
|
nvme_offload_dir)
|
||||||
self.adamw_mode = adamw_mode
|
|
||||||
|
|
||||||
# build during runtime if not found
|
|
||||||
cpu_optim = CPUAdamBuilder().load()
|
|
||||||
fused_optim = FusedOptimBuilder().load()
|
fused_optim = FusedOptimBuilder().load()
|
||||||
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
|
||||||
|
|
||||||
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
self.gpu_adam_op = fused_optim.multi_tensor_adam
|
||||||
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
|
|
||||||
|
@ -108,10 +103,12 @@ class HybridAdam(NVMeOptimizer):
|
||||||
if len(state) == 0:
|
if len(state) == 0:
|
||||||
state['step'] = 0
|
state['step'] = 0
|
||||||
|
|
||||||
|
# FIXME(ver217): CPU adam kernel only supports fp32 states now
|
||||||
|
assert p.dtype is torch.float, "HybridAdam only support fp32 parameters"
|
||||||
# gradient momentums
|
# gradient momentums
|
||||||
state['exp_avg'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
|
state['exp_avg'] = torch.zeros_like(p, device=target_device)
|
||||||
# gradient variances
|
# gradient variances
|
||||||
state['exp_avg_sq'] = torch.zeros_like(p, dtype=torch.float, device=target_device)
|
state['exp_avg_sq'] = torch.zeros_like(p, device=target_device)
|
||||||
self._post_state_init(p)
|
self._post_state_init(p)
|
||||||
|
|
||||||
state['step'] += 1
|
state['step'] += 1
|
||||||
|
@ -122,9 +119,17 @@ class HybridAdam(NVMeOptimizer):
|
||||||
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
assert state['exp_avg_sq'].device.type == 'cpu', "exp_avg should stay on cpu"
|
||||||
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
self._pre_update(p, 'exp_avg', 'exp_avg_sq')
|
||||||
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'], group['weight_decay'],
|
if p.grad.dtype is torch.bfloat16:
|
||||||
group['bias_correction'], p.data, p.grad.data, state['exp_avg'],
|
# cpu adam kernel does not support bf16 now
|
||||||
state['exp_avg_sq'], div_scale)
|
bias_correction1 = 1 - beta1**state['step']
|
||||||
|
bias_correction2 = 1 - beta2**state['step']
|
||||||
|
self.torch_adam_update(p.data, p.grad.data, state['exp_avg'], state['exp_avg_sq'], group['lr'],
|
||||||
|
beta1, beta2, group['eps'], group['weight_decay'], bias_correction1,
|
||||||
|
bias_correction2, self.adamw_mode)
|
||||||
|
else:
|
||||||
|
self.cpu_adam_op.step(state['step'], group['lr'], beta1, beta2, group['eps'],
|
||||||
|
group['weight_decay'], group['bias_correction'], p.data, p.grad.data,
|
||||||
|
state['exp_avg'], state['exp_avg_sq'], div_scale)
|
||||||
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
self._post_update(p, 'exp_avg', 'exp_avg_sq')
|
||||||
|
|
||||||
elif target_device.type == 'cuda':
|
elif target_device.type == 'cuda':
|
||||||
|
|
|
@ -51,6 +51,7 @@ class ZeroDDP(ColoDDP):
|
||||||
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
strict_ddp_mode (bool): If set to True, there is no tensor sharding, each tensor is replicated.
|
||||||
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
Defaults to False. Users can set it to True, when they clearly know that they only need DDP.
|
||||||
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
|
scatter_after_inference (bool): If set to True, the model will be scattered after inference. This will save memory but slow down the consecutive inference.
|
||||||
|
mixed_precision (torch.dtype): If set to torch.float16, the model will be trained in fp16. Otherwise, the model will be trained in bf16. Defaults to torch.float16.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -59,7 +60,9 @@ class ZeroDDP(ColoDDP):
|
||||||
pin_memory: bool = False,
|
pin_memory: bool = False,
|
||||||
force_outputs_fp32: bool = False,
|
force_outputs_fp32: bool = False,
|
||||||
strict_ddp_mode: bool = False,
|
strict_ddp_mode: bool = False,
|
||||||
scatter_after_inference: bool = True) -> None:
|
scatter_after_inference: bool = True,
|
||||||
|
mixed_precision: torch.dtype = torch.float16) -> None:
|
||||||
|
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||||
self.gemini_manager = gemini_manager
|
self.gemini_manager = gemini_manager
|
||||||
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
self.chunk_manager: ChunkManager = gemini_manager.chunk_manager
|
||||||
self.force_outputs_fp32 = force_outputs_fp32
|
self.force_outputs_fp32 = force_outputs_fp32
|
||||||
|
@ -71,6 +74,7 @@ class ZeroDDP(ColoDDP):
|
||||||
self.param2name: Dict[nn.Parameter, str] = dict()
|
self.param2name: Dict[nn.Parameter, str] = dict()
|
||||||
self.name2param: Dict[str, nn.Parameter] = dict()
|
self.name2param: Dict[str, nn.Parameter] = dict()
|
||||||
self.scatter_after_inference = scatter_after_inference
|
self.scatter_after_inference = scatter_after_inference
|
||||||
|
self.mixed_precision = mixed_precision
|
||||||
|
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
|
|
||||||
|
@ -151,7 +155,7 @@ class ZeroDDP(ColoDDP):
|
||||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||||
), "You should run a completed iteration as your warmup iter"
|
), "You should run a completed iteration as your warmup iter"
|
||||||
|
|
||||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
args, kwargs = _cast_float(args, self.mixed_precision), _cast_float(kwargs, self.mixed_precision)
|
||||||
self.module.zero_grad(set_to_none=True)
|
self.module.zero_grad(set_to_none=True)
|
||||||
if not grad_flag:
|
if not grad_flag:
|
||||||
outputs = self._inference_forward(*args, **kwargs)
|
outputs = self._inference_forward(*args, **kwargs)
|
||||||
|
@ -570,14 +574,14 @@ class ZeroDDP(ColoDDP):
|
||||||
|
|
||||||
# move ignored parameters to CUDA
|
# move ignored parameters to CUDA
|
||||||
if is_ddp_ignored(p):
|
if is_ddp_ignored(p):
|
||||||
p.data = p.data.to(device=get_current_device(), dtype=torch.float16)
|
p.data = p.data.to(device=get_current_device(), dtype=self.mixed_precision)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# create a fp32 parameter
|
# create a fp32 parameter
|
||||||
fp32_data = p.data.float()
|
fp32_data = p.data.float()
|
||||||
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
fp32_p = ColoTensor(fp32_data, spec=ColoTensorSpec(p.process_group))
|
||||||
# create a fp16 parameter
|
# create a fp16 parameter
|
||||||
p.data = p.data.half()
|
p.data = p.data.to(self.mixed_precision)
|
||||||
|
|
||||||
# register the fp16 parameter and fp32 parameter in the chunk manager
|
# register the fp16 parameter and fp32 parameter in the chunk manager
|
||||||
dp_world_size = p.process_group.dp_world_size()
|
dp_world_size = p.process_group.dp_world_size()
|
||||||
|
@ -613,7 +617,7 @@ class ZeroDDP(ColoDDP):
|
||||||
buffer.materialize()
|
buffer.materialize()
|
||||||
buffer.data = buffer.cuda()
|
buffer.data = buffer.cuda()
|
||||||
if torch.is_floating_point(buffer):
|
if torch.is_floating_point(buffer):
|
||||||
buffer.data = buffer.half()
|
buffer.data = buffer.to(self.mixed_precision)
|
||||||
|
|
||||||
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
def _preprocess_param(self, p: Union[nn.Parameter, ColoParameter, 'LazyTensor']) -> None:
|
||||||
"""Convert parameter to ColoParameter in-place.
|
"""Convert parameter to ColoParameter in-place.
|
||||||
|
@ -736,6 +740,7 @@ class GeminiDDP(ZeroDDP):
|
||||||
hidden_dim: Optional[int] = None,
|
hidden_dim: Optional[int] = None,
|
||||||
min_chunk_size_mb: float = 32,
|
min_chunk_size_mb: float = 32,
|
||||||
memstats: Optional[MemStats] = None,
|
memstats: Optional[MemStats] = None,
|
||||||
|
mixed_precision: torch.dtype = torch.float16,
|
||||||
verbose: bool = False) -> None:
|
verbose: bool = False) -> None:
|
||||||
"""
|
"""
|
||||||
A torch.Module wrapper using ZeRO-DP and Gemini.
|
A torch.Module wrapper using ZeRO-DP and Gemini.
|
||||||
|
@ -776,5 +781,10 @@ class GeminiDDP(ZeroDDP):
|
||||||
strict_ddp_flag=strict_ddp_mode,
|
strict_ddp_flag=strict_ddp_mode,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager, memstats)
|
||||||
super().__init__(module, gemini_manager, pin_memory, force_outputs_fp32, strict_ddp_mode,
|
super().__init__(module,
|
||||||
scatter_after_inference)
|
gemini_manager,
|
||||||
|
pin_memory,
|
||||||
|
force_outputs_fp32,
|
||||||
|
strict_ddp_mode,
|
||||||
|
scatter_after_inference,
|
||||||
|
mixed_precision=mixed_precision)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Dict, Set, Tuple
|
from typing import Any, Dict, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -9,7 +8,7 @@ import torch.distributed as dist
|
||||||
from torch.nn import Parameter
|
from torch.nn import Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
from colossalai.nn.optimizer import ColossalaiOptimizer, CPUAdam, FusedAdam, HybridAdam
|
||||||
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
from colossalai.utils import disposable, get_current_device, is_ddp_ignored
|
||||||
|
@ -22,9 +21,26 @@ __all__ = ['ZeroOptimizer', 'GeminiAdamOptimizer']
|
||||||
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
_AVAIL_OPTIM_LIST = {FusedAdam, CPUAdam, HybridAdam}
|
||||||
|
|
||||||
|
|
||||||
class OptimState(Enum):
|
class GeminiFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||||
SCALED = 0
|
|
||||||
UNSCALED = 1
|
def __init__(self,
|
||||||
|
module: ZeroDDP,
|
||||||
|
initial_scale: float = 2**16,
|
||||||
|
min_scale: float = 1,
|
||||||
|
growth_factor: float = 2,
|
||||||
|
backoff_factor: float = 0.5,
|
||||||
|
growth_interval: int = 1000,
|
||||||
|
hysteresis: int = 2,
|
||||||
|
max_scale: float = 2**32) -> None:
|
||||||
|
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||||
|
max_scale)
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
def check_local_overflow(self) -> bool:
|
||||||
|
return self.module.overflow_counter > 0
|
||||||
|
|
||||||
|
def pre_zero_grad(self) -> None:
|
||||||
|
self.module.overflow_counter = 0
|
||||||
|
|
||||||
|
|
||||||
class ZeroOptimizer(ColossalaiOptimizer):
|
class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
|
@ -79,7 +95,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
self.module = module
|
self.module = module
|
||||||
self.gemini_manager = module.gemini_manager
|
self.gemini_manager = module.gemini_manager
|
||||||
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
self.chunk_manager: ChunkManager = self.gemini_manager.chunk_manager
|
||||||
self.optim_state = OptimState.UNSCALED
|
|
||||||
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
self.param_to_range: Dict[Parameter, Tuple[int, int]] = dict()
|
||||||
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
self.param_to_chunk32: Dict[Parameter, Chunk] = dict()
|
||||||
self.chunk16_set: Set[Chunk] = set()
|
self.chunk16_set: Set[Chunk] = set()
|
||||||
|
@ -107,15 +122,20 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
|
|
||||||
self.__init__optimizer()
|
self.__init__optimizer()
|
||||||
|
|
||||||
# Grad scaler
|
if module.mixed_precision is torch.float16:
|
||||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
self.mix_precision_mixin = GeminiFP16MixedPrecisionMixin(module,
|
||||||
min_scale=min_scale,
|
initial_scale=initial_scale,
|
||||||
growth_factor=growth_factor,
|
min_scale=min_scale,
|
||||||
backoff_factor=backoff_factor,
|
growth_factor=growth_factor,
|
||||||
growth_interval=growth_interval,
|
backoff_factor=backoff_factor,
|
||||||
hysteresis=hysteresis,
|
growth_interval=growth_interval,
|
||||||
max_scale=max_scale)
|
hysteresis=hysteresis,
|
||||||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=get_current_device())
|
max_scale=max_scale)
|
||||||
|
elif module.mixed_precision is torch.bfloat16:
|
||||||
|
self.mix_precision_mixin = BF16MixedPrecisionMixin()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unsupported mixed precision type: {module.mixed_precision}")
|
||||||
|
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
|
|
||||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||||
|
@ -151,15 +171,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
for chunk16 in self.chunk16_set:
|
for chunk16 in self.chunk16_set:
|
||||||
chunk16.optim_update()
|
chunk16.optim_update()
|
||||||
|
|
||||||
def _check_overflow(self):
|
|
||||||
# clear previous overflow record
|
|
||||||
self._found_overflow.fill_(self.module.overflow_counter)
|
|
||||||
|
|
||||||
# all-reduce across global group
|
|
||||||
dist.all_reduce(self._found_overflow)
|
|
||||||
|
|
||||||
return self._found_overflow.item() > 0
|
|
||||||
|
|
||||||
def _clear_global_norm(self) -> None:
|
def _clear_global_norm(self) -> None:
|
||||||
for c16 in self.chunk16_set:
|
for c16 in self.chunk16_set:
|
||||||
c16.l2_norm = None
|
c16.l2_norm = None
|
||||||
|
@ -190,40 +201,25 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
return global_norm
|
return global_norm
|
||||||
|
|
||||||
def _get_combined_scale(self):
|
def _get_combined_scale(self):
|
||||||
loss_scale = 1
|
div_scale = self.mix_precision_mixin.get_grad_div_scale()
|
||||||
|
|
||||||
if self.optim_state == OptimState.SCALED:
|
|
||||||
loss_scale = self.loss_scale
|
|
||||||
self.optim_state = OptimState.UNSCALED
|
|
||||||
|
|
||||||
combined_scale = loss_scale
|
|
||||||
if self.clipping_flag:
|
if self.clipping_flag:
|
||||||
total_norm = self._calc_global_norm()
|
total_norm = self._calc_global_norm()
|
||||||
clip = ((total_norm / loss_scale) + 1e-6) / self.max_norm
|
clip = ((total_norm / div_scale) + 1e-6) / self.max_norm
|
||||||
if clip > 1:
|
if clip > 1:
|
||||||
combined_scale = clip * loss_scale
|
div_scale = clip * div_scale
|
||||||
|
|
||||||
if combined_scale == 1:
|
return -1 if div_scale == 1.0 else div_scale
|
||||||
return -1
|
|
||||||
else:
|
|
||||||
return combined_scale
|
|
||||||
|
|
||||||
@property
|
|
||||||
def loss_scale(self):
|
|
||||||
return self.grad_scaler.scale.item()
|
|
||||||
|
|
||||||
def zero_grad(self, *args, **kwargs):
|
def zero_grad(self, *args, **kwargs):
|
||||||
self.module.overflow_counter = 0
|
self.mix_precision_mixin.pre_zero_grad()
|
||||||
return self.optim.zero_grad(set_to_none=True)
|
return self.optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
self._maybe_move_fp32_params()
|
self._maybe_move_fp32_params()
|
||||||
self._set_grad_ptr()
|
self._set_grad_ptr()
|
||||||
|
|
||||||
found_inf = self._check_overflow()
|
if self.mix_precision_mixin.should_skip_step():
|
||||||
if found_inf:
|
|
||||||
self.optim_state = OptimState.UNSCALED # no need to unscale grad
|
|
||||||
self.grad_scaler.update(found_inf) # update gradient scaler
|
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
self._logger.info(f'Found overflow. Skip step')
|
self._logger.info(f'Found overflow. Skip step')
|
||||||
self._clear_global_norm() # clear recorded norm
|
self._clear_global_norm() # clear recorded norm
|
||||||
|
@ -234,7 +230,6 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
# get combined scale. combined scale = loss scale * clipping norm
|
# get combined scale. combined scale = loss scale * clipping norm
|
||||||
# so that gradient = gradient / combined scale
|
# so that gradient = gradient / combined scale
|
||||||
combined_scale = self._get_combined_scale()
|
combined_scale = self._get_combined_scale()
|
||||||
self.grad_scaler.update(found_inf)
|
|
||||||
|
|
||||||
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
ret = self.optim.step(div_scale=combined_scale, *args, **kwargs)
|
||||||
self._register_states()
|
self._register_states()
|
||||||
|
@ -246,8 +241,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor):
|
def backward(self, loss: torch.Tensor):
|
||||||
loss = self.loss_scale * loss
|
loss = self.mix_precision_mixin.pre_backward(loss)
|
||||||
self.optim_state = OptimState.SCALED
|
|
||||||
self.module.backward(loss)
|
self.module.backward(loss)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||||
|
@ -255,7 +249,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
||||||
# It receives the scaled grad from the previous rank
|
# It receives the scaled grad from the previous rank
|
||||||
# No need to scale the grad again
|
# No need to scale the grad again
|
||||||
# Need to unscale when optimizing
|
# Need to unscale when optimizing
|
||||||
self.optim_state = OptimState.SCALED
|
grad = self.mix_precision_mixin.pre_backward_by_grad(grad)
|
||||||
self.module.backward_by_grad(tensor, grad)
|
self.module.backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
def _maybe_move_fp32_params(self):
|
def _maybe_move_fp32_params(self):
|
||||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||||
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
from colossalai.zero.legacy.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_fp16
|
from colossalai.zero.legacy.sharded_model._utils import cast_tensor_to_bf16, cast_tensor_to_fp16
|
||||||
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
|
from colossalai.zero.legacy.sharded_model.sharded_model_v2 import ShardedModelV2
|
||||||
from colossalai.zero.legacy.sharded_param import ShardedParamV2
|
from colossalai.zero.legacy.sharded_param import ShardedParamV2
|
||||||
|
|
||||||
|
@ -55,6 +55,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
seed (int, optional): Random seed for weight initialization
|
seed (int, optional): Random seed for weight initialization
|
||||||
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
shard_param (bool, optional): Is param sharded after exiting the context. Defaults to False.
|
||||||
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
|
default_dtype (torch.dtype, optional): If it's not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
|
||||||
|
bf16 (bool, optional): If it's True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
|
||||||
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
model_numel_tensor (torch.Tensor, optional): A tensor which will store the number of elements of model. Defaults to torch.zeros(1, dtype=torch.int).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -64,6 +65,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
seed: int = 2**10 - 1,
|
seed: int = 2**10 - 1,
|
||||||
shard_param: bool = False,
|
shard_param: bool = False,
|
||||||
default_dtype: Optional[torch.dtype] = None,
|
default_dtype: Optional[torch.dtype] = None,
|
||||||
|
bf16: bool = False,
|
||||||
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
|
model_numel_tensor: torch.Tensor = torch.zeros(1, dtype=torch.long)):
|
||||||
|
|
||||||
super().__init__(default_dtype=default_dtype)
|
super().__init__(default_dtype=default_dtype)
|
||||||
|
@ -71,6 +73,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
self.param_list = []
|
self.param_list = []
|
||||||
self.model_numel_tensor = model_numel_tensor
|
self.model_numel_tensor = model_numel_tensor
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
|
self.bf16 = bf16
|
||||||
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
self.dp_process_group = gpc.get_group(ParallelMode.DATA)
|
||||||
|
|
||||||
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
self.config = ZeroContextConfig(target_device=target_device, is_replicated=True, shard_param=shard_param)
|
||||||
|
@ -183,9 +186,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
NOTE() The module may be passed to this function multiple times.
|
NOTE() The module may be passed to this function multiple times.
|
||||||
"""
|
"""
|
||||||
self.top_module = module
|
self.top_module = module
|
||||||
|
half_dtype = torch.float16 if not self.bf16 else torch.bfloat16
|
||||||
|
|
||||||
def half_fn(t: torch.Tensor):
|
def half_fn(t: torch.Tensor):
|
||||||
return t.half() if t.is_floating_point() else t
|
return t.to(half_dtype) if t.is_floating_point() else t
|
||||||
|
|
||||||
for param in module.parameters(recurse=False):
|
for param in module.parameters(recurse=False):
|
||||||
# avoid adapting a param to ShardedParam twice
|
# avoid adapting a param to ShardedParam twice
|
||||||
|
@ -226,9 +230,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
# We must cast buffers
|
# We must cast buffers
|
||||||
# If we use BN, buffers may be on CPU and Float
|
# If we use BN, buffers may be on CPU and Float
|
||||||
# We must cast them
|
# We must cast them
|
||||||
|
cast_fn = cast_tensor_to_fp16 if not self.bf16 else cast_tensor_to_bf16
|
||||||
for buffer in module.buffers(recurse=False):
|
for buffer in module.buffers(recurse=False):
|
||||||
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
buffer.data = buffer.data.to(device=torch.cuda.current_device())
|
||||||
buffer.data = cast_tensor_to_fp16(buffer.data)
|
buffer.data = cast_fn(buffer.data)
|
||||||
|
|
||||||
|
|
||||||
class ZeroContextMgr(metaclass=SingletonMeta):
|
class ZeroContextMgr(metaclass=SingletonMeta):
|
||||||
|
|
|
@ -43,11 +43,19 @@ def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Te
|
||||||
if isinstance(tensor, StatefulTensor):
|
if isinstance(tensor, StatefulTensor):
|
||||||
tensor = tensor.payload
|
tensor = tensor.payload
|
||||||
|
|
||||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
|
if torch.is_floating_point(tensor) and tensor.dtype in (torch.float16, torch.bfloat16):
|
||||||
return tensor.float()
|
return tensor.float()
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def cast_tensor_to_bf16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
if isinstance(tensor, StatefulTensor):
|
||||||
|
tensor = tensor.payload
|
||||||
|
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
||||||
|
return tensor.bfloat16()
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def apply_to_tensors(x: Any, fn: Callable):
|
def apply_to_tensors(x: Any, fn: Callable):
|
||||||
if torch.is_tensor(x):
|
if torch.is_tensor(x):
|
||||||
return fn(x)
|
return fn(x)
|
||||||
|
|
|
@ -28,6 +28,7 @@ from colossalai.zero.legacy.sharded_model.reduce_scatter import ReduceScatterBuc
|
||||||
|
|
||||||
from ._utils import (
|
from ._utils import (
|
||||||
cast_float_arguments,
|
cast_float_arguments,
|
||||||
|
cast_tensor_to_bf16,
|
||||||
cast_tensor_to_fp16,
|
cast_tensor_to_fp16,
|
||||||
cast_tensor_to_fp32,
|
cast_tensor_to_fp32,
|
||||||
chunk_and_pad,
|
chunk_and_pad,
|
||||||
|
@ -74,6 +75,7 @@ class ShardedModelV2(nn.Module):
|
||||||
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
In this mode, grad will be fp16. Make sure your optimizer supports mixed precision (fp32 param and fp16 grad).
|
||||||
We find that PyTorch's optimizers don't support mixed precision,
|
We find that PyTorch's optimizers don't support mixed precision,
|
||||||
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
|
so we recommend you enable this only when using our CPUAdam with CPU offload. Defaults to False.
|
||||||
|
bf16 (bool, optional): Whether to use bfloat16 for param and grad. Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -86,11 +88,13 @@ class ShardedModelV2(nn.Module):
|
||||||
tensor_placement_policy: str = 'cuda',
|
tensor_placement_policy: str = 'cuda',
|
||||||
gradient_predivide_factor: Optional[float] = 1.0,
|
gradient_predivide_factor: Optional[float] = 1.0,
|
||||||
reuse_fp16_shard: bool = False,
|
reuse_fp16_shard: bool = False,
|
||||||
|
bf16: bool = False,
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
assert not isinstance(module, ShardedModelV2), 'Nested ShardedModelV2 is not supported.'
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
self.bf16 = bf16
|
||||||
|
|
||||||
# We force users to use ZeroInitContext
|
# We force users to use ZeroInitContext
|
||||||
for submodule in module.modules():
|
for submodule in module.modules():
|
||||||
|
@ -232,7 +236,8 @@ class ShardedModelV2(nn.Module):
|
||||||
|
|
||||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
self._pre_forward_operations(*args)
|
self._pre_forward_operations(*args)
|
||||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
cast_fn = cast_tensor_to_bf16 if self.bf16 else cast_tensor_to_fp16
|
||||||
|
args, kwargs = cast_float_arguments(cast_fn, *args, **kwargs)
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
self._post_forward_operations()
|
self._post_forward_operations()
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
@ -94,6 +94,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
super().__init__(optimizer)
|
super().__init__(optimizer)
|
||||||
self.shard_strategy = sharded_model.shard_strategy
|
self.shard_strategy = sharded_model.shard_strategy
|
||||||
self.model: ShardedModelV2 = sharded_model
|
self.model: ShardedModelV2 = sharded_model
|
||||||
|
self.bf16 = sharded_model.bf16
|
||||||
|
|
||||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||||
|
@ -117,6 +118,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
|
self._found_overflow: Tensor = torch.IntTensor([0]).to(torch.cuda.current_device())
|
||||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||||
self._verbose = verbose
|
self._verbose = verbose
|
||||||
|
self._grad_prepared: bool = False # this should be set to true when _prepare_grads() and reset to false when backward
|
||||||
|
|
||||||
# Store fp32 param shards
|
# Store fp32 param shards
|
||||||
self._register_master_weight()
|
self._register_master_weight()
|
||||||
|
@ -166,8 +168,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self._zero_grad()
|
self._zero_grad()
|
||||||
|
|
||||||
def backward(self, loss: Tensor) -> None:
|
def backward(self, loss: Tensor) -> None:
|
||||||
loss = self.loss_scale * loss
|
if not self.bf16:
|
||||||
self.optim_state = OptimState.SCALED
|
loss = self.loss_scale * loss
|
||||||
|
self.optim_state = OptimState.SCALED
|
||||||
|
self._grad_prepared = False
|
||||||
self.model.backward(loss)
|
self.model.backward(loss)
|
||||||
|
|
||||||
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
||||||
|
@ -175,30 +179,33 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# It receives the scaled grad from the previous rank
|
# It receives the scaled grad from the previous rank
|
||||||
# No need to scale the grad again
|
# No need to scale the grad again
|
||||||
# Need to unscale when optimizing
|
# Need to unscale when optimizing
|
||||||
self.optim_state = OptimState.SCALED
|
if not self.bf16:
|
||||||
|
self.optim_state = OptimState.SCALED
|
||||||
|
self._grad_prepared = False
|
||||||
self.model.backward_by_grad(tensor, grad)
|
self.model.backward_by_grad(tensor, grad)
|
||||||
|
|
||||||
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
||||||
if self.optim_state == OptimState.SCALED:
|
self._prepare_grads()
|
||||||
self._prepare_grads()
|
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||||
self._unscale_grads()
|
self._unscale_grads()
|
||||||
return super().clip_grad_norm(model, max_norm)
|
return super().clip_grad_norm(model, max_norm)
|
||||||
|
|
||||||
def step(self, *args, **kwargs):
|
def step(self, *args, **kwargs):
|
||||||
|
|
||||||
|
self._prepare_grads()
|
||||||
# unscale grads if scaled
|
# unscale grads if scaled
|
||||||
if self.optim_state == OptimState.SCALED:
|
if not self.bf16 and self.optim_state == OptimState.SCALED:
|
||||||
self._prepare_grads()
|
|
||||||
self._unscale_grads()
|
self._unscale_grads()
|
||||||
|
|
||||||
self._maybe_move_fp32_shards()
|
self._maybe_move_fp32_shards()
|
||||||
found_inf = self._check_overflow()
|
if not self.bf16:
|
||||||
self.grad_scaler.update(found_inf)
|
found_inf = self._check_overflow()
|
||||||
|
self.grad_scaler.update(found_inf)
|
||||||
|
|
||||||
if found_inf:
|
if found_inf:
|
||||||
self._logger.warning('found inf during ShardedOptimV2 step')
|
self._logger.warning('found inf during ShardedOptimV2 step')
|
||||||
self._zero_grad(recover_data=True)
|
self._zero_grad(recover_data=True)
|
||||||
return
|
return
|
||||||
|
|
||||||
self._point_param_fp16_to_master_param()
|
self._point_param_fp16_to_master_param()
|
||||||
|
|
||||||
|
@ -304,6 +311,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
state[k] = v.cuda()
|
state[k] = v.cuda()
|
||||||
|
|
||||||
def _prepare_grads(self):
|
def _prepare_grads(self):
|
||||||
|
if self._grad_prepared:
|
||||||
|
return
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
if p.colo_attr.saved_grad.is_null():
|
if p.colo_attr.saved_grad.is_null():
|
||||||
|
@ -320,6 +329,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
p.grad = p.colo_attr.grad_payload
|
p.grad = p.colo_attr.grad_payload
|
||||||
# Set p.data to empty tensor, in case of memory leaking
|
# Set p.data to empty tensor, in case of memory leaking
|
||||||
p.colo_attr.set_data_none()
|
p.colo_attr.set_data_none()
|
||||||
|
self._grad_prepared = True
|
||||||
|
|
||||||
def _point_param_fp16_to_master_param(self):
|
def _point_param_fp16_to_master_param(self):
|
||||||
# assign master param pointers to p.data.
|
# assign master param pointers to p.data.
|
||||||
|
@ -357,7 +367,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
||||||
|
|
||||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||||
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
|
half_dtype = torch.bfloat16 if self.bf16 else torch.float16
|
||||||
|
p.colo_attr.sharded_data_tensor.payload_copy(p.to(half_dtype).detach())
|
||||||
p.colo_attr.set_data_none()
|
p.colo_attr.set_data_none()
|
||||||
|
|
||||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||||
|
|
|
@ -6,7 +6,11 @@ import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||||
|
BF16MixedPrecisionMixin,
|
||||||
|
FP16MixedPrecisionMixin,
|
||||||
|
MixedPrecisionMixin,
|
||||||
|
)
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
@ -27,6 +31,31 @@ from ._utils import (
|
||||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
from .bookkeeping import BucketStore, GradientStore, ParameterStore, TensorBucket
|
||||||
|
|
||||||
|
|
||||||
|
class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
num_working_param_groups: int,
|
||||||
|
grad_store: GradientStore,
|
||||||
|
initial_scale: float = 2**16,
|
||||||
|
min_scale: float = 1,
|
||||||
|
growth_factor: float = 2,
|
||||||
|
backoff_factor: float = 0.5,
|
||||||
|
growth_interval: int = 1000,
|
||||||
|
hysteresis: int = 2,
|
||||||
|
max_scale: float = 2**32) -> None:
|
||||||
|
super().__init__(initial_scale, min_scale, growth_factor, backoff_factor, growth_interval, hysteresis,
|
||||||
|
max_scale)
|
||||||
|
self.num_working_param_groups = num_working_param_groups
|
||||||
|
self.grad_store = grad_store
|
||||||
|
|
||||||
|
def check_local_overflow(self) -> bool:
|
||||||
|
for group_id in range(self.num_working_param_groups):
|
||||||
|
for avg_grad in self.grad_store.get_averaged_gradients_by_group(group_id):
|
||||||
|
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||||
"""
|
"""
|
||||||
|
@ -100,17 +129,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
self._reduce_bucket_size = reduce_bucket_size
|
self._reduce_bucket_size = reduce_bucket_size
|
||||||
self._communication_dtype = communication_dtype
|
self._communication_dtype = communication_dtype
|
||||||
|
|
||||||
# gradient scaler
|
|
||||||
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
|
||||||
min_scale=min_scale,
|
|
||||||
growth_factor=growth_factor,
|
|
||||||
backoff_factor=backoff_factor,
|
|
||||||
growth_interval=growth_interval,
|
|
||||||
hysteresis=hysteresis,
|
|
||||||
max_scale=max_scale,
|
|
||||||
verbose=verbose)
|
|
||||||
self._found_overflow = torch.FloatTensor([0]).to(get_current_device())
|
|
||||||
|
|
||||||
# gradient clipping
|
# gradient clipping
|
||||||
self._clip_grad_norm = clip_grad_norm
|
self._clip_grad_norm = clip_grad_norm
|
||||||
|
|
||||||
|
@ -200,14 +218,25 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
if self._overlap_communication or self._partition_grads:
|
if self._overlap_communication or self._partition_grads:
|
||||||
self._attach_reduction_hook()
|
self._attach_reduction_hook()
|
||||||
|
|
||||||
|
# initialize mixed precision mixin
|
||||||
|
self.mixed_precision_mixin: Optional[MixedPrecisionMixin] = None
|
||||||
|
if self._dtype is torch.float16:
|
||||||
|
self.mixed_precision_mixin = LowLevelZeroFP16MixedPrecisionMixin(self.num_param_groups,
|
||||||
|
self._grad_store,
|
||||||
|
initial_scale=initial_scale,
|
||||||
|
min_scale=min_scale,
|
||||||
|
growth_factor=growth_factor,
|
||||||
|
backoff_factor=backoff_factor,
|
||||||
|
growth_interval=growth_interval,
|
||||||
|
hysteresis=hysteresis,
|
||||||
|
max_scale=max_scale)
|
||||||
|
elif self._dtype is torch.bfloat16:
|
||||||
|
self.mixed_precision_mixin = BF16MixedPrecisionMixin()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
return self._dtype
|
return self._dtype
|
||||||
|
|
||||||
@property
|
|
||||||
def loss_scale(self):
|
|
||||||
return self.grad_scaler.scale
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_param_groups(self):
|
def num_param_groups(self):
|
||||||
return len(self._working_param_groups)
|
return len(self._working_param_groups)
|
||||||
|
@ -392,7 +421,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
################################
|
################################
|
||||||
|
|
||||||
def backward(self, loss, retain_graph=False, sync_grad=True):
|
def backward(self, loss, retain_graph=False, sync_grad=True):
|
||||||
loss = self.loss_scale * loss
|
if self.mixed_precision_mixin is not None:
|
||||||
|
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||||
loss.backward(retain_graph=retain_graph)
|
loss.backward(retain_graph=retain_graph)
|
||||||
|
|
||||||
# finish gradient reduction
|
# finish gradient reduction
|
||||||
|
@ -419,6 +449,8 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
:param set_to_none: Whether set the gradient to None. Default value is True.
|
:param set_to_none: Whether set the gradient to None. Default value is True.
|
||||||
:type set_to_none: bool
|
:type set_to_none: bool
|
||||||
"""
|
"""
|
||||||
|
if self.mixed_precision_mixin is not None:
|
||||||
|
self.mixed_precision_mixin.pre_zero_grad()
|
||||||
for _, param_group in self._working_param_groups.items():
|
for _, param_group in self._working_param_groups.items():
|
||||||
for param in param_group:
|
for param in param_group:
|
||||||
if set_to_none:
|
if set_to_none:
|
||||||
|
@ -435,12 +467,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
def step(self, closure=None):
|
def step(self, closure=None):
|
||||||
assert closure is None, 'closure is not supported by step()'
|
assert closure is None, 'closure is not supported by step()'
|
||||||
|
|
||||||
# check for overflow
|
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||||
found_inf = self._check_overflow()
|
|
||||||
self.grad_scaler.update(found_inf)
|
|
||||||
|
|
||||||
# update loss scale if overflow occurs
|
|
||||||
if found_inf:
|
|
||||||
self._grad_store.reset_all_average_gradients()
|
self._grad_store.reset_all_average_gradients()
|
||||||
if self._verbose:
|
if self._verbose:
|
||||||
self._logger.info(f'Found overflow. Skip step')
|
self._logger.info(f'Found overflow. Skip step')
|
||||||
|
@ -507,41 +534,20 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||||
# Mixed Precision Utilities #
|
# Mixed Precision Utilities #
|
||||||
#############################
|
#############################
|
||||||
|
|
||||||
def _check_overflow(self):
|
|
||||||
# clear previous overflow record
|
|
||||||
self._found_overflow.fill_(0.0)
|
|
||||||
|
|
||||||
# check for overflow
|
|
||||||
for group_id in range(len(self._working_param_groups)):
|
|
||||||
for avg_grad in self._grad_store.get_averaged_gradients_by_group(group_id):
|
|
||||||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
|
||||||
self._found_overflow.fill_(1.0)
|
|
||||||
break
|
|
||||||
|
|
||||||
# all-reduce across dp group
|
|
||||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._dp_torch_group)
|
|
||||||
|
|
||||||
# all-reduce over model parallel group
|
|
||||||
if self._mp_torch_group:
|
|
||||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self._mp_torch_group)
|
|
||||||
|
|
||||||
if self._found_overflow.item() > 0:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm):
|
||||||
# compute combined scale factor for this group
|
# compute combined scale factor for this group
|
||||||
combined_scale = self.loss_scale
|
div_scale = 1.0
|
||||||
|
if self.mixed_precision_mixin is not None:
|
||||||
|
div_scale = self.mixed_precision_mixin.get_grad_div_scale()
|
||||||
|
|
||||||
if self._clip_grad_norm > 0.:
|
if self._clip_grad_norm > 0.:
|
||||||
# norm is in fact norm*scale
|
# norm is in fact norm*scale
|
||||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self._clip_grad_norm
|
clip = ((total_norm / div_scale) + 1e-6) / self._clip_grad_norm
|
||||||
if clip > 1:
|
if clip > 1:
|
||||||
combined_scale = clip * self.loss_scale
|
div_scale = clip * div_scale
|
||||||
|
|
||||||
for grad in grad_groups_flat:
|
for grad in grad_groups_flat:
|
||||||
grad.data.mul_(1. / combined_scale)
|
grad.data.mul_(1. / div_scale)
|
||||||
|
|
||||||
############################
|
############################
|
||||||
# Gradient Synchronization #
|
# Gradient Synchronization #
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
# This test checks adam kernels
|
||||||
|
# Baseline is pure fp32 torch adam optimizer
|
||||||
|
import math
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device, multi_tensor_applier
|
||||||
|
|
||||||
|
_FUSED_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
|
||||||
|
(torch.half, torch.half), (torch.bfloat16, torch.float), (torch.float, torch.bfloat16),
|
||||||
|
(torch.bfloat16, torch.bfloat16)]
|
||||||
|
|
||||||
|
_CPU_ALLOWED_P_G_TYPES = [(torch.float, torch.half), (torch.float, torch.float), (torch.half, torch.float),
|
||||||
|
(torch.half, torch.half)]
|
||||||
|
|
||||||
|
|
||||||
|
class AdamKernel:
|
||||||
|
|
||||||
|
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||||
|
self.lr = lr
|
||||||
|
self.beta1 = beta1
|
||||||
|
self.beta2 = beta2
|
||||||
|
self.eps = eps
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.use_adamw = use_adamw
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TorchAdamKernel(AdamKernel):
|
||||||
|
|
||||||
|
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||||
|
bias_correction1 = 1 - self.beta1**step
|
||||||
|
bias_correction2 = 1 - self.beta2**step
|
||||||
|
|
||||||
|
if self.weight_decay != 0:
|
||||||
|
if self.use_adamw:
|
||||||
|
# Perform stepweight decay
|
||||||
|
param.mul_(1 - self.lr * self.weight_decay)
|
||||||
|
else:
|
||||||
|
grad = grad.add(param, alpha=self.weight_decay)
|
||||||
|
|
||||||
|
# Decay the first and second moment running average coefficient
|
||||||
|
exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
|
||||||
|
exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
|
||||||
|
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
|
||||||
|
|
||||||
|
step_size = self.lr / bias_correction1
|
||||||
|
|
||||||
|
param.addcdiv_(exp_avg, denom, value=-step_size)
|
||||||
|
|
||||||
|
|
||||||
|
class FusedAdamKernel(AdamKernel):
|
||||||
|
|
||||||
|
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||||
|
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||||
|
from colossalai.kernel.op_builder import FusedOptimBuilder
|
||||||
|
fused_optim = FusedOptimBuilder().load()
|
||||||
|
self.fused_adam = fused_optim.multi_tensor_adam
|
||||||
|
self.dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||||
|
|
||||||
|
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||||
|
multi_tensor_applier(self.fused_adam, self.dummy_overflow_buf, [[grad], [param], [exp_avg], [exp_avg_sq]],
|
||||||
|
self.lr, self.beta1, self.beta2, self.eps, step, self.use_adamw, True, self.weight_decay,
|
||||||
|
-1)
|
||||||
|
|
||||||
|
|
||||||
|
class CPUAdamKernel(AdamKernel):
|
||||||
|
|
||||||
|
def __init__(self, lr: float, beta1: float, beta2: float, eps: float, weight_decay: float, use_adamw: bool) -> None:
|
||||||
|
super().__init__(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||||
|
from colossalai.kernel.op_builder import CPUAdamBuilder
|
||||||
|
cpu_optim = CPUAdamBuilder().load()
|
||||||
|
|
||||||
|
self.cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, use_adamw)
|
||||||
|
|
||||||
|
def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
|
||||||
|
self.cpu_adam_op.step(step, self.lr, self.beta1, self.beta2, self.eps, self.weight_decay, True, param.view(-1),
|
||||||
|
grad.view(-1), exp_avg.view(-1), exp_avg_sq.view(-1), -1)
|
||||||
|
|
||||||
|
|
||||||
|
def check_adam_kernel(kernel: Type[AdamKernel], adamw: bool, weight_decay: float, p_dtype: torch.dtype,
|
||||||
|
g_dtype: torch.dtype, device: torch.device, n_steps: int, rtol: float, atol: float):
|
||||||
|
lr = 1e-3
|
||||||
|
beta1, beta2 = 0.9, 0.999
|
||||||
|
eps = 1e-8
|
||||||
|
torch_adam = TorchAdamKernel(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||||
|
adam_kernel = kernel(lr, beta1, beta2, eps, weight_decay, adamw)
|
||||||
|
master_p = torch.rand(64, device=device)
|
||||||
|
master_g = torch.rand_like(master_p)
|
||||||
|
master_exp_avg = torch.zeros_like(master_p)
|
||||||
|
master_exp_avg_sq = torch.zeros_like(master_p)
|
||||||
|
p = master_p.clone().to(p_dtype)
|
||||||
|
g = master_g.clone().to(g_dtype)
|
||||||
|
exp_avg = master_exp_avg.clone()
|
||||||
|
exp_avg_sq = master_exp_avg_sq.clone()
|
||||||
|
|
||||||
|
for step in range(1, 1 + n_steps):
|
||||||
|
torch_adam.update(step, master_p, master_g, master_exp_avg, master_exp_avg_sq)
|
||||||
|
adam_kernel.update(step, p, g, exp_avg, exp_avg_sq)
|
||||||
|
# if overflow, the weight won't be updated. so there will be no nan in p
|
||||||
|
assert not torch.isnan(p).any()
|
||||||
|
assert torch.allclose(master_p, p.float(), rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('adamw', [False, True])
|
||||||
|
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
|
||||||
|
@pytest.mark.parametrize('p_dtype, g_dtype', _FUSED_ALLOWED_P_G_TYPES)
|
||||||
|
def test_fused_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
|
||||||
|
rtol, atol = 1e-5, 1e-8
|
||||||
|
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
||||||
|
rtol, atol = 1e-3, 1e-3
|
||||||
|
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
|
||||||
|
rtol, atol = 4e-3, 4e-3
|
||||||
|
check_adam_kernel(FusedAdamKernel, adamw, weight_decay, p_dtype, g_dtype, get_current_device(), 3, rtol, atol)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('adamw', [False, True])
|
||||||
|
@pytest.mark.parametrize('weight_decay', [0.0, 0.1])
|
||||||
|
@pytest.mark.parametrize('p_dtype, g_dtype', _CPU_ALLOWED_P_G_TYPES)
|
||||||
|
def test_cpu_adam_kernel(adamw, weight_decay, p_dtype, g_dtype):
|
||||||
|
rtol, atol = 1e-5, 1e-8
|
||||||
|
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
||||||
|
rtol, atol = 1e-3, 1e-3
|
||||||
|
check_adam_kernel(CPUAdamKernel, adamw, weight_decay, p_dtype, g_dtype, torch.device('cpu'), 3, rtol, atol)
|
|
@ -0,0 +1,86 @@
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Type, Union
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.optim import Adam, AdamW
|
||||||
|
|
||||||
|
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
|
||||||
|
_ALLOWED_OPTIM_DEVICES = [
|
||||||
|
(FusedAdam, torch.device('cuda:0')),
|
||||||
|
(CPUAdam, torch.device('cpu')),
|
||||||
|
(CPUAdam, torch.device('cuda:0')),
|
||||||
|
(HybridAdam, torch.device('cpu')),
|
||||||
|
(HybridAdam, torch.device('cuda:0')),
|
||||||
|
]
|
||||||
|
|
||||||
|
_ALLOWED_P_G_TYPES = [
|
||||||
|
(torch.float, torch.float), # pure fp32
|
||||||
|
(torch.float, torch.half), # fp16 amp
|
||||||
|
(torch.float, torch.bfloat16), # bfloat16 amp
|
||||||
|
# (torch.half, torch.half), # FIXME(ver217): cpu adam kernel does not support pure fp16
|
||||||
|
# (torch.bfloat16, torch.bfloat16), # FIXME(ver217): cpu adam kernel does not support pure bfloat16
|
||||||
|
]
|
||||||
|
|
||||||
|
N_STEPS = 3
|
||||||
|
|
||||||
|
|
||||||
|
def setup_param_groups(bert_model: nn.Module) -> list:
|
||||||
|
no_decay = ["bias", "LayerNorm.weight"]
|
||||||
|
optimizer_grouped_parameters = [
|
||||||
|
{
|
||||||
|
"params": [p for n, p in bert_model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [p for n, p in bert_model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return optimizer_grouped_parameters
|
||||||
|
|
||||||
|
|
||||||
|
def set_grad(model: nn.Module, torch_model: nn.Module, g_dtype: torch.dtype) -> None:
|
||||||
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
|
torch_p.grad = torch.rand_like(torch_p)
|
||||||
|
# avoid inconsistent grad and param dtype error
|
||||||
|
orig_p = p.data
|
||||||
|
p.data = torch_p.grad.clone().to(g_dtype)
|
||||||
|
p.grad = p.data
|
||||||
|
p.data = orig_p
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('optim_cls, device', _ALLOWED_OPTIM_DEVICES)
|
||||||
|
@pytest.mark.parametrize('adamw', [False, True])
|
||||||
|
@pytest.mark.parametrize('p_dtype, g_dtype', _ALLOWED_P_G_TYPES)
|
||||||
|
def test_adam_optim_on_bert(optim_cls: Union[Type[FusedAdam], Type[CPUAdam], Type[HybridAdam]], device: torch.device,
|
||||||
|
adamw: bool, p_dtype: torch.dtype, g_dtype: torch.dtype) -> None:
|
||||||
|
model_fn, *_ = next(iter(model_zoo.get_sub_registry('transformers_bert_for_sequence_classification').values()))
|
||||||
|
torch_model = model_fn().to(device)
|
||||||
|
model = deepcopy(torch_model).to(p_dtype)
|
||||||
|
lr = 1e-3
|
||||||
|
beta1, beta2 = 0.9, 0.999
|
||||||
|
eps = 1e-8
|
||||||
|
torch_optim_cls = AdamW if adamw else Adam
|
||||||
|
torch_optim = torch_optim_cls(setup_param_groups(torch_model), lr=lr, betas=(beta1, beta2), eps=eps)
|
||||||
|
optim = optim_cls(setup_param_groups(model), lr=lr, betas=(beta1, beta2), eps=eps, adamw_mode=adamw)
|
||||||
|
|
||||||
|
rtol, atol = 1e-5, 1e-5
|
||||||
|
if p_dtype is torch.float16 or g_dtype is torch.float16:
|
||||||
|
rtol, atol = 2e-3, 2e-3
|
||||||
|
if p_dtype is torch.bfloat16 or g_dtype is torch.bfloat16:
|
||||||
|
rtol, atol = 4e-3, 4e-3
|
||||||
|
|
||||||
|
for _ in range(N_STEPS):
|
||||||
|
set_grad(model, torch_model, g_dtype)
|
||||||
|
torch_optim.step()
|
||||||
|
optim.step()
|
||||||
|
torch_optim.zero_grad()
|
||||||
|
optim.zero_grad()
|
||||||
|
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||||
|
# if overflow, the weight won't be updated. so there will be no nan in p
|
||||||
|
assert not torch.isnan(p).any()
|
||||||
|
assert torch.allclose(p.float(), torch_p, rtol=rtol, atol=atol)
|
|
@ -1,121 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize
|
|
||||||
|
|
||||||
|
|
||||||
def torch_adam_update(
|
|
||||||
step,
|
|
||||||
lr,
|
|
||||||
beta1,
|
|
||||||
beta2,
|
|
||||||
eps,
|
|
||||||
weight_decay,
|
|
||||||
param,
|
|
||||||
grad,
|
|
||||||
exp_avg,
|
|
||||||
exp_avg_sq,
|
|
||||||
use_adamw,
|
|
||||||
):
|
|
||||||
bias_correction1 = 1 - beta1**step
|
|
||||||
bias_correction2 = 1 - beta2**step
|
|
||||||
|
|
||||||
if weight_decay != 0:
|
|
||||||
if use_adamw:
|
|
||||||
# Perform stepweight decay
|
|
||||||
param.mul_(1 - lr * weight_decay)
|
|
||||||
else:
|
|
||||||
grad = grad.add(param, alpha=weight_decay)
|
|
||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
||||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
||||||
|
|
||||||
step_size = lr / bias_correction1
|
|
||||||
|
|
||||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
|
||||||
|
|
||||||
|
|
||||||
def assertLess(data_diff, threshold, msg):
|
|
||||||
assert data_diff < threshold, msg
|
|
||||||
|
|
||||||
|
|
||||||
def assertTrue(condition, msg):
|
|
||||||
assert condition, msg
|
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
|
||||||
@parameterize('adamw', [True, False])
|
|
||||||
@parameterize('step', [1, 2])
|
|
||||||
@parameterize('p_dtype', [torch.float, torch.half])
|
|
||||||
@parameterize('g_dtype', [torch.float, torch.half])
|
|
||||||
def test_cpu_adam(adamw, step, p_dtype, g_dtype):
|
|
||||||
lr = 1e-3
|
|
||||||
beta1, beta2 = 0.9, 0.999
|
|
||||||
eps = 1e-8
|
|
||||||
weight_decay = 0
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
p_data = torch.rand(64, dtype=p_dtype)
|
|
||||||
p_data_copy = p_data.clone().float()
|
|
||||||
p_grad = torch.rand(64, dtype=g_dtype)
|
|
||||||
p_grad_copy = p_grad.clone().float()
|
|
||||||
exp_avg = torch.rand(p_data.shape)
|
|
||||||
exp_avg_copy = exp_avg.clone()
|
|
||||||
exp_avg_sq = torch.rand(p_data.shape)
|
|
||||||
exp_avg_sq_copy = exp_avg_sq.clone()
|
|
||||||
|
|
||||||
from colossalai.kernel.op_builder import CPUAdamBuilder
|
|
||||||
cpu_optim = CPUAdamBuilder().load()
|
|
||||||
|
|
||||||
cpu_adam_op = cpu_optim.CPUAdamOptimizer(lr, beta1, beta2, eps, weight_decay, adamw)
|
|
||||||
|
|
||||||
cpu_adam_op.step(
|
|
||||||
step,
|
|
||||||
lr,
|
|
||||||
beta1,
|
|
||||||
beta2,
|
|
||||||
eps,
|
|
||||||
weight_decay,
|
|
||||||
True,
|
|
||||||
p_data.view(-1), # fp32 data
|
|
||||||
p_grad.view(-1), # fp32 grad
|
|
||||||
exp_avg.view(-1),
|
|
||||||
exp_avg_sq.view(-1),
|
|
||||||
-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch_adam_update(
|
|
||||||
step,
|
|
||||||
lr,
|
|
||||||
beta1,
|
|
||||||
beta2,
|
|
||||||
eps,
|
|
||||||
weight_decay,
|
|
||||||
p_data_copy, # fp32 data
|
|
||||||
p_grad_copy, # fp32 grad
|
|
||||||
exp_avg_copy,
|
|
||||||
exp_avg_sq_copy,
|
|
||||||
adamw,
|
|
||||||
)
|
|
||||||
var = p_data_copy - p_data
|
|
||||||
data_diff = torch.max(torch.abs(var))
|
|
||||||
threshold = 1e-3
|
|
||||||
assertLess(
|
|
||||||
data_diff,
|
|
||||||
threshold,
|
|
||||||
f"p_data diff {data_diff}. failed check, step {step}, lr {lr}, eps "
|
|
||||||
f"{eps} beta1 {beta1} beta2 {beta2} weight_decay {weight_decay} p_dtype {p_dtype}, g_dtype {g_dtype}",
|
|
||||||
)
|
|
||||||
max_grad_diff = torch.max(torch.abs(p_grad_copy - p_grad))
|
|
||||||
assertTrue(max_grad_diff < threshold, f"diff {max_grad_diff}")
|
|
||||||
max_exp_avg_diff = torch.max(torch.abs(exp_avg_copy - exp_avg))
|
|
||||||
assertTrue(max_exp_avg_diff < threshold, f"max_exp_avg_diff {max_exp_avg_diff}")
|
|
||||||
max_exp_avg_sq_diff = torch.max(torch.abs(exp_avg_sq_copy - exp_avg_sq))
|
|
||||||
assertTrue(max_exp_avg_sq_diff < threshold, f"max_exp_avg_sq_diff {max_exp_avg_sq_diff}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
test_cpu_adam()
|
|
|
@ -1,64 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from torch.optim.adam import Adam
|
|
||||||
|
|
||||||
from colossalai.nn.optimizer.fused_adam import FusedAdam
|
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize
|
|
||||||
|
|
||||||
|
|
||||||
class FC(nn.Module):
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.fc = nn.Sequential(nn.Linear(64, 64))
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.fc(x)
|
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
|
||||||
@parameterize('adamw', [False, True])
|
|
||||||
@parameterize('p_dtype', [torch.float, torch.half])
|
|
||||||
@parameterize('g_dtype', [torch.float, torch.half])
|
|
||||||
def test_adam(adamw, p_dtype, g_dtype):
|
|
||||||
model = FC().cuda().to(p_dtype)
|
|
||||||
state = model.state_dict()
|
|
||||||
model_copy = FC().cuda().to(p_dtype)
|
|
||||||
model_copy.load_state_dict(state.copy())
|
|
||||||
|
|
||||||
if adamw:
|
|
||||||
optim = FusedAdam(model.parameters(), lr=1e-3, adamw_mode=True)
|
|
||||||
torch_optim = AdamW(model_copy.parameters(), lr=1e-3)
|
|
||||||
else:
|
|
||||||
optim = FusedAdam(model.parameters(), lr=1e-3)
|
|
||||||
torch_optim = Adam(model_copy.parameters(), lr=1e-3)
|
|
||||||
|
|
||||||
data = torch.rand(1024, 64).cuda().to(p_dtype)
|
|
||||||
data_copy = data.clone()
|
|
||||||
label = torch.rand(1024, 64).cuda().to(p_dtype)
|
|
||||||
|
|
||||||
for d, l in zip(data, label):
|
|
||||||
y = model(d)
|
|
||||||
loss = ((l - y)**2).sum()
|
|
||||||
optim.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
if p_dtype != g_dtype:
|
|
||||||
for i in range(len(optim.param_groups[0]['params'])):
|
|
||||||
optim.param_groups[0]['params'][i].grad.data = optim.param_groups[0]['params'][i].grad.data.to(g_dtype)
|
|
||||||
optim.step()
|
|
||||||
|
|
||||||
for d, l in zip(data_copy, label):
|
|
||||||
y = model_copy(d)
|
|
||||||
loss = ((l - y)**2).sum()
|
|
||||||
torch_optim.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
torch_optim.step()
|
|
||||||
|
|
||||||
assert len(optim.param_groups[0]['params']) == len(torch_optim.param_groups[0]['params'])
|
|
||||||
|
|
||||||
for i in range(len(optim.param_groups[0]['params'])):
|
|
||||||
if torch.isnan(optim.param_groups[0]['params'][i]).any() \
|
|
||||||
or torch.isnan(torch_optim.param_groups[0]['params'][i]).any():
|
|
||||||
continue
|
|
||||||
assert torch.allclose(optim.param_groups[0]['params'][i], torch_optim.param_groups[0]['params'][i], 2e-3, 2e-3)
|
|
|
@ -1,95 +0,0 @@
|
||||||
import math
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from numpy import dtype
|
|
||||||
|
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize
|
|
||||||
from colossalai.utils import multi_tensor_applier
|
|
||||||
|
|
||||||
|
|
||||||
def torch_adam_update(
|
|
||||||
step,
|
|
||||||
lr,
|
|
||||||
beta1,
|
|
||||||
beta2,
|
|
||||||
eps,
|
|
||||||
weight_decay,
|
|
||||||
param,
|
|
||||||
grad,
|
|
||||||
exp_avg,
|
|
||||||
exp_avg_sq,
|
|
||||||
use_adamw,
|
|
||||||
):
|
|
||||||
bias_correction1 = 1 - beta1**step
|
|
||||||
bias_correction2 = 1 - beta2**step
|
|
||||||
|
|
||||||
if weight_decay != 0:
|
|
||||||
if use_adamw:
|
|
||||||
# Perform stepweight decay
|
|
||||||
param.mul_(1 - lr * weight_decay)
|
|
||||||
else:
|
|
||||||
grad = grad.add(param, alpha=weight_decay)
|
|
||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
|
||||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
|
||||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
|
||||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
|
|
||||||
|
|
||||||
step_size = lr / bias_correction1
|
|
||||||
|
|
||||||
param.addcdiv_(exp_avg, denom, value=-step_size)
|
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
|
||||||
@parameterize('adamw', [False, True])
|
|
||||||
@parameterize('step', [1, 2])
|
|
||||||
@parameterize('p_dtype', [torch.float, torch.half])
|
|
||||||
@parameterize('g_dtype', [torch.float, torch.half])
|
|
||||||
def test_adam(adamw, step, p_dtype, g_dtype):
|
|
||||||
from colossalai.kernel.op_builder import FusedOptimBuilder
|
|
||||||
fused_optim = FusedOptimBuilder().load()
|
|
||||||
fused_adam = fused_optim.multi_tensor_adam
|
|
||||||
|
|
||||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
|
||||||
|
|
||||||
count = 0
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
p = torch.rand(64, dtype=p_dtype).cuda()
|
|
||||||
p_copy = p.clone().float()
|
|
||||||
g = torch.rand(p.shape, dtype=g_dtype).cuda()
|
|
||||||
g_copy = g.clone().float()
|
|
||||||
m = torch.rand(p.shape).cuda()
|
|
||||||
m_copy = m.clone()
|
|
||||||
v = torch.rand(p.shape).cuda()
|
|
||||||
v_copy = v.clone()
|
|
||||||
|
|
||||||
lr = 1e-3
|
|
||||||
beta1, beta2 = 0.9, 0.999
|
|
||||||
eps = 1e-8
|
|
||||||
weight_decay = 0
|
|
||||||
|
|
||||||
multi_tensor_applier(fused_adam, dummy_overflow_buf, [[g], [p], [m], [v]], lr, beta1, beta2, eps, step, adamw,
|
|
||||||
True, weight_decay, -1)
|
|
||||||
|
|
||||||
torch_adam_update(
|
|
||||||
step,
|
|
||||||
lr,
|
|
||||||
beta1,
|
|
||||||
beta2,
|
|
||||||
eps,
|
|
||||||
weight_decay,
|
|
||||||
p_copy, # fp32 data
|
|
||||||
g_copy, # fp32 grad
|
|
||||||
m_copy,
|
|
||||||
v_copy,
|
|
||||||
adamw,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.isnan(p).any() or torch.isnan(p_copy).any():
|
|
||||||
count += 1
|
|
||||||
continue
|
|
||||||
assert count < 200, "too many nans"
|
|
||||||
assert torch.allclose(p.to(torch.float), p_copy.to(torch.float), 1e-5,
|
|
||||||
1e-5), f"failed check, adamw {adamw}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
|
|
@ -1,42 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from torch.optim.adam import Adam
|
|
||||||
|
|
||||||
from colossalai.nn.optimizer.hybrid_adam import HybridAdam
|
|
||||||
from colossalai.testing import clear_cache_before_run, parameterize
|
|
||||||
|
|
||||||
RE = 3
|
|
||||||
|
|
||||||
|
|
||||||
@clear_cache_before_run()
|
|
||||||
@parameterize('adamw', [False, True])
|
|
||||||
@parameterize('device', ['cpu', 'cuda:0'])
|
|
||||||
@parameterize('p_dtype', [torch.float])
|
|
||||||
@parameterize('g_dtype', [torch.float, torch.half])
|
|
||||||
def test_adam(adamw, device, p_dtype, g_dtype):
|
|
||||||
rng_state = torch.get_rng_state()
|
|
||||||
p = nn.Parameter(torch.rand(64).to(device, p_dtype))
|
|
||||||
torch.set_rng_state(rng_state)
|
|
||||||
p_copy = nn.Parameter(torch.rand(64).to(device).float())
|
|
||||||
|
|
||||||
if adamw:
|
|
||||||
optim = HybridAdam([p], lr=1e-3, adamw_mode=True)
|
|
||||||
torch_optim = AdamW([p_copy], lr=1e-3)
|
|
||||||
else:
|
|
||||||
optim = HybridAdam([p], lr=1e-3)
|
|
||||||
torch_optim = Adam([p_copy], lr=1e-3)
|
|
||||||
|
|
||||||
print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}")
|
|
||||||
for i in range(RE):
|
|
||||||
p.grad = torch.rand(64).to(device, p_dtype)
|
|
||||||
p_copy.grad = p.grad.clone().float()
|
|
||||||
p.grad.data = p.grad.data.to(g_dtype)
|
|
||||||
|
|
||||||
optim.step()
|
|
||||||
torch_optim.step()
|
|
||||||
|
|
||||||
if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
|
|
||||||
continue
|
|
||||||
assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
|
|
||||||
f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"
|
|
|
@ -21,23 +21,40 @@ TEST_MODELS = ['gpt2']
|
||||||
# these models are too small, all parameters in these models are compacted into one chunk
|
# these models are too small, all parameters in these models are compacted into one chunk
|
||||||
EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers']
|
EXAMPLE_MODELS = ['albert', 'beit', 'bert', 'hanging_param_model', 'nested_model', 'repeated_computed_layers']
|
||||||
|
|
||||||
|
# bfloat16 cannot represent them exactly
|
||||||
|
BF16_IGNORED_KEYS = [
|
||||||
|
'albert.embeddings.word_embeddings.weight',
|
||||||
|
'albert.embeddings.position_embeddings.weight',
|
||||||
|
'masked_bias',
|
||||||
|
]
|
||||||
|
|
||||||
def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
|
|
||||||
zero_dict = model.state_dict(only_rank_0=False)
|
def check_param(model: ZeroDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||||
|
zero_dict = model.state_dict(only_rank_0=False, dtype=dtype)
|
||||||
torch_dict = torch_model.state_dict()
|
torch_dict = torch_model.state_dict()
|
||||||
|
|
||||||
for key, value in torch_dict.items():
|
for key, value in torch_dict.items():
|
||||||
# key is 'module.model.PARAMETER', so we truncate it
|
# key is 'module.model.PARAMETER', so we truncate it
|
||||||
key = key[7:]
|
key = key[7:]
|
||||||
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
assert key in zero_dict, "{} not in ZeRO dictionary.".format(key)
|
||||||
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
|
temp_zero_value = zero_dict[key].to(device=value.device)
|
||||||
|
if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS):
|
||||||
|
continue
|
||||||
|
rtol, atol = 1e-3, 4e-3
|
||||||
|
if dtype is torch.bfloat16:
|
||||||
|
rtol, atol = 4e-3, 8e-3
|
||||||
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
|
||||||
assert_close(value, temp_zero_value, rtol=1e-3, atol=4e-3)
|
assert_close(value.float(),
|
||||||
|
temp_zero_value.float(),
|
||||||
|
rtol=rtol,
|
||||||
|
atol=atol,
|
||||||
|
msg=lambda s: s + f'\n{key}\n{temp_zero_value.dtype}')
|
||||||
|
|
||||||
|
|
||||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||||
@parameterize('model_name', TEST_MODELS)
|
@parameterize('model_name', TEST_MODELS)
|
||||||
def exam_model_step(placement_policy, model_name: str):
|
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||||
|
def exam_model_step(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||||
set_seed(42)
|
set_seed(42)
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
@ -65,7 +82,7 @@ def exam_model_step(placement_policy, model_name: str):
|
||||||
init_device = None
|
init_device = None
|
||||||
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
chunk_manager = ChunkManager(config_dict, init_device=init_device)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||||
|
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=128)
|
||||||
|
@ -74,6 +91,7 @@ def exam_model_step(placement_policy, model_name: str):
|
||||||
torch_model.eval()
|
torch_model.eval()
|
||||||
|
|
||||||
set_seed(dist.get_rank() * 3 + 128)
|
set_seed(dist.get_rank() * 3 + 128)
|
||||||
|
rtol, atol = 1e-4, 1e-5
|
||||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||||
if i > 2:
|
if i > 2:
|
||||||
break
|
break
|
||||||
|
@ -83,17 +101,18 @@ def exam_model_step(placement_policy, model_name: str):
|
||||||
|
|
||||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||||
assert_close(torch_loss, loss)
|
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
zero_optim.step()
|
zero_optim.step()
|
||||||
torch_optim.step()
|
torch_optim.step()
|
||||||
|
|
||||||
check_param(model, torch_model)
|
check_param(model, torch_model, mixed_precision)
|
||||||
|
|
||||||
|
|
||||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
@parameterize('placement_policy', ['cuda', 'cpu', 'auto', 'const'])
|
||||||
@parameterize('model_name', EXAMPLE_MODELS)
|
@parameterize('model_name', EXAMPLE_MODELS)
|
||||||
def exam_tiny_example(placement_policy, model_name: str):
|
@parameterize('mixed_precision', [torch.half, torch.bfloat16])
|
||||||
|
def exam_tiny_example(placement_policy, model_name: str, mixed_precision: torch.dtype):
|
||||||
set_seed(2008)
|
set_seed(2008)
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
@ -113,7 +132,7 @@ def exam_tiny_example(placement_policy, model_name: str):
|
||||||
|
|
||||||
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
|
chunk_manager = init_chunk_manager(model=model, init_device=get_current_device(), search_range_mb=1)
|
||||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
model = ZeroDDP(model, gemini_manager, pin_memory=True)
|
model = ZeroDDP(model, gemini_manager, pin_memory=True, mixed_precision=mixed_precision)
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
zero_optim = ZeroOptimizer(optimizer, model, initial_scale=2)
|
||||||
|
|
||||||
|
@ -121,6 +140,9 @@ def exam_tiny_example(placement_policy, model_name: str):
|
||||||
torch_model.eval()
|
torch_model.eval()
|
||||||
|
|
||||||
set_seed(dist.get_rank() * 3 + 128)
|
set_seed(dist.get_rank() * 3 + 128)
|
||||||
|
rtol, atol = 1.5e-6, 2e-5
|
||||||
|
if mixed_precision is torch.bfloat16:
|
||||||
|
rtol, atol = 2e-3, 2e-3
|
||||||
for i, (input_ids, label) in enumerate(train_dataloader):
|
for i, (input_ids, label) in enumerate(train_dataloader):
|
||||||
if i > 2:
|
if i > 2:
|
||||||
break
|
break
|
||||||
|
@ -133,12 +155,12 @@ def exam_tiny_example(placement_policy, model_name: str):
|
||||||
|
|
||||||
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
|
||||||
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
|
||||||
assert_close(torch_loss, loss, rtol=1.5e-6, atol=2e-5) # atol should be 2e-5 for torch lower than 1.12
|
assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12
|
||||||
|
|
||||||
zero_optim.step()
|
zero_optim.step()
|
||||||
torch_optim.step()
|
torch_optim.step()
|
||||||
|
|
||||||
check_param(model, torch_model)
|
check_param(model, torch_model, mixed_precision)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
|
|
@ -16,7 +16,11 @@ from colossalai.zero.low_level._utils import has_inf_or_nan
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, parallel_config):
|
def run_dist(rank, world_size, port, parallel_config, bf16):
|
||||||
|
is_mp_config = parallel_config == MP_PARALLEL_CONFIG
|
||||||
|
is_zero_config = parallel_config == ZERO_PARALLEL_CONFIG
|
||||||
|
if bf16:
|
||||||
|
parallel_config['zero']['model_config']['bf16'] = True
|
||||||
colossalai.launch(config=parallel_config,
|
colossalai.launch(config=parallel_config,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
|
@ -30,7 +34,8 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||||
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
|
||||||
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
with ZeroInitContext(target_device=torch.cuda.current_device(),
|
||||||
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
shard_strategy=gpc.config.zero.model_config.shard_strategy,
|
||||||
shard_param=True):
|
shard_param=True,
|
||||||
|
bf16=bf16):
|
||||||
colo_model = model_builder(checkpoint=True)
|
colo_model = model_builder(checkpoint=True)
|
||||||
|
|
||||||
colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
|
colo_optimizer = optimizer_class(colo_model.parameters(), lr=1e-3)
|
||||||
|
@ -38,7 +43,8 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||||
optimizer=colo_optimizer,
|
optimizer=colo_optimizer,
|
||||||
criterion=criterion,
|
criterion=criterion,
|
||||||
train_dataloader=train_dataloader)
|
train_dataloader=train_dataloader)
|
||||||
torch_model = model_builder(checkpoint=True).half()
|
dtype = torch.bfloat16 if bf16 else torch.float16
|
||||||
|
torch_model = model_builder(checkpoint=True).to(dtype)
|
||||||
col_model_deepcopy(engine.model, torch_model)
|
col_model_deepcopy(engine.model, torch_model)
|
||||||
torch_model = torch_model.cuda().float()
|
torch_model = torch_model.cuda().float()
|
||||||
|
|
||||||
|
@ -80,9 +86,9 @@ def run_dist(rank, world_size, port, parallel_config):
|
||||||
torch_optimizer.step()
|
torch_optimizer.step()
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
if parallel_config == MP_PARALLEL_CONFIG:
|
if is_mp_config:
|
||||||
check_params(torch_model, colo_model, loose=True)
|
check_params(torch_model, colo_model, loose=True)
|
||||||
elif parallel_config == ZERO_PARALLEL_CONFIG:
|
elif is_zero_config:
|
||||||
check_sharded_model_params(torch_model, colo_model, loose=True)
|
check_sharded_model_params(torch_model, colo_model, loose=True)
|
||||||
|
|
||||||
|
|
||||||
|
@ -97,9 +103,10 @@ def test_mp_engine(world_size):
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
|
@pytest.mark.parametrize("bf16", [True, False])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_zero_engine(world_size):
|
def test_zero_engine(world_size, bf16):
|
||||||
spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG)
|
spawn(run_dist, world_size, parallel_config=ZERO_PARALLEL_CONFIG, bf16=bf16)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -82,7 +82,6 @@ def exam_zero_1_2_grad_acc():
|
||||||
|
|
||||||
def exam_zero_1_grad_acc():
|
def exam_zero_1_grad_acc():
|
||||||
local_rank = torch.distributed.get_rank()
|
local_rank = torch.distributed.get_rank()
|
||||||
grad_scale = 32
|
|
||||||
seed_all(2008)
|
seed_all(2008)
|
||||||
|
|
||||||
# create models
|
# create models
|
||||||
|
@ -101,7 +100,6 @@ def exam_zero_1_grad_acc():
|
||||||
# level 1 and 2 will produce exactly the same results
|
# level 1 and 2 will produce exactly the same results
|
||||||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||||
overlap_communication=False,
|
overlap_communication=False,
|
||||||
initial_scale=grad_scale,
|
|
||||||
reduce_bucket_size=262144,
|
reduce_bucket_size=262144,
|
||||||
clip_grad_norm=1.0)
|
clip_grad_norm=1.0)
|
||||||
|
|
||||||
|
@ -128,9 +126,8 @@ def exam_zero_1_grad_acc():
|
||||||
if check_flag:
|
if check_flag:
|
||||||
# check grad
|
# check grad
|
||||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||||
unscale_grad = z1p.grad / grad_scale
|
|
||||||
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
||||||
assert torch.equal(p.grad, unscale_grad)
|
assert torch.equal(p.grad, z1p.grad)
|
||||||
|
|
||||||
zero_optimizer._sync_grad()
|
zero_optimizer._sync_grad()
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
from colossalai.testing.random import seed_all
|
from colossalai.testing.random import seed_all
|
||||||
from colossalai.zero import LowLevelZeroOptimizer
|
from colossalai.zero import LowLevelZeroOptimizer
|
||||||
|
|
||||||
|
@ -25,15 +25,18 @@ class MlpModel(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def half_close(a, b, loose=False):
|
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||||
rtol = None
|
rtol = None
|
||||||
atol = None
|
atol = None
|
||||||
if loose:
|
if dtype is torch.float16:
|
||||||
rtol = 5e-2
|
rtol = 5e-2
|
||||||
atol = 5e-4
|
atol = 5e-4
|
||||||
|
elif dtype is torch.bfloat16:
|
||||||
|
rtol = 4e-3
|
||||||
|
atol = 4e-3
|
||||||
|
|
||||||
a = a.detach().half()
|
a = a.detach().to(dtype)
|
||||||
b = b.detach().half()
|
b = b.detach().to(dtype)
|
||||||
|
|
||||||
assert_close(a, b, rtol=rtol, atol=atol)
|
assert_close(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
@ -96,7 +99,8 @@ def exam_zero_1_2():
|
||||||
assert torch.equal(z1p.data, z2p.data)
|
assert torch.equal(z1p.data, z2p.data)
|
||||||
|
|
||||||
|
|
||||||
def exam_zero_1_torch_ddp():
|
@parameterize('dtype', [torch.float16, torch.bfloat16])
|
||||||
|
def exam_zero_1_torch_ddp(dtype: torch.dtype):
|
||||||
"""
|
"""
|
||||||
In this test, two pairs of model and optimizers are created.
|
In this test, two pairs of model and optimizers are created.
|
||||||
1. zero: use sharded optimizer and fp16 parameters
|
1. zero: use sharded optimizer and fp16 parameters
|
||||||
|
@ -109,15 +113,10 @@ def exam_zero_1_torch_ddp():
|
||||||
seed_all(1453)
|
seed_all(1453)
|
||||||
|
|
||||||
# create models
|
# create models
|
||||||
zero_model = MlpModel()
|
torch_model = MlpModel().cuda()
|
||||||
torch_model = copy.deepcopy(zero_model)
|
zero_model = copy.deepcopy(torch_model).to(dtype)
|
||||||
|
|
||||||
zero_model = zero_model.cuda().half()
|
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0).cuda()
|
||||||
torch_model = DDP(torch_model.cuda(), bucket_cap_mb=0)
|
|
||||||
torch_model = torch_model.cuda()
|
|
||||||
|
|
||||||
# for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
|
||||||
# half_close(p.data, z1p.data)
|
|
||||||
|
|
||||||
# create optimizer
|
# create optimizer
|
||||||
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
zero_optimizer = torch.optim.SGD(zero_model.parameters(), lr=1)
|
||||||
|
@ -137,11 +136,11 @@ def exam_zero_1_torch_ddp():
|
||||||
input_data = torch.rand(32, 128).cuda()
|
input_data = torch.rand(32, 128).cuda()
|
||||||
|
|
||||||
# zero-dp forward
|
# zero-dp forward
|
||||||
zero_output = zero_model(input_data.half())
|
zero_output = zero_model(input_data.to(dtype))
|
||||||
|
|
||||||
# torch-ddp forward
|
# torch-ddp forward
|
||||||
torch_output = torch_model(input_data)
|
torch_output = torch_model(input_data)
|
||||||
half_close(zero_output, torch_output, loose=True)
|
loose_close(zero_output, torch_output, dtype=dtype)
|
||||||
|
|
||||||
# zero-dp backward
|
# zero-dp backward
|
||||||
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
|
zero_optimizer.backward(zero_output.mean().float(), sync_grad=False)
|
||||||
|
@ -151,7 +150,7 @@ def exam_zero_1_torch_ddp():
|
||||||
|
|
||||||
# check grad
|
# check grad
|
||||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||||
half_close(p.grad, z1p.grad, loose=True)
|
loose_close(p.grad, z1p.grad, dtype=dtype)
|
||||||
|
|
||||||
# zero-dp step
|
# zero-dp step
|
||||||
zero_optimizer._sync_grad()
|
zero_optimizer._sync_grad()
|
||||||
|
@ -163,7 +162,7 @@ def exam_zero_1_torch_ddp():
|
||||||
# check updated param
|
# check updated param
|
||||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||||
# print(n, torch.max(torch.abs(p.data - z1p.data)))
|
# print(n, torch.max(torch.abs(p.data - z1p.data)))
|
||||||
half_close(p.data, z1p.data, loose=True)
|
loose_close(p.data, z1p.data, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
|
|
Loading…
Reference in New Issue