mirror of https://github.com/hpcaitech/ColossalAI
cast colo ddp v2 inputs/outputs (#1120)
parent
16302a5359
commit
e127b4375b
|
@ -20,6 +20,16 @@ def free_storage(data: torch.Tensor) -> None:
|
||||||
data.storage().resize_(0)
|
data.storage().resize_(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _cast_float(args, dtype: torch.dtype):
|
||||||
|
if isinstance(args, torch.Tensor) and torch.is_floating_point(args):
|
||||||
|
args = args.to(dtype)
|
||||||
|
elif isinstance(args, (list, tuple)):
|
||||||
|
args = type(args)(_cast_float(t, dtype) for t in args)
|
||||||
|
elif isinstance(args, dict):
|
||||||
|
args = {k: _cast_float(v, dtype) for k, v in args.items()}
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
class ColoDDP(torch.nn.Module):
|
class ColoDDP(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module) -> None:
|
def __init__(self, module: torch.nn.Module) -> None:
|
||||||
|
@ -93,7 +103,7 @@ class ColoDDP(torch.nn.Module):
|
||||||
class ColoDDPV2(ColoDDP):
|
class ColoDDPV2(ColoDDP):
|
||||||
|
|
||||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||||
super().__init__(module)
|
super().__init__(module.half())
|
||||||
self.gemini_manager = gemini_manager
|
self.gemini_manager = gemini_manager
|
||||||
self.chunk_manager = gemini_manager.chunk_manager
|
self.chunk_manager = gemini_manager.chunk_manager
|
||||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||||
|
@ -113,12 +123,13 @@ class ColoDDPV2(ColoDDP):
|
||||||
self._logger = get_dist_logger()
|
self._logger = get_dist_logger()
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
|
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||||
self.module.zero_grad(set_to_none=True)
|
self.module.zero_grad(set_to_none=True)
|
||||||
self.gemini_manager.pre_iter()
|
self.gemini_manager.pre_iter()
|
||||||
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
with ParamOpHookManager.use_hooks(self.param_op_hook):
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
self.chunk_manager.exec_lazy_release()
|
self.chunk_manager.exec_lazy_release()
|
||||||
return outputs
|
return _cast_float(outputs, torch.float)
|
||||||
|
|
||||||
def _setup_grads_ptr(self):
|
def _setup_grads_ptr(self):
|
||||||
for p in self.module.parameters():
|
for p in self.module.parameters():
|
||||||
|
|
Loading…
Reference in New Issue