[zero]support no_sync method for zero1 plugin (#4138)

* support no sync for zero1 plugin

* polish

* polish
pull/4359/head
LuGY 2023-07-04 12:00:33 +08:00 committed by Hongxin Liu
parent c6ab96983a
commit 79cf1b5f33
8 changed files with 45 additions and 49 deletions

View File

@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
@ -153,18 +153,20 @@ class Booster:
# return loss or outputs if needed
pass
def no_sync(self, model: nn.Module) -> contextmanager:
def no_sync(self, model: nn.Module = None, optimizer: OptimizerWrapper = None) -> contextmanager:
"""Context manager to disable gradient synchronization across DP process groups.
Support torch DDP and Low Level ZeRO-1 for now.
Args:
model (nn.Module): The model to be disabled gradient synchronization.
model (nn.Module): The model to be disabled gradient synchronization, for DDP
optimizer (OptimizerWrapper): The optimizer to be disabled gradient synchronization, for ZeRO1-1
Returns:
contextmanager: Context to disable gradient synchronization.
"""
assert self.plugin is not None, f'no_sync is only enabled when a plugin is provided and the plugin supports no_sync.'
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model)
assert self.plugin.support_no_sync(), f'The plugin {self.plugin.__class__.__name__} does not support no_sync.'
return self.plugin.no_sync(model, optimizer)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True):
"""Load model from checkpoint.

View File

@ -408,5 +408,5 @@ class GeminiPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return GeminiCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError

View File

@ -179,8 +179,11 @@ class LowLevelZeroPlugin(DPPluginBase):
norm_type=norm_type)
self.verbose = verbose
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
def support_no_sync(self) -> bool:
return False
return self.stage == 1
def control_precision(self) -> bool:
return True
@ -219,5 +222,6 @@ class LowLevelZeroPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return LowLevelZeroCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
raise NotImplementedError
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.optim.no_sync()

View File

@ -61,7 +61,7 @@ class Plugin(ABC):
pass
@abstractmethod
def no_sync(self, model: nn.Module) -> Iterator[None]:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
"""
Context manager to disable gradient synchronization.
"""

View File

@ -168,6 +168,6 @@ class TorchDDPPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return TorchDDPCheckpointIO()
def no_sync(self, model: nn.Module) -> Iterator[None]:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), 'Model is not boosted by TorchDDPPlugin.'
return model.module.no_sync()

View File

@ -177,7 +177,7 @@ class TorchFSDPPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
False
def no_sync(self, model: nn.Module) -> Iterator[None]:
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
def control_precision(self) -> bool:

View File

@ -14,10 +14,10 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
)
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor import ColoParameter, ProcessGroup
from colossalai.utils import conditional_context
from colossalai.utils.cuda import get_current_device
from ._utils import (
@ -56,7 +56,7 @@ class LowLevelZeroFP16MixedPrecisionMixin(FP16MixedPrecisionMixin):
return False
class LowLevelZeroOptimizer(ColossalaiOptimizer):
class LowLevelZeroOptimizer(OptimizerWrapper):
"""Optimizer used for ZeRO-1 and ZeRO-2.
"""
@ -77,11 +77,12 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
overlap_communication: bool = False,
partition_grad: bool = False, # stage 2 flag
cpu_offload: bool = False, # cpu offload
grad_accumulate_interval: int = 1,
forced_dtype: Optional[torch.dtype] = None):
assert not (partition_grad and grad_accumulate_interval > 1), \
"gradient accumulation is not compatible with ZeRO-2"
# TODO:
# 1. process group api
# 2. checkpoint IO
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
self._dtype = self.optim.param_groups[0]['params'][0].dtype
self._logger = get_dist_logger()
@ -94,8 +95,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
# grad accumulation
self.require_grad_sync = True
self._accumulate_intervel = grad_accumulate_interval
self._accumulate_step = 0
colo_pg = self._search_colo_process_group()
if isinstance(colo_pg, ProcessGroup):
@ -340,15 +339,15 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
################################
def backward(self, loss, retain_graph=False):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
self._accumulate_step += 1
no_sync = self._accumulate_step < self._accumulate_intervel
with conditional_context(self.no_sync(), enable=no_sync):
loss.backward(retain_graph=retain_graph)
loss.backward(retain_graph=retain_graph)
if no_sync:
if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
@ -385,7 +384,7 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
if not self._accumulate_step == self._accumulate_intervel:
if not self.require_grad_sync:
return
if self.mixed_precision_mixin is not None and self.mixed_precision_mixin.should_skip_step():
@ -393,7 +392,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
if self._verbose:
self._logger.info(f'Found overflow. Skip step')
self.zero_grad()
self._accumulate_step -= 1
return
# record all grads for unscale and clip
@ -463,9 +461,6 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
self.optim.param_groups[group_id]['params'] = self._master_param_groups_of_current_rank[group_id]
# reset accumulate step
self._accumulate_step = 0
#############################
# Mixed Precision Utilities #
#############################

View File

@ -9,6 +9,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.testing import spawn
from colossalai.testing.random import seed_all
from colossalai.utils import conditional_context
from colossalai.zero import LowLevelZeroOptimizer
@ -39,14 +40,12 @@ def exam_zero_1_2_grad_acc():
overlap_communication=True,
initial_scale=32,
clip_grad_norm=1.0,
grad_accumulate_interval=2,
verbose=True)
zero2_optimizer = LowLevelZeroOptimizer(zero2_optimizer,
overlap_communication=True,
partition_grad=True,
initial_scale=32,
clip_grad_norm=1.0,
grad_accumulate_interval=2)
clip_grad_norm=1.0)
# create data
seed_all(2021 + local_rank)
input_data1 = torch.randn(32, 128).cuda()
@ -59,8 +58,11 @@ def exam_zero_1_2_grad_acc():
assert torch.equal(zero1_output, zero2_output)
# zero-dp backward
zero1_optimizer.backward(zero1_output.sum().float())
zero2_optimizer.backward(zero2_output.sum().float())
no_sync = number == 0
with conditional_context(zero1_optimizer.no_sync(), no_sync):
zero1_optimizer.backward(zero1_output.sum().float())
with conditional_context(zero2_optimizer.no_sync(), no_sync):
zero2_optimizer.backward(zero2_output.sum().float())
if check_flag:
for (n, z1p), z2p in zip(zero1_model.named_parameters(), zero2_model.parameters()):
@ -101,8 +103,7 @@ def exam_zero_1_grad_acc():
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=False,
reduce_bucket_size=262144,
clip_grad_norm=1.0,
grad_accumulate_interval=2)
clip_grad_norm=1.0)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
@ -112,20 +113,15 @@ def exam_zero_1_grad_acc():
input_data2 = torch.randn(32, 128).cuda()
def fwd_bwd_func(number, cur_data, check_flag):
# zero-dp forward
zero_output = zero_model(cur_data)
# torch-ddp forward
no_sync = number == 0
# zero1 fwd and bwd
with conditional_context(zero_optimizer.no_sync(), no_sync):
zero_output = zero_model(cur_data)
zero_optimizer.backward(zero_output.sum().float())
# zero-dp backward
zero_optimizer.backward(zero_output.sum().float())
# torch-ddp backward
if number < 1:
with torch_model.no_sync():
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
else:
# torch-ddp fwd and bwd
with conditional_context(torch_model.no_sync(), no_sync):
torch_output = torch_model(cur_data)
assert torch.equal(zero_output, torch_output)
torch_output.sum().backward()
@ -133,7 +129,6 @@ def exam_zero_1_grad_acc():
if check_flag:
# check grad
for (n, p), z1p in zip(torch_model.named_parameters(), zero_model.parameters()):
# print(n, p.shape, torch.max(torch.abs(p.grad - unscale_grad)))
assert torch.equal(p.grad, z1p.grad)
fwd_bwd_func(0, input_data1, True)