[zero] ZeroDDP supports controlling outputs' dtype (#1399)

pull/1396/head
ver217 2022-08-02 17:49:11 +08:00 committed by GitHub
parent 4e98e938ce
commit 04c9a86af8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 2 deletions

View File

@ -202,12 +202,17 @@ class ZeroDDP(ColoDDP):
module (torch.nn.Module): Module to apply ZeRO-DP.
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
For more details, see the API reference of ``GeminiManager``.
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
"""
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
def __init__(self,
module: torch.nn.Module,
gemini_manager: GeminiManager,
force_outputs_fp32: bool = False) -> None:
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = ZeROHookV2(gemini_manager)
self.fp32_params: List[ColoParameter] = []
self.overflow_counter = 0
@ -235,7 +240,9 @@ class ZeroDDP(ColoDDP):
with ParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release()
return _cast_float(outputs, torch.float)
if self.force_outputs_fp32:
return _cast_float(outputs, torch.float)
return outputs
def _setup_grads_ptr(self):
for p in self.module.parameters():