|
|
|
@ -216,7 +216,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|
|
|
|
if self.convert_fn is not None: |
|
|
|
|
args = tree_map(self.convert_fn, args) |
|
|
|
|
kwargs = tree_map(self.convert_fn, kwargs) |
|
|
|
|
with self._wait_all_gather(): |
|
|
|
|
with self._hook_context(): |
|
|
|
|
return super().forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
def unwrap(self): |
|
|
|
@ -229,12 +229,8 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|
|
|
|
for p in self.module.parameters(): |
|
|
|
|
wait_all_gather_handle(p) |
|
|
|
|
|
|
|
|
|
def _wait_all_gather(self): |
|
|
|
|
return ( |
|
|
|
|
ColoParamOpHookManager.use_hooks(*self.op_hooks) |
|
|
|
|
if (self.overlap_allgather or self.use_fp8) |
|
|
|
|
else nullcontext() |
|
|
|
|
) |
|
|
|
|
def _hook_context(self): |
|
|
|
|
return ColoParamOpHookManager.use_hooks(*self.op_hooks) if len(self.op_hooks) > 0 else nullcontext() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_param_info(optim: Optimizer): |
|
|
|
@ -317,7 +313,8 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
# Call the superclass backward method to compute gradients. |
|
|
|
|
super().backward(loss, *args, **kwargs) |
|
|
|
|
with self.model._hook_context(): |
|
|
|
|
super().backward(loss, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
if self.model.require_grad_sync: |
|
|
|
|
# If gradient synchronization is required, sync sequence parallelism gradients. |
|
|
|
@ -540,7 +537,8 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
|
|
|
|
None |
|
|
|
|
""" |
|
|
|
|
# Call the superclass backward method to compute gradients. |
|
|
|
|
super().backward(loss, *args, **kwargs) |
|
|
|
|
with self.model._hook_context(): |
|
|
|
|
super().backward(loss, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
if self.model.require_grad_sync: |
|
|
|
|
# If gradient synchronization is required, sync sequence parallelism gradients. |
|
|
|
@ -683,6 +681,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|
|
|
|
pp_process_group: Optional[ProcessGroup] = None, # if using pp |
|
|
|
|
forced_dtype: Optional[torch.dtype] = None, |
|
|
|
|
overlap_allgather: bool = False, |
|
|
|
|
fp8_communication: bool = False, |
|
|
|
|
): |
|
|
|
|
self.model = model |
|
|
|
|
self.param_info = param_info |
|
|
|
@ -712,6 +711,8 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
|
|
|
|
|
dp_process_group=dp_process_group, |
|
|
|
|
forced_dtype=forced_dtype, |
|
|
|
|
overlap_allgather=overlap_allgather, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
backward_context=model._hook_context, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
def sync_dp_grads(self): |
|
|
|
@ -1206,6 +1207,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|
|
|
|
partition_grad=(self.zero_stage == 2), |
|
|
|
|
forced_dtype=PRECISION_TORCH_TYPE[precision], |
|
|
|
|
overlap_allgather=overlap_allgather, |
|
|
|
|
fp8_communication=fp8_communication, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
self.max_norm = max_norm |
|
|
|
@ -1371,7 +1373,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|
|
|
|
# so we disable it, performing manual reduction instead. |
|
|
|
|
ctx = optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync() |
|
|
|
|
|
|
|
|
|
with ctx, model._wait_all_gather(): |
|
|
|
|
with ctx, model._hook_context(): |
|
|
|
|
outputs = self.schedule.forward_backward_step( |
|
|
|
|
model, data_iter, criterion, optimizer, return_loss, return_outputs |
|
|
|
|
) |
|
|
|
|