|
|
|
@ -1,5 +1,6 @@
|
|
|
|
|
from enum import Enum |
|
|
|
|
from typing import Dict, Optional |
|
|
|
|
from os import stat |
|
|
|
|
from typing import Dict, Optional, Tuple |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
@ -16,7 +17,7 @@ from torch.distributed import ProcessGroup
|
|
|
|
|
from torch.nn.parameter import Parameter |
|
|
|
|
from torch.optim import Optimizer |
|
|
|
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan |
|
|
|
|
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move |
|
|
|
|
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OptimState(Enum): |
|
|
|
@ -26,14 +27,20 @@ class OptimState(Enum):
|
|
|
|
|
|
|
|
|
|
class ShardedOptimizerV2(ColossalaiOptimizer): |
|
|
|
|
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO). |
|
|
|
|
|
|
|
|
|
By default the ZeRO optimizer stage 3 offload Optimizer States on CPU. |
|
|
|
|
|
|
|
|
|
We apply the Device-aware Operator Placement technique for OS placement from the following paper. |
|
|
|
|
|
|
|
|
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management |
|
|
|
|
https://arxiv.org/abs/2108.05818 |
|
|
|
|
|
|
|
|
|
GPU margin space is the remaining space after removing peak non-model data from the overall GPU memory, |
|
|
|
|
which is detected by a runtime memory tracer. |
|
|
|
|
|
|
|
|
|
We place as many OS chunks in the margin space as possible. |
|
|
|
|
The size of margin space can be controlled by `gpu_margin_mem_ratio` |
|
|
|
|
|
|
|
|
|
The size of margin space can be controlled by `gpu_margin_mem_ratio`。 |
|
|
|
|
If it is set as 0.0, it is the same as classical ZeRO optimizer. |
|
|
|
|
|
|
|
|
|
NOTE() You must use `ShardedOptimizerV2` with `ShardedModelV2`. |
|
|
|
@ -99,7 +106,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
hysteresis=hysteresis, |
|
|
|
|
max_scale=max_scale) |
|
|
|
|
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device()) |
|
|
|
|
self._logger = get_dist_logger() |
|
|
|
|
self._logger = get_dist_logger("ShardedOptimizerV2") |
|
|
|
|
|
|
|
|
|
# Store fp32 param shards |
|
|
|
|
self.master_params: Dict[Parameter, Tensor] = {} |
|
|
|
@ -119,6 +126,37 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
# So we gather here |
|
|
|
|
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) |
|
|
|
|
|
|
|
|
|
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", |
|
|
|
|
ranks=[0]) |
|
|
|
|
|
|
|
|
|
def get_memory_usage(self) -> Tuple[int, int]: |
|
|
|
|
""" |
|
|
|
|
Get the memory usage of the optimizer. Including master_params (param fp32), |
|
|
|
|
momentum (self.state[p]['exp_avg']) variance (self.state[p]['exp_avg_sq']) |
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
Tuple[int, int]: cuda/cpu memory usage in Byte. |
|
|
|
|
""" |
|
|
|
|
cuda_use = 0 |
|
|
|
|
cpu_use = 0 |
|
|
|
|
|
|
|
|
|
def update_mem_use(t): |
|
|
|
|
nonlocal cuda_use |
|
|
|
|
nonlocal cpu_use |
|
|
|
|
t_cuda_use, t_cpu_use = colo_tensor_mem_usage(t) |
|
|
|
|
cuda_use += t_cuda_use |
|
|
|
|
cpu_use += t_cpu_use |
|
|
|
|
|
|
|
|
|
for _, p_fp32 in self.master_params.items(): |
|
|
|
|
update_mem_use(p_fp32) |
|
|
|
|
for group in self.optim.param_groups: |
|
|
|
|
for p in group['params']: |
|
|
|
|
state = self.optim.state[p] |
|
|
|
|
for k, v in state.items(): |
|
|
|
|
update_mem_use(v) |
|
|
|
|
|
|
|
|
|
return cuda_use, cpu_use |
|
|
|
|
|
|
|
|
|
def step(self, *args, **kwargs): |
|
|
|
|
self._maybe_move_fp32_shards() |
|
|
|
|
|
|
|
|
@ -130,7 +168,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
self.grad_scaler.update(found_inf) |
|
|
|
|
|
|
|
|
|
if found_inf: |
|
|
|
|
self._logger.info('found inf during ShardedOptimV2 step') |
|
|
|
|
self._logger.warning('found inf during ShardedOptimV2 step') |
|
|
|
|
self.zero_grad() |
|
|
|
|
return |
|
|
|
|
|
|
|
|
@ -142,8 +180,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
# Now p.data is sharded |
|
|
|
|
# So optimizer states are sharded naturally |
|
|
|
|
|
|
|
|
|
self._logger.debug(f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", |
|
|
|
|
ranks=[0]) |
|
|
|
|
|
|
|
|
|
ret = self.optim.step(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
self._logger.debug(f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", |
|
|
|
|
ranks=[0]) |
|
|
|
|
# Copy master param data (fp32) to payload of col_attr (fp16) |
|
|
|
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering |
|
|
|
|
# a chunk. |
|
|
|
|