mirror of https://github.com/hpcaitech/ColossalAI
cast colo ddp v2 inputs/outputs (#1120)
parent
16302a5359
commit
e127b4375b
|
@ -20,6 +20,16 @@ def free_storage(data: torch.Tensor) -> None:
|
|||
data.storage().resize_(0)
|
||||
|
||||
|
||||
def _cast_float(args, dtype: torch.dtype):
|
||||
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
|
||||
args = args.to(dtype)
|
||||
elif isinstance(args, (list, tuple)):
|
||||
args = type(args)(_cast_float(t, dtype) for t in args)
|
||||
elif isinstance(args, dict):
|
||||
args = {k: _cast_float(v, dtype) for k, v in args.items()}
|
||||
return args
|
||||
|
||||
|
||||
class ColoDDP(torch.nn.Module):
|
||||
|
||||
def __init__(self, module: torch.nn.Module) -> None:
|
||||
|
@ -93,7 +103,7 @@ class ColoDDP(torch.nn.Module):
|
|||
class ColoDDPV2(ColoDDP):
|
||||
|
||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__(module)
|
||||
super().__init__(module.half())
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
|
@ -113,12 +123,13 @@ class ColoDDPV2(ColoDDP):
|
|||
self._logger = get_dist_logger()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter()
|
||||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
return outputs
|
||||
return _cast_float(outputs, torch.float)
|
||||
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
|
|
Loading…
Reference in New Issue