cast colo ddp v2 inputs/outputs (#1120)

pull/1123/head
ver217 2022-06-15 15:57:04 +08:00 committed by GitHub
parent 16302a5359
commit e127b4375b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 2 deletions

View File

@ -20,6 +20,16 @@ def free_storage(data: torch.Tensor) -> None:
data.storage().resize_(0)
def _cast_float(args, dtype: torch.dtype):
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
args = args.to(dtype)
elif isinstance(args, (list, tuple)):
args = type(args)(_cast_float(t, dtype) for t in args)
elif isinstance(args, dict):
args = {k: _cast_float(v, dtype) for k, v in args.items()}
return args
class ColoDDP(torch.nn.Module):
def __init__(self, module: torch.nn.Module) -> None:
@ -93,7 +103,7 @@ class ColoDDP(torch.nn.Module):
class ColoDDPV2(ColoDDP):
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module)
super().__init__(module.half())
self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager)
@ -113,12 +123,13 @@ class ColoDDPV2(ColoDDP):
self._logger = get_dist_logger()
def forward(self, *args, **kwargs):
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)
self.gemini_manager.pre_iter()
with ParamOpHookManager.use_hooks(self.param_op_hook):
outputs = self.module(*args, **kwargs)
self.chunk_manager.exec_lazy_release()
return outputs
return _cast_float(outputs, torch.float)
def _setup_grads_ptr(self):
for p in self.module.parameters():