2022-03-03 07:06:18 +00:00
|
|
|
from enum import Enum
|
2022-03-29 01:08:18 +00:00
|
|
|
from os import stat
|
|
|
|
from typing import Dict, Optional, Tuple
|
2022-03-03 07:06:18 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn as nn
|
2022-03-29 07:45:48 +00:00
|
|
|
from torch import Tensor
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
2022-03-15 02:05:38 +00:00
|
|
|
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
2022-03-03 07:06:18 +00:00
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.core import global_context as gpc
|
2022-03-18 08:18:31 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
2022-03-03 07:06:18 +00:00
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
2022-03-29 04:48:00 +00:00
|
|
|
from colossalai.utils.memory_utils.utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
|
2022-03-03 07:06:18 +00:00
|
|
|
from colossalai.zero.sharded_model import ShardedModelV2
|
2022-03-25 06:54:39 +00:00
|
|
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
2022-03-29 04:48:00 +00:00
|
|
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
2022-03-29 07:45:48 +00:00
|
|
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
|
|
|
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
|
|
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
|
|
|
|
2022-03-03 07:06:18 +00:00
|
|
|
|
|
|
|
|
|
|
|
class OptimState(Enum):
|
|
|
|
SCALED = 1
|
|
|
|
UNSCALED = 2
|
|
|
|
|
|
|
|
|
2022-03-03 07:55:27 +00:00
|
|
|
class ShardedOptimizerV2(ColossalaiOptimizer):
|
2022-03-24 06:29:41 +00:00
|
|
|
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO).
|
2022-03-29 01:08:18 +00:00
|
|
|
|
2022-03-24 06:29:41 +00:00
|
|
|
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU.
|
2022-03-29 04:48:00 +00:00
|
|
|
|
2022-03-24 06:29:41 +00:00
|
|
|
We apply the Device-aware Operator Placement technique for OS placement from the following paper.
|
2022-03-29 01:08:18 +00:00
|
|
|
|
2022-03-24 06:29:41 +00:00
|
|
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
|
|
|
https://arxiv.org/abs/2108.05818
|
2022-03-29 04:48:00 +00:00
|
|
|
|
2022-03-24 06:29:41 +00:00
|
|
|
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory,
|
|
|
|
which is detected by a runtime memory tracer.
|
2022-03-29 04:48:00 +00:00
|
|
|
|
2022-03-24 06:29:41 +00:00
|
|
|
We place as many OS chunks in the margin space as possible.
|
2022-03-29 04:48:00 +00:00
|
|
|
|
2022-03-29 01:08:18 +00:00
|
|
|
The size of margin space can be controlled by `gpu_margin_mem_ratio`。
|
2022-03-24 06:29:41 +00:00
|
|
|
If it is set as 0.0, it is the same as classical ZeRO optimizer.
|
|
|
|
|
|
|
|
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`.
|
2022-03-18 08:48:20 +00:00
|
|
|
|
2022-03-22 06:56:59 +00:00
|
|
|
Args:
|
|
|
|
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
|
|
|
shard strategy provided by sharded model to shard param fp32 tensors.
|
|
|
|
optimizer (Optimizer): An Optimizer instance.
|
|
|
|
cpu_offload (bool, optional): Is offloading the optimizer states to CPU.. Defaults to False.
|
|
|
|
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
|
|
|
which will be used when using hybrid CPU optimizer. Defaults to 0.0.
|
|
|
|
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
|
|
|
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
|
|
|
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
|
|
|
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
|
|
|
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
|
|
|
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
|
|
|
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
|
|
|
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
|
|
|
|
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
|
|
|
|
"""
|
2022-03-03 07:06:18 +00:00
|
|
|
|
|
|
|
def __init__(self,
|
2022-03-09 08:09:36 +00:00
|
|
|
sharded_model: ShardedModelV2,
|
2022-03-18 07:44:47 +00:00
|
|
|
optimizer: Optimizer,
|
2022-03-03 07:06:18 +00:00
|
|
|
cpu_offload: bool = False,
|
2022-03-22 06:56:59 +00:00
|
|
|
gpu_margin_mem_ratio: float = 0.0,
|
2022-03-03 07:06:18 +00:00
|
|
|
initial_scale: float = 2**32,
|
|
|
|
min_scale: float = 1,
|
|
|
|
growth_factor: float = 2,
|
|
|
|
backoff_factor: float = 0.5,
|
|
|
|
growth_interval: float = 1000,
|
|
|
|
hysteresis: float = 2,
|
|
|
|
max_scale: int = 2**32,
|
2022-03-29 07:45:48 +00:00
|
|
|
use_memory_tracer=False,
|
2022-03-03 07:06:18 +00:00
|
|
|
dp_process_group: Optional[ProcessGroup] = None,
|
2022-03-18 07:44:47 +00:00
|
|
|
mp_process_group: Optional[ProcessGroup] = None) -> None:
|
2022-03-09 08:09:36 +00:00
|
|
|
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
|
2022-03-14 12:48:41 +00:00
|
|
|
|
2022-03-18 07:44:47 +00:00
|
|
|
super().__init__(optimizer)
|
2022-03-15 02:45:55 +00:00
|
|
|
self.shard_strategy = sharded_model.shard_strategy
|
2022-03-09 08:09:36 +00:00
|
|
|
self.model: ShardedModelV2 = sharded_model
|
2022-03-11 06:40:01 +00:00
|
|
|
if cpu_offload and not sharded_model.cpu_offload:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload"
|
|
|
|
)
|
2022-03-22 06:56:59 +00:00
|
|
|
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
|
|
|
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
|
|
|
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
|
|
|
|
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
|
|
|
|
# and it must set `num_fp32_shards_per_param` correctly
|
|
|
|
self._should_move_fp32_shards_h2d: bool = cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
|
|
|
|
optimizer, 'num_fp32_shards_per_param', 0) >= 2
|
2022-03-03 07:42:53 +00:00
|
|
|
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
|
2022-03-03 07:06:18 +00:00
|
|
|
self.optim_state: OptimState = OptimState.UNSCALED
|
|
|
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
|
|
|
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
|
|
|
|
# Grad scaler
|
|
|
|
self.grad_scaler = DynamicGradScaler(initial_scale=initial_scale,
|
|
|
|
min_scale=min_scale,
|
|
|
|
growth_factor=growth_factor,
|
|
|
|
backoff_factor=backoff_factor,
|
|
|
|
growth_interval=growth_interval,
|
|
|
|
hysteresis=hysteresis,
|
|
|
|
max_scale=max_scale)
|
2022-03-09 08:09:36 +00:00
|
|
|
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device())
|
2022-03-29 01:08:18 +00:00
|
|
|
self._logger = get_dist_logger("ShardedOptimizerV2")
|
2022-03-03 07:42:53 +00:00
|
|
|
|
2022-03-09 08:09:36 +00:00
|
|
|
# Store fp32 param shards
|
2022-03-03 07:42:53 +00:00
|
|
|
self.master_params: Dict[Parameter, Tensor] = {}
|
2022-03-03 07:06:18 +00:00
|
|
|
|
2022-03-18 07:44:47 +00:00
|
|
|
for group in self.optim.param_groups:
|
2022-03-03 07:06:18 +00:00
|
|
|
for p in group['params']:
|
2022-03-09 08:09:36 +00:00
|
|
|
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
|
2022-03-22 06:36:16 +00:00
|
|
|
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
2022-03-09 08:09:36 +00:00
|
|
|
if not is_param_sharded:
|
|
|
|
# TODO (ver217): we may not use shard / gather here
|
|
|
|
# Param is no sharded, which means we use ZeRO-2 here
|
|
|
|
# As we only store param shard, we shard it here
|
2022-03-22 06:36:16 +00:00
|
|
|
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device)
|
2022-03-09 08:09:36 +00:00
|
|
|
if not is_param_sharded:
|
|
|
|
# In this branch, there's no need to shard param
|
|
|
|
# So we gather here
|
2022-03-22 06:36:16 +00:00
|
|
|
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
2022-03-03 07:06:18 +00:00
|
|
|
|
2022-03-29 01:08:18 +00:00
|
|
|
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
|
|
|
ranks=[0])
|
|
|
|
|
2022-03-29 07:45:48 +00:00
|
|
|
self._use_memory_tracer = self.model.use_memory_tracer
|
|
|
|
if self._use_memory_tracer:
|
|
|
|
GLOBAL_MODEL_DATA_TRACER.register_optimizer(self)
|
|
|
|
|
2022-03-29 01:08:18 +00:00
|
|
|
def get_memory_usage(self) -> Tuple[int, int]:
|
|
|
|
"""
|
|
|
|
Get the memory usage of the optimizer. Including master_params (param fp32),
|
|
|
|
momentum (self.state[p]['exp_avg']) variance (self.state[p]['exp_avg_sq'])
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Tuple[int, int]: cuda/cpu memory usage in Byte.
|
|
|
|
"""
|
|
|
|
cuda_use = 0
|
|
|
|
cpu_use = 0
|
|
|
|
|
|
|
|
def update_mem_use(t):
|
|
|
|
nonlocal cuda_use
|
|
|
|
nonlocal cpu_use
|
|
|
|
t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t)
|
|
|
|
cuda_use += t_cuda_use
|
|
|
|
cpu_use += t_cpu_use
|
|
|
|
|
|
|
|
for _, p_fp32 in self.master_params.items():
|
|
|
|
update_mem_use(p_fp32)
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
for p in group['params']:
|
|
|
|
state = self.optim.state[p]
|
|
|
|
for k, v in state.items():
|
|
|
|
update_mem_use(v)
|
|
|
|
|
|
|
|
return cuda_use, cpu_use
|
|
|
|
|
2022-03-03 07:06:18 +00:00
|
|
|
def step(self, *args, **kwargs):
|
2022-03-22 07:53:48 +00:00
|
|
|
self._maybe_move_fp32_shards()
|
2022-03-22 06:56:59 +00:00
|
|
|
|
2022-03-03 07:06:18 +00:00
|
|
|
# unscale grads if scaled
|
|
|
|
if self.optim_state == OptimState.SCALED:
|
|
|
|
self._unscale_grads()
|
|
|
|
|
|
|
|
found_inf = self._check_overflow()
|
|
|
|
self.grad_scaler.update(found_inf)
|
|
|
|
|
|
|
|
if found_inf:
|
2022-03-29 01:08:18 +00:00
|
|
|
self._logger.warning('found inf during ShardedOptimV2 step')
|
2022-03-03 07:06:18 +00:00
|
|
|
self.zero_grad()
|
|
|
|
return
|
|
|
|
|
2022-03-10 09:51:50 +00:00
|
|
|
# assign master param pointers to p.data.
|
|
|
|
# We will not trigger data copy here.
|
2022-03-18 07:44:47 +00:00
|
|
|
for group in self.optim.param_groups:
|
2022-03-03 07:06:18 +00:00
|
|
|
for p in group['params']:
|
2022-03-03 07:42:53 +00:00
|
|
|
p.data = self.master_params[p]
|
2022-03-09 08:09:36 +00:00
|
|
|
# Now p.data is sharded
|
|
|
|
# So optimizer states are sharded naturally
|
2022-03-10 09:51:50 +00:00
|
|
|
|
2022-03-29 01:08:18 +00:00
|
|
|
self._logger.debug(f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
|
|
|
ranks=[0])
|
|
|
|
|
2022-03-18 07:44:47 +00:00
|
|
|
ret = self.optim.step(*args, **kwargs)
|
2022-03-10 09:51:50 +00:00
|
|
|
|
2022-03-29 01:08:18 +00:00
|
|
|
self._logger.debug(f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
|
|
|
ranks=[0])
|
2022-03-10 09:51:50 +00:00
|
|
|
# Copy master param data (fp32) to payload of col_attr (fp16)
|
|
|
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
|
|
|
# a chunk.
|
2022-03-18 07:44:47 +00:00
|
|
|
for group in self.optim.param_groups:
|
2022-03-03 07:06:18 +00:00
|
|
|
for p in group['params']:
|
2022-03-22 06:36:16 +00:00
|
|
|
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
2022-03-09 08:09:36 +00:00
|
|
|
if not is_param_sharded:
|
|
|
|
# We use ZeRO-2 here
|
2022-03-22 06:36:16 +00:00
|
|
|
# The `p.col_attr.sharded_data_tensor` saves full fp16 param
|
2022-03-09 08:09:36 +00:00
|
|
|
# But we only have updated fp32 param shard here
|
|
|
|
# So we first shard full fp16 param and copy fp32 param shard to it
|
|
|
|
# Then we will gather them
|
2022-03-22 06:36:16 +00:00
|
|
|
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
2022-03-09 08:09:36 +00:00
|
|
|
# We have to use `copy_payload` instead of `reset_payload`
|
2022-03-22 06:36:16 +00:00
|
|
|
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
|
2022-03-10 09:51:50 +00:00
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
2022-03-29 04:48:00 +00:00
|
|
|
p.col_attr.sharded_data_tensor.reset_payload(
|
|
|
|
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
2022-03-10 09:51:50 +00:00
|
|
|
|
2022-03-09 08:09:36 +00:00
|
|
|
if not is_param_sharded:
|
|
|
|
# We gather full fp16 param here
|
2022-03-22 06:36:16 +00:00
|
|
|
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
p.data = p.col_attr.sharded_data_tensor.payload
|
2022-03-03 07:06:18 +00:00
|
|
|
return ret
|
|
|
|
|
|
|
|
def backward(self, loss: Tensor) -> None:
|
|
|
|
loss = self.loss_scale * loss
|
|
|
|
self.optim_state = OptimState.SCALED
|
2022-03-09 08:09:36 +00:00
|
|
|
self.model.backward(loss)
|
2022-03-03 07:06:18 +00:00
|
|
|
|
|
|
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None:
|
2022-03-09 08:09:36 +00:00
|
|
|
self.model.backward_by_grad(tensor, grad)
|
2022-03-03 07:06:18 +00:00
|
|
|
|
|
|
|
def clip_grad_norm(self, model: nn.Module, max_norm: float):
|
|
|
|
if self.optim_state == OptimState.SCALED:
|
|
|
|
self._unscale_grads()
|
|
|
|
return super().clip_grad_norm(model, max_norm)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def loss_scale(self):
|
2022-03-09 08:09:36 +00:00
|
|
|
return self.grad_scaler.scale.item()
|
2022-03-03 07:06:18 +00:00
|
|
|
|
|
|
|
def _check_overflow(self):
|
|
|
|
# clear previous overflow record
|
|
|
|
self._found_overflow.fill_(0.0)
|
|
|
|
|
|
|
|
# check for overflow
|
2022-03-18 07:44:47 +00:00
|
|
|
for group in self.optim.param_groups:
|
2022-03-03 07:06:18 +00:00
|
|
|
for p in group['params']:
|
|
|
|
if has_inf_or_nan(p.grad):
|
|
|
|
self._found_overflow.fill_(1.0)
|
|
|
|
break
|
|
|
|
|
|
|
|
# all-reduce across dp group
|
|
|
|
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.dp_process_group)
|
|
|
|
|
|
|
|
# all-reduce over model parallel group
|
|
|
|
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=self.mp_process_group)
|
|
|
|
|
2022-03-03 07:50:30 +00:00
|
|
|
return self._found_overflow.item() > 0
|
2022-03-03 07:06:18 +00:00
|
|
|
|
|
|
|
def _unscale_grads(self):
|
|
|
|
assert self.optim_state == OptimState.SCALED
|
2022-03-18 07:44:47 +00:00
|
|
|
for group in self.optim.param_groups:
|
2022-03-03 07:06:18 +00:00
|
|
|
for p in group['params']:
|
|
|
|
if p.grad is not None:
|
|
|
|
p.grad.data.div_(self.loss_scale)
|
|
|
|
self.optim_state = OptimState.UNSCALED
|
2022-03-09 08:09:36 +00:00
|
|
|
|
|
|
|
def zero_grad(self, *args, **kwargs):
|
|
|
|
# We must set grad to None
|
|
|
|
# Because we will judge whether local grad accumulation
|
|
|
|
# is enabled by wheter grad is None
|
2022-03-18 07:44:47 +00:00
|
|
|
self.optim.zero_grad(set_to_none=True)
|
2022-03-16 11:29:37 +00:00
|
|
|
|
|
|
|
def sync_grad(self):
|
|
|
|
pass
|
2022-03-22 07:53:48 +00:00
|
|
|
|
|
|
|
def _maybe_move_fp32_shards(self):
|
|
|
|
if self._should_move_fp32_shards_h2d:
|
|
|
|
self._should_move_fp32_shards_h2d = False
|
|
|
|
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
|
|
|
|
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
|
|
|
fp32_shards_used_cuda_margin_mem = 0
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
for p in group['params']:
|
|
|
|
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
|
|
|
|
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
|
|
|
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
|
|
|
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
2022-03-23 06:59:59 +00:00
|
|
|
p.col_attr.offload_grad = False
|
2022-03-22 07:53:48 +00:00
|
|
|
fp32_shards_used_cuda_margin_mem += shard_mem
|