|
|
|
@ -28,48 +28,29 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
"""A wrapper for optimizer. `ShardedOptimizerV2` and `ShardedModelV2` implement Zero Redundancy Optimizer (ZeRO) stage 3.
|
|
|
|
|
You must use `ShardedOptimizerV2` with `ShardedModelV2`.
|
|
|
|
|
|
|
|
|
|
:param sharded_model: A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
|
|
|
|
shard strategy provided by sharded model to shard param fp32 tensors.
|
|
|
|
|
:type sharded_model: sharded_model
|
|
|
|
|
|
|
|
|
|
:param optimizer: A Optimizer instance.
|
|
|
|
|
:type optimizer: Optimizer
|
|
|
|
|
|
|
|
|
|
:param cpu_offload: is offloading the optimizer states to CPU.
|
|
|
|
|
:type cpu_offload: bool
|
|
|
|
|
|
|
|
|
|
:param initial_scale: initial scale used by DynamicGradScaler
|
|
|
|
|
:type initial_scale: float
|
|
|
|
|
|
|
|
|
|
:param min_scale: min scale used by DynamicGradScaler
|
|
|
|
|
:type min_scale: float
|
|
|
|
|
|
|
|
|
|
:param growth_factor: growth_factor used by DynamicGradScaler
|
|
|
|
|
:type growth_factor: float
|
|
|
|
|
|
|
|
|
|
:param backoff_factor: backoff_factor used by DynamicGradScaler
|
|
|
|
|
:type backoff_factor: float
|
|
|
|
|
|
|
|
|
|
:param growth_interval: growth_interval used by DynamicGradScaler
|
|
|
|
|
:type growth_interval: float
|
|
|
|
|
|
|
|
|
|
:param hysteresis: hysteresis used by DynamicGradScaler
|
|
|
|
|
:type hysteresis: float
|
|
|
|
|
|
|
|
|
|
:param max_scale: max_scale used by DynamicGradScaler
|
|
|
|
|
:type max_scale: float
|
|
|
|
|
|
|
|
|
|
:param dp_process_group: data paralle process group
|
|
|
|
|
:type dp_process_group: Optional[ProcessGroup]
|
|
|
|
|
|
|
|
|
|
:param mp_process_group: model paralle process group
|
|
|
|
|
:type mp_process_group: Optional[ProcessGroup]
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
|
|
|
|
shard strategy provided by sharded model to shard param fp32 tensors.
|
|
|
|
|
optimizer (Optimizer): An Optimizer instance.
|
|
|
|
|
cpu_offload (bool, optional): Is offloading the optimizer states to CPU.. Defaults to False.
|
|
|
|
|
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
|
|
|
|
which will be used when using hybrid CPU optimizer. Defaults to 0.0.
|
|
|
|
|
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
|
|
|
|
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
|
|
|
|
growth_factor (float, optional): growth_factor used by DynamicGradScaler. Defaults to 2.
|
|
|
|
|
backoff_factor (float, optional): backoff_factor used by DynamicGradScaler. Defaults to 0.5.
|
|
|
|
|
growth_interval (float, optional): growth_interval used by DynamicGradScaler. Defaults to 1000.
|
|
|
|
|
hysteresis (float, optional): hysteresis used by DynamicGradScaler. Defaults to 2.
|
|
|
|
|
max_scale (int, optional): max_scale used by DynamicGradScaler. Defaults to 2**32.
|
|
|
|
|
dp_process_group (Optional[ProcessGroup], optional): data paralle process group. Defaults to None.
|
|
|
|
|
mp_process_group (Optional[ProcessGroup], optional): model paralle process group. Defaults to None.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
sharded_model: ShardedModelV2,
|
|
|
|
|
optimizer: Optimizer,
|
|
|
|
|
cpu_offload: bool = False,
|
|
|
|
|
gpu_margin_mem_ratio: float = 0.0,
|
|
|
|
|
initial_scale: float = 2**32,
|
|
|
|
|
min_scale: float = 1,
|
|
|
|
|
growth_factor: float = 2,
|
|
|
|
@ -88,6 +69,13 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload"
|
|
|
|
|
)
|
|
|
|
|
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
|
|
|
|
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
|
|
|
|
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
|
|
|
|
|
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
|
|
|
|
|
# and it must set `num_fp32_shards_per_param` correctly
|
|
|
|
|
self._should_move_fp32_shards_h2d: bool = cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
|
|
|
|
|
optimizer, 'num_fp32_shards_per_param', 0) >= 2
|
|
|
|
|
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
|
|
|
|
|
self.optim_state: OptimState = OptimState.UNSCALED
|
|
|
|
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
|
|
|
@ -122,6 +110,20 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
|
|
|
|
|
|
def step(self, *args, **kwargs):
|
|
|
|
|
if self._should_move_fp32_shards_h2d:
|
|
|
|
|
self._should_move_fp32_shards_h2d = False
|
|
|
|
|
available_cuda_margin_mem = self.model.cuda_margin_space * self.gpu_margin_mem_ratio
|
|
|
|
|
fp32_shards_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
|
|
|
|
fp32_shards_used_cuda_margin_mem = 0
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
|
|
|
|
|
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
|
|
|
|
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
|
|
|
|
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
|
|
|
|
p.col_attr.offload_fp32_grad = False
|
|
|
|
|
fp32_shards_used_cuda_margin_mem += shard_mem
|
|
|
|
|
|
|
|
|
|
# unscale grads if scaled
|
|
|
|
|
if self.optim_state == OptimState.SCALED:
|
|
|
|
|
self._unscale_grads()
|
|
|
|
|