mirror of https://github.com/hpcaitech/ColossalAI
[zero] ZeroDDP supports controlling outputs' dtype (#1399)
parent
4e98e938ce
commit
04c9a86af8
|
@ -202,12 +202,17 @@ class ZeroDDP(ColoDDP):
|
|||
module (torch.nn.Module): Module to apply ZeRO-DP.
|
||||
gemini_manager (GeminiManager): Manages the chunk manager and heterogeneous momery space.
|
||||
For more details, see the API reference of ``GeminiManager``.
|
||||
force_outputs_fp32 (bool): If set to True, outputs will be fp32. Otherwise, outputs will be fp16. Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
force_outputs_fp32: bool = False) -> None:
|
||||
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
self.fp32_params: List[ColoParameter] = []
|
||||
self.overflow_counter = 0
|
||||
|
@ -235,7 +240,9 @@ class ZeroDDP(ColoDDP):
|
|||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
return _cast_float(outputs, torch.float)
|
||||
if self.force_outputs_fp32:
|
||||
return _cast_float(outputs, torch.float)
|
||||
return outputs
|
||||
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
|
|
Loading…
Reference in New Issue