[fp8] hotfix backward hook (#6053)

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation
pull/6063/head
Hongxin Liu 3 months ago committed by GitHub
parent c54c4fcd15
commit 13946c4448
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -216,7 +216,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None: if self.convert_fn is not None:
args = tree_map(self.convert_fn, args) args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs) kwargs = tree_map(self.convert_fn, kwargs)
with self._wait_all_gather(): with self._hook_context():
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
def unwrap(self): def unwrap(self):
@ -229,12 +229,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
for p in self.module.parameters(): for p in self.module.parameters():
wait_all_gather_handle(p) wait_all_gather_handle(p)
def _wait_all_gather(self): def _hook_context(self):
return ( return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
ColoParamOpHookManager.use_hooks(*self.op_hooks)
if (self.overlap_allgather or self.use_fp8)
else nullcontext()
)
def get_param_info(optim: Optimizer): def get_param_info(optim: Optimizer):
@ -317,7 +313,8 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs) with self.model._hook_context():
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -540,7 +537,8 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
None None
""" """
# Call the superclass backward method to compute gradients. # Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs) with self.model._hook_context():
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync: if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients. # If gradient synchronization is required, sync sequence parallelism gradients.
@ -683,6 +681,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
pp_process_group: Optional[ProcessGroup] = None, # if using pp pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None, forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False, overlap_allgather: bool = False,
fp8_communication: bool = False,
): ):
self.model = model self.model = model
self.param_info = param_info self.param_info = param_info
@ -712,6 +711,8 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
dp_process_group=dp_process_group, dp_process_group=dp_process_group,
forced_dtype=forced_dtype, forced_dtype=forced_dtype,
overlap_allgather=overlap_allgather, overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
backward_context=model._hook_context,
) )
def sync_dp_grads(self): def sync_dp_grads(self):
@ -1206,6 +1207,7 @@ class HybridParallelPlugin(PipelinePluginBase):
partition_grad=(self.zero_stage == 2), partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision], forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather, overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
) )
self.max_norm = max_norm self.max_norm = max_norm
@ -1371,7 +1373,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# so we disable it, performing manual reduction instead. # so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx, model._wait_all_gather(): with ctx, model._hook_context():
outputs = self.schedule.forward_backward_step( outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs model, data_iter, criterion, optimizer, return_loss, return_outputs
) )

@ -100,14 +100,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None: if self.convert_fn is not None:
args = tree_map(self.convert_fn, args) args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs) kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext() with self._hook_context():
with ctx:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
def _force_wait_all_gather(self): def _force_wait_all_gather(self):
for p in self.module.parameters(): for p in self.module.parameters():
wait_all_gather_handle(p) wait_all_gather_handle(p)
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@ -520,7 +522,7 @@ class LowLevelZeroPlugin(DPPluginBase):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer( optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **zero_optim_kwargs, verbose=self.verbose optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context
) )
# inject update_master_params # inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model) model.update_master_params = MethodType(optimizer.update_master_params, model)

@ -9,6 +9,7 @@ import os
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16 # https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
@ -64,6 +65,11 @@ def launch(
set_seed(seed) set_seed(seed)
try:
torch._dynamo.config.optimize_ddp = world_size > 1
except AttributeError:
pass
if verbose: if verbose:
logger = get_dist_logger() logger = get_dist_logger()
logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0]) logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])

@ -318,7 +318,7 @@ class InterleavedSchedule(PipelineSchedule):
if self.stage_manager.is_last_stage(): if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatch loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.data)
if outputs is not None: if outputs is not None:
outputs.append(tree_map(detach, output_obj)) outputs.append(tree_map(detach, output_obj))
return loss return loss

@ -273,7 +273,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
loss = criterion(output_obj, micro_batch) / self.num_microbatches loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None: if accum_loss is not None:
accum_loss.add_(loss.detach()) accum_loss.add_(loss.data)
if outputs is not None: if outputs is not None:
outputs.append(tree_map_hf(detach, output_obj)) outputs.append(tree_map_hf(detach, output_obj))
return loss return loss

@ -1,6 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch # this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy import copy
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from functools import partial from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple from typing import Dict, Iterator, List, Optional, Tuple
from weakref import proxy from weakref import proxy
@ -88,6 +88,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_weights: bool = True, # master weights master_weights: bool = True, # master weights
overlap_allgather: bool = False, overlap_allgather: bool = False,
fp8_communication: bool = False, fp8_communication: bool = False,
backward_context=None,
): ):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer) super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -130,6 +131,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._reduce_bucket_size = reduce_bucket_size self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype self._communication_dtype = communication_dtype
self._fp8_communication = fp8_communication self._fp8_communication = fp8_communication
self._backward_context = backward_context
# gradient clipping # gradient clipping
self._clip_grad_norm = clip_grad_norm self._clip_grad_norm = clip_grad_norm
@ -429,7 +431,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.mixed_precision_mixin is not None: if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss) loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph) ctx = nullcontext() if self._backward_context is None else self._backward_context()
with ctx:
loss.backward(retain_graph=retain_graph)
if not self.require_grad_sync: if not self.require_grad_sync:
return return

Loading…
Cancel
Save