|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
from enum import Enum |
|
|
|
|
from typing import Dict, Optional, Union |
|
|
|
|
from typing import Dict, Optional |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
@ -8,7 +8,9 @@ from colossalai.amp.naive_amp._fp16_optimizer import DynamicGradScaler
|
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer |
|
|
|
|
from colossalai.zero.shard_utils import BaseShardStrategy |
|
|
|
|
from colossalai.zero.sharded_model import ShardedModelV2 |
|
|
|
|
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp32 |
|
|
|
|
from torch import Tensor |
|
|
|
|
from torch.distributed import ProcessGroup |
|
|
|
|
from torch.nn.parameter import Parameter |
|
|
|
@ -26,7 +28,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
|
|
|
|
|
def __init__(self, |
|
|
|
|
optimizer: Optimizer, |
|
|
|
|
sharded_model: Union[nn.Module, ShardedModelV2], |
|
|
|
|
sharded_model: ShardedModelV2, |
|
|
|
|
shard_strategy: BaseShardStrategy, |
|
|
|
|
cpu_offload: bool = False, |
|
|
|
|
initial_scale: float = 2**32, |
|
|
|
|
min_scale: float = 1, |
|
|
|
@ -37,9 +40,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
max_scale: int = 2**32, |
|
|
|
|
dp_process_group: Optional[ProcessGroup] = None, |
|
|
|
|
mp_process_group: Optional[ProcessGroup] = None) -> None: |
|
|
|
|
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' |
|
|
|
|
super().__init__(optimizer) |
|
|
|
|
self.model: Union[nn.Module, ShardedModelV2] = sharded_model |
|
|
|
|
self.model_is_sharded = isinstance(sharded_model, ShardedModelV2) |
|
|
|
|
self.shard_strategy = shard_strategy |
|
|
|
|
self.model: ShardedModelV2 = sharded_model |
|
|
|
|
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu') |
|
|
|
|
self.optim_state: OptimState = OptimState.UNSCALED |
|
|
|
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) |
|
|
|
@ -52,20 +56,25 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
growth_interval=growth_interval, |
|
|
|
|
hysteresis=hysteresis, |
|
|
|
|
max_scale=max_scale) |
|
|
|
|
self._found_overflow: Tensor = torch.FloatTensor([0]).to(self.device) |
|
|
|
|
self._found_overflow: Tensor = torch.FloatTensor([0]).to(torch.cuda.current_device()) |
|
|
|
|
|
|
|
|
|
# Store fp32 params |
|
|
|
|
# Store fp32 param shards |
|
|
|
|
self.master_params: Dict[Parameter, Tensor] = {} |
|
|
|
|
|
|
|
|
|
for group in optimizer.param_groups: |
|
|
|
|
for p in group['params']: |
|
|
|
|
if hasattr(p, 'ca_attr'): |
|
|
|
|
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model' |
|
|
|
|
self.master_params[p] = p.ca_attr.payload(self.device) |
|
|
|
|
else: |
|
|
|
|
self.master_params[p] = p.data.to(device=self.device) |
|
|
|
|
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float: |
|
|
|
|
self.master_params[p] = self.master_params[p].to(torch.float) |
|
|
|
|
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' |
|
|
|
|
is_param_sharded = p.col_attr.data.is_sharded |
|
|
|
|
if not is_param_sharded: |
|
|
|
|
# TODO (ver217): we may not use shard / gather here |
|
|
|
|
# Param is no sharded, which means we use ZeRO-2 here |
|
|
|
|
# As we only store param shard, we shard it here |
|
|
|
|
self.shard_strategy.shard([p.col_attr.data]) |
|
|
|
|
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device) |
|
|
|
|
if not is_param_sharded: |
|
|
|
|
# In this branch, there's no need to shard param |
|
|
|
|
# So we gather here |
|
|
|
|
self.shard_strategy.gather([p.col_attr.data]) |
|
|
|
|
|
|
|
|
|
def step(self, *args, **kwargs): |
|
|
|
|
# unscale grads if scaled |
|
|
|
@ -83,28 +92,36 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
for group in self.optim.param_groups: |
|
|
|
|
for p in group['params']: |
|
|
|
|
p.data = self.master_params[p] |
|
|
|
|
# Now p.data is sharded |
|
|
|
|
# So optimizer states are sharded naturally |
|
|
|
|
ret = self.optim.step(*args, **kwargs) |
|
|
|
|
# Write master param to payload |
|
|
|
|
for group in self.optim.param_groups: |
|
|
|
|
for p in group['params']: |
|
|
|
|
if hasattr(p, 'ca_attr'): |
|
|
|
|
p.ca_attr.set_payload(p.data) |
|
|
|
|
p.data = p.ca_attr.payload() |
|
|
|
|
is_param_sharded = p.col_attr.data.is_sharded |
|
|
|
|
if not is_param_sharded: |
|
|
|
|
# We use ZeRO-2 here |
|
|
|
|
# The `p.col_attr.data` saves full fp16 param |
|
|
|
|
# But we only have updated fp32 param shard here |
|
|
|
|
# So we first shard full fp16 param and copy fp32 param shard to it |
|
|
|
|
# Then we will gather them |
|
|
|
|
self.shard_strategy.shard([p.col_attr.data]) |
|
|
|
|
# We have to use `copy_payload` instead of `reset_payload` |
|
|
|
|
# Since p.data is fp32 and p.col_attr.data is fp16 |
|
|
|
|
p.col_attr.data.copy_payload(p.data) |
|
|
|
|
if not is_param_sharded: |
|
|
|
|
# We gather full fp16 param here |
|
|
|
|
self.shard_strategy.gather([p.col_attr.data]) |
|
|
|
|
p.data = p.col_attr.data.payload |
|
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
def backward(self, loss: Tensor) -> None: |
|
|
|
|
loss = self.loss_scale * loss |
|
|
|
|
self.optim_state = OptimState.SCALED |
|
|
|
|
if self.model_is_sharded: |
|
|
|
|
self.model.backward(loss) |
|
|
|
|
else: |
|
|
|
|
super().backward(loss) |
|
|
|
|
self.model.backward(loss) |
|
|
|
|
|
|
|
|
|
def backward_by_grad(self, tensor: Tensor, grad: Tensor) -> None: |
|
|
|
|
if self.model_is_sharded: |
|
|
|
|
self.model.backward_by_grad(tensor, grad) |
|
|
|
|
else: |
|
|
|
|
super().backward_by_grad(tensor, grad) |
|
|
|
|
self.model.backward_by_grad(tensor, grad) |
|
|
|
|
|
|
|
|
|
def clip_grad_norm(self, model: nn.Module, max_norm: float): |
|
|
|
|
if self.optim_state == OptimState.SCALED: |
|
|
|
@ -113,7 +130,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
def loss_scale(self): |
|
|
|
|
return self.grad_scaler.scale |
|
|
|
|
return self.grad_scaler.scale.item() |
|
|
|
|
|
|
|
|
|
def _check_overflow(self): |
|
|
|
|
# clear previous overflow record |
|
|
|
@ -141,3 +158,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
if p.grad is not None: |
|
|
|
|
p.grad.data.div_(self.loss_scale) |
|
|
|
|
self.optim_state = OptimState.UNSCALED |
|
|
|
|
|
|
|
|
|
def zero_grad(self, *args, **kwargs): |
|
|
|
|
# We must set grad to None |
|
|
|
|
# Because we will judge whether local grad accumulation |
|
|
|
|
# is enabled by wheter grad is None |
|
|
|
|
self.optim.zero_grad(set_to_none=True) |
|
|
|
|