mirror of https://github.com/hpcaitech/ColossalAI
[fp8] hotfix backward hook (#6053)
* [fp8] hotfix backward hook * [fp8] hotfix pipeline loss accumulationpull/6063/head
parent
c54c4fcd15
commit
13946c4448
|
@ -216,7 +216,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
if self.convert_fn is not None:
|
if self.convert_fn is not None:
|
||||||
args = tree_map(self.convert_fn, args)
|
args = tree_map(self.convert_fn, args)
|
||||||
kwargs = tree_map(self.convert_fn, kwargs)
|
kwargs = tree_map(self.convert_fn, kwargs)
|
||||||
with self._wait_all_gather():
|
with self._hook_context():
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def unwrap(self):
|
def unwrap(self):
|
||||||
|
@ -229,12 +229,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
wait_all_gather_handle(p)
|
wait_all_gather_handle(p)
|
||||||
|
|
||||||
def _wait_all_gather(self):
|
def _hook_context(self):
|
||||||
return (
|
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
||||||
ColoParamOpHookManager.use_hooks(*self.op_hooks)
|
|
||||||
if (self.overlap_allgather or self.use_fp8)
|
|
||||||
else nullcontext()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_param_info(optim: Optimizer):
|
def get_param_info(optim: Optimizer):
|
||||||
|
@ -317,6 +313,7 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
|
with self.model._hook_context():
|
||||||
super().backward(loss, *args, **kwargs)
|
super().backward(loss, *args, **kwargs)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
|
@ -540,6 +537,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
# Call the superclass backward method to compute gradients.
|
# Call the superclass backward method to compute gradients.
|
||||||
|
with self.model._hook_context():
|
||||||
super().backward(loss, *args, **kwargs)
|
super().backward(loss, *args, **kwargs)
|
||||||
|
|
||||||
if self.model.require_grad_sync:
|
if self.model.require_grad_sync:
|
||||||
|
@ -683,6 +681,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
pp_process_group: Optional[ProcessGroup] = None, # if using pp
|
||||||
forced_dtype: Optional[torch.dtype] = None,
|
forced_dtype: Optional[torch.dtype] = None,
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
|
fp8_communication: bool = False,
|
||||||
):
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.param_info = param_info
|
self.param_info = param_info
|
||||||
|
@ -712,6 +711,8 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
||||||
dp_process_group=dp_process_group,
|
dp_process_group=dp_process_group,
|
||||||
forced_dtype=forced_dtype,
|
forced_dtype=forced_dtype,
|
||||||
overlap_allgather=overlap_allgather,
|
overlap_allgather=overlap_allgather,
|
||||||
|
fp8_communication=fp8_communication,
|
||||||
|
backward_context=model._hook_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
def sync_dp_grads(self):
|
def sync_dp_grads(self):
|
||||||
|
@ -1206,6 +1207,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
partition_grad=(self.zero_stage == 2),
|
partition_grad=(self.zero_stage == 2),
|
||||||
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
forced_dtype=PRECISION_TORCH_TYPE[precision],
|
||||||
overlap_allgather=overlap_allgather,
|
overlap_allgather=overlap_allgather,
|
||||||
|
fp8_communication=fp8_communication,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_norm = max_norm
|
self.max_norm = max_norm
|
||||||
|
@ -1371,7 +1373,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
# so we disable it, performing manual reduction instead.
|
# so we disable it, performing manual reduction instead.
|
||||||
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||||
|
|
||||||
with ctx, model._wait_all_gather():
|
with ctx, model._hook_context():
|
||||||
outputs = self.schedule.forward_backward_step(
|
outputs = self.schedule.forward_backward_step(
|
||||||
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
model, data_iter, criterion, optimizer, return_loss, return_outputs
|
||||||
)
|
)
|
||||||
|
|
|
@ -100,14 +100,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||||
if self.convert_fn is not None:
|
if self.convert_fn is not None:
|
||||||
args = tree_map(self.convert_fn, args)
|
args = tree_map(self.convert_fn, args)
|
||||||
kwargs = tree_map(self.convert_fn, kwargs)
|
kwargs = tree_map(self.convert_fn, kwargs)
|
||||||
ctx = ColoParamOpHookManager.use_hooks(*self.op_hooks) if self.overlap_allgather else nullcontext()
|
with self._hook_context():
|
||||||
with ctx:
|
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
def _force_wait_all_gather(self):
|
def _force_wait_all_gather(self):
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
wait_all_gather_handle(p)
|
wait_all_gather_handle(p)
|
||||||
|
|
||||||
|
def _hook_context(self):
|
||||||
|
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext()
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||||
|
@ -520,7 +522,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
|
|
||||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||||
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
|
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
|
||||||
optimizer, **zero_optim_kwargs, verbose=self.verbose
|
optimizer, **zero_optim_kwargs, verbose=self.verbose, backward_context=model._hook_context
|
||||||
)
|
)
|
||||||
# inject update_master_params
|
# inject update_master_params
|
||||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||||
|
|
|
@ -9,6 +9,7 @@ import os
|
||||||
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
|
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
|
||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
|
@ -64,6 +65,11 @@ def launch(
|
||||||
|
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch._dynamo.config.optimize_ddp = world_size > 1
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
|
logger.info(f"Distributed environment is initialized, world size: {dist.get_world_size()}", ranks=[0])
|
||||||
|
|
|
@ -318,7 +318,7 @@ class InterleavedSchedule(PipelineSchedule):
|
||||||
if self.stage_manager.is_last_stage():
|
if self.stage_manager.is_last_stage():
|
||||||
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
loss = criterion(output_obj, micro_batch) / self.num_microbatch
|
||||||
if accum_loss is not None:
|
if accum_loss is not None:
|
||||||
accum_loss.add_(loss.detach())
|
accum_loss.add_(loss.data)
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs.append(tree_map(detach, output_obj))
|
outputs.append(tree_map(detach, output_obj))
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -273,7 +273,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
loss = criterion(output_obj, micro_batch) / self.num_microbatches
|
||||||
|
|
||||||
if accum_loss is not None:
|
if accum_loss is not None:
|
||||||
accum_loss.add_(loss.detach())
|
accum_loss.add_(loss.data)
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs.append(tree_map_hf(detach, output_obj))
|
outputs.append(tree_map_hf(detach, output_obj))
|
||||||
return loss
|
return loss
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
|
||||||
import copy
|
import copy
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager, nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, Iterator, List, Optional, Tuple
|
from typing import Dict, Iterator, List, Optional, Tuple
|
||||||
from weakref import proxy
|
from weakref import proxy
|
||||||
|
@ -88,6 +88,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
master_weights: bool = True, # master weights
|
master_weights: bool = True, # master weights
|
||||||
overlap_allgather: bool = False,
|
overlap_allgather: bool = False,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
|
backward_context=None,
|
||||||
):
|
):
|
||||||
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
super(LowLevelZeroOptimizer, self).__init__(optim=optimizer)
|
||||||
|
|
||||||
|
@ -130,6 +131,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
self._reduce_bucket_size = reduce_bucket_size
|
self._reduce_bucket_size = reduce_bucket_size
|
||||||
self._communication_dtype = communication_dtype
|
self._communication_dtype = communication_dtype
|
||||||
self._fp8_communication = fp8_communication
|
self._fp8_communication = fp8_communication
|
||||||
|
self._backward_context = backward_context
|
||||||
|
|
||||||
# gradient clipping
|
# gradient clipping
|
||||||
self._clip_grad_norm = clip_grad_norm
|
self._clip_grad_norm = clip_grad_norm
|
||||||
|
@ -429,6 +431,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
if self.mixed_precision_mixin is not None:
|
if self.mixed_precision_mixin is not None:
|
||||||
loss = self.mixed_precision_mixin.pre_backward(loss)
|
loss = self.mixed_precision_mixin.pre_backward(loss)
|
||||||
|
|
||||||
|
ctx = nullcontext() if self._backward_context is None else self._backward_context()
|
||||||
|
with ctx:
|
||||||
loss.backward(retain_graph=retain_graph)
|
loss.backward(retain_graph=retain_graph)
|
||||||
|
|
||||||
if not self.require_grad_sync:
|
if not self.require_grad_sync:
|
||||||
|
|
Loading…
Reference in New Issue