From e127b4375b8ba6283dd84f9cb572511862f67d00 Mon Sep 17 00:00:00 2001 From: ver217 Date: Wed, 15 Jun 2022 15:57:04 +0800 Subject: [PATCH] cast colo ddp v2 inputs/outputs (#1120) --- colossalai/nn/parallel/data_parallel.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 823c355f4..7db3da498 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -20,6 +20,16 @@ def free_storage(data: torch.Tensor) -> None: data.storage().resize_(0) +def _cast_float(args, dtype: torch.dtype): + if isinstance(args, torch.Tensor) and torch.is_floating_point(args): + args = args.to(dtype) + elif isinstance(args, (list, tuple)): + args = type(args)(_cast_float(t, dtype) for t in args) + elif isinstance(args, dict): + args = {k: _cast_float(v, dtype) for k, v in args.items()} + return args + + class ColoDDP(torch.nn.Module): def __init__(self, module: torch.nn.Module) -> None: @@ -93,7 +103,7 @@ class ColoDDP(torch.nn.Module): class ColoDDPV2(ColoDDP): def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None: - super().__init__(module) + super().__init__(module.half()) self.gemini_manager = gemini_manager self.chunk_manager = gemini_manager.chunk_manager self.param_op_hook = ZeROHookV2(gemini_manager) @@ -113,12 +123,13 @@ class ColoDDPV2(ColoDDP): self._logger = get_dist_logger() def forward(self, *args, **kwargs): + args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) self.module.zero_grad(set_to_none=True) self.gemini_manager.pre_iter() with ParamOpHookManager.use_hooks(self.param_op_hook): outputs = self.module(*args, **kwargs) self.chunk_manager.exec_lazy_release() - return outputs + return _cast_float(outputs, torch.float) def _setup_grads_ptr(self): for p in self.module.parameters():