[fp8] hotfix backward hook (#6053)

* [fp8] hotfix backward hook

* [fp8] hotfix pipeline loss accumulation
pull/6063/head
Hongxin Liu 3 months ago committed by GitHub
parent c54c4fcd15
commit 13946c4448
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -216,7 +216,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
with self._wait_all_gather():
with self._hook_context():
return super().forward(*args, **kwargs)
def unwrap(self):
@ -229,12 +229,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
for p in self.module.parameters():
wait_all_gather_handle(p)
def _wait_all_gather(self):
return (
ColoParamOpHookManager.use_hooks(*self.op_hooks)
if (self.overlap_allgather or self.use_fp8)
else nullcontext()
)
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
def get_param_info(optim: Optimizer):
@ -317,7 +313,8 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
with self.model._hook_context():
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
@ -540,7 +537,8 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
None
"""
# Call the superclass backward method to compute gradients.
super().backward(loss, *args, **kwargs)
with self.model._hook_context():
super().backward(loss, *args, **kwargs)
if self.model.require_grad_sync:
# If gradient synchronization is required, sync sequence parallelism gradients.
@ -683,6 +681,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
pp_process_group: Optional[ProcessGroup] = None, # if using pp
forced_dtype: Optional[torch.dtype] = None,
overlap_allgather: bool = False,
fp8_communication: bool = False,
):
self.model = model
self.param_info = param_info
@ -712,6 +711,8 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
dp_process_group=dp_process_group,
forced_dtype=forced_dtype,
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
backward_context=model._hook_context,
)
def sync_dp_grads(self):
@ -1206,6 +1207,7 @@ class HybridParallelPlugin(PipelinePluginBase):
partition_grad=(self.zero_stage == 2),
forced_dtype=PRECISION_TORCH_TYPE[precision],
overlap_allgather=overlap_allgather,
fp8_communication=fp8_communication,
)
self.max_norm = max_norm
@ -1371,7 +1373,7 @@ class HybridParallelPlugin(PipelinePluginBase):
# so we disable it, performing manual reduction instead.
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
with ctx, model._wait_all_gather():
with ctx, model._hook_context():
outputs = self.schedule.forward_backward_step(
model, data_iter, criterion, optimizer, return_loss, return_outputs
)

@ -100,14 +100,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext()
with ctx:
with self._hook_context():
return super().forward(*args, **kwargs)
def _force_wait_all_gather(self):
for p in self.module.parameters():
wait_all_gather_handle(p)
def _hook_context(self):
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@ -520,7 +522,7 @@ class LowLevelZeroPlugin(DPPluginBase):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
optimizer, **zero_optim_kwargs, verbose=self.verbose
optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)

@ -9,6 +9,7 @@ import os
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
import torch
import torch.distributed as dist
from colossalai.accelerator import get_accelerator
@ -64,6 +65,11 @@ def launch(
set_seed(seed)
try:
torch._dynamo.config.optimize_ddp = world_size > 1
except AttributeError:
pass
if verbose:
logger = get_dist_logger()
logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])

@ -318,7 +318,7 @@ class InterleavedSchedule(PipelineSchedule):
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatch
if accum_loss is not None:
accum_loss.add_(loss.detach())
accum_loss.add_(loss.data)
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss

@ -273,7 +273,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
accum_loss.add_(loss.data)
if outputs is not None:
outputs.append(tree_map_hf(detach, output_obj))
return loss

@ -1,6 +1,6 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from functools import partial
from typing import Dict, Iterator, List, Optional, Tuple
from weakref import proxy
@ -88,6 +88,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
master_weights: bool = True, # master weights
overlap_allgather: bool = False,
fp8_communication: bool = False,
backward_context=None,
):
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
@ -130,6 +131,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self._reduce_bucket_size = reduce_bucket_size
self._communication_dtype = communication_dtype
self._fp8_communication = fp8_communication
self._backward_context = backward_context
# gradient clipping
self._clip_grad_norm = clip_grad_norm
@ -429,7 +431,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if self.mixed_precision_mixin is not None:
loss = self.mixed_precision_mixin.pre_backward(loss)
loss.backward(retain_graph=retain_graph)
ctx = nullcontext() if self._backward_context is None else self._backward_context()
with ctx:
loss.backward(retain_graph=retain_graph)
if not self.require_grad_sync:
return

Loading…
Cancel
Save