mirror of https://github.com/hpcaitech/ColossalAI
[zero]support no_sync method for zero1 plugin (#4138)
* support no sync for zero1 plugin * polish * polishpull/4359/head
parent
c6ab96983a
commit
79cf1b5f33
|
@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .accelerator import Accelerator
|
||||
from .mixed_precision import MixedPrecision, mixed_precision_factory
|
||||
|
@ -153,18 +153,20 @@ class Booster:
|
|||
# return loss or outputs if needed
|
||||
pass
|
||||
|
||||
def no_sync(self, model: nn.Module) -> contextmanager:
|
||||
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
|
||||
"""Context manager to disable gradient synchronization across DP process groups.
|
||||
Support torch DDP and Low Level ZeRO-1 for now.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be disabled gradient synchronization.
|
||||
model (nn.Module): The model to be disabled gradient synchronization, for DDP
|
||||
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
|
||||
|
||||
Returns:
|
||||
contextmanager: Context to disable gradient synchronization.
|
||||
"""
|
||||
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
|
||||
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model)
|
||||
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
|
||||
"""Load model from checkpoint.
|
||||
|
|
|
@ -408,5 +408,5 @@ class GeminiPlugin(DPPluginBase):
|
|||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return GeminiCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -179,8 +179,11 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
norm_type=norm_type)
|
||||
self.verbose = verbose
|
||||
|
||||
# set class name with stage, for better error message
|
||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||
|
||||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
return self.stage == 1
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
|
@ -219,5 +222,6 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return LowLevelZeroCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||
return optimizer.optim.no_sync()
|
||||
|
|
|
@ -61,7 +61,7 @@ class Plugin(ABC):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
"""
|
||||
Context manager to disable gradient synchronization.
|
||||
"""
|
||||
|
|
|
@ -168,6 +168,6 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchDDPCheckpointIO()
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
|
||||
return model.module.no_sync()
|
||||
|
|
|
@ -177,7 +177,7 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
False
|
||||
|
||||
def no_sync(self, model: nn.Module) -> Iterator[None]:
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
|
|
|
@ -14,10 +14,10 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
|||
)
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup
|
||||
from colossalai.utils import conditional_context
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._utils import (
|
||||
|
@ -56,7 +56,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
|
|||
return False
|
||||
|
||||
|
||||
class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
||||
class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||
"""Optimizer used for ZeRO-1 and ZeRO-2.
|
||||
"""
|
||||
|
||||
|
@ -77,11 +77,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
overlap_communication: bool = False,
|
||||
partition_grad: bool = False, # stage 2 flag
|
||||
cpu_offload: bool = False, # cpu offload
|
||||
grad_accumulate_interval: int = 1,
|
||||
forced_dtype: Optional[torch.dtype] = None):
|
||||
|
||||
assert not (partition_grad and grad_accumulate_interval > 1), \
|
||||
"gradient accumulation is not compatible with ZeRO-2"
|
||||
# TODO:
|
||||
# 1. process group api
|
||||
# 2. checkpoint IO
|
||||
|
||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||
self._dtype = self.optim.param_groups[0]['params'][0].dtype
|
||||
self._logger = get_dist_logger()
|
||||
|
@ -94,8 +95,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
# grad accumulation
|
||||
self.require_grad_sync = True
|
||||
self._accumulate_intervel = grad_accumulate_interval
|
||||
self._accumulate_step = 0
|
||||
|
||||
colo_pg = self._search_colo_process_group()
|
||||
if isinstance(colo_pg, ProcessGroup):
|
||||
|
@ -340,15 +339,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
################################
|
||||
|
||||
def backward(self, loss, retain_graph=False):
|
||||
assert not(self._partition_grads and not self.require_grad_sync), \
|
||||
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
|
||||
|
||||
if self.mixed_precision_mixin is not None:
|
||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||
|
||||
self._accumulate_step += 1
|
||||
no_sync = self._accumulate_step < self._accumulate_intervel
|
||||
with conditional_context(self.no_sync(), enable=no_sync):
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
if no_sync:
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
|
||||
self._reduce_grad(self._partition_grads)
|
||||
|
@ -385,7 +384,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
def step(self, closure=None):
|
||||
assert closure is None, 'closure is not supported by step()'
|
||||
|
||||
if not self._accumulate_step == self._accumulate_intervel:
|
||||
if not self.require_grad_sync:
|
||||
return
|
||||
|
||||
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
|
||||
|
@ -393,7 +392,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
if self._verbose:
|
||||
self._logger.info(f'Found overflow. Skip step')
|
||||
self.zero_grad()
|
||||
self._accumulate_step -= 1
|
||||
return
|
||||
|
||||
# record all grads for unscale and clip
|
||||
|
@ -463,9 +461,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
# reset accumulate step
|
||||
self._accumulate_step = 0
|
||||
|
||||
#############################
|
||||
# Mixed Precision Utilities #
|
||||
#############################
|
||||
|
|
|
@ -9,6 +9,7 @@ from torch.testing import assert_close
|
|||
import colossalai
|
||||
from colossalai.testing import spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from colossalai.utils import conditional_context
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
|
@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc():
|
|||
overlap_communication=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
grad_accumulate_interval=2,
|
||||
verbose=True)
|
||||
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
|
||||
overlap_communication=True,
|
||||
partition_grad=True,
|
||||
initial_scale=32,
|
||||
clip_grad_norm=1.0,
|
||||
grad_accumulate_interval=2)
|
||||
clip_grad_norm=1.0)
|
||||
# create data
|
||||
seed_all(2021 + local_rank)
|
||||
input_data1 = torch.randn(32, 128).cuda()
|
||||
|
@ -59,8 +58,11 @@ def exam_zero_1_2_grad_acc():
|
|||
assert torch.equal(zero1_output, zero2_output)
|
||||
|
||||
# zero-dp backward
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
no_sync = number == 0
|
||||
with conditional_context(zero1_optimizer.no_sync(), no_sync):
|
||||
zero1_optimizer.backward(zero1_output.sum().float())
|
||||
with conditional_context(zero2_optimizer.no_sync(), no_sync):
|
||||
zero2_optimizer.backward(zero2_output.sum().float())
|
||||
|
||||
if check_flag:
|
||||
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
|
||||
|
@ -101,8 +103,7 @@ def exam_zero_1_grad_acc():
|
|||
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
|
||||
overlap_communication=False,
|
||||
reduce_bucket_size=262144,
|
||||
clip_grad_norm=1.0,
|
||||
grad_accumulate_interval=2)
|
||||
clip_grad_norm=1.0)
|
||||
|
||||
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
|
||||
|
||||
|
@ -112,20 +113,15 @@ def exam_zero_1_grad_acc():
|
|||
input_data2 = torch.randn(32, 128).cuda()
|
||||
|
||||
def fwd_bwd_func(number, cur_data, check_flag):
|
||||
# zero-dp forward
|
||||
zero_output = zero_model(cur_data)
|
||||
|
||||
# torch-ddp forward
|
||||
no_sync = number == 0
|
||||
# zero1 fwd and bwd
|
||||
with conditional_context(zero_optimizer.no_sync(), no_sync):
|
||||
zero_output = zero_model(cur_data)
|
||||
zero_optimizer.backward(zero_output.sum().float())
|
||||
|
||||
# zero-dp backward
|
||||
zero_optimizer.backward(zero_output.sum().float())
|
||||
# torch-ddp backward
|
||||
if number < 1:
|
||||
with torch_model.no_sync():
|
||||
torch_output = torch_model(cur_data)
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
torch_output.sum().backward()
|
||||
else:
|
||||
# torch-ddp fwd and bwd
|
||||
with conditional_context(torch_model.no_sync(), no_sync):
|
||||
torch_output = torch_model(cur_data)
|
||||
assert torch.equal(zero_output, torch_output)
|
||||
torch_output.sum().backward()
|
||||
|
@ -133,7 +129,6 @@ def exam_zero_1_grad_acc():
|
|||
if check_flag:
|
||||
# check grad
|
||||
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
|
||||
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
|
||||
assert torch.equal(p.grad, z1p.grad)
|
||||
|
||||
fwd_bwd_func(0, input_data1, True)
|
||||
|
|
Loading…
Reference in New Issue