From 83328329dd3ef9437e02c1b13d84dac90d4b6b0a Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 28 Jul 2022 10:54:44 +0800 Subject: [PATCH] [hotfix] fix zero ddp buffer cast (#1376) * fix zero ddp buffer cast * fix zero ddp ignore params --- colossalai/nn/parallel/data_parallel.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index 31a9e5627..9aca524e9 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -192,7 +192,7 @@ class ZeroDDP(ColoDDP): """ def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None: - super().__init__(module.half(), process_group=gemini_manager.chunk_manager.process_group) + super().__init__(module, process_group=gemini_manager.chunk_manager.process_group) self.gemini_manager = gemini_manager self.chunk_manager = gemini_manager.chunk_manager self.param_op_hook = ZeROHookV2(gemini_manager) @@ -204,13 +204,15 @@ class ZeroDDP(ColoDDP): # TODO: get param order and filter unused params for p in module.parameters(): if getattr(p, '_ddp_to_ignore', False): + p.data = p.half() continue - assert p.dtype == torch.half fp32_p = p.float().detach() + p.data = p.half() self.chunk_manager.append_tensor(p, 'fp16_param') self.chunk_manager.append_tensor(fp32_p, 'fp32_param') self.fp32_params.append(fp32_p) self.grads_device[p] = self.gemini_manager.default_device + self._cast_buffers() self._logger = get_dist_logger() def forward(self, *args, **kwargs): @@ -481,3 +483,9 @@ class ZeroDDP(ColoDDP): input_name = key[len(prefix):] if input_name not in local_state: unexpected_keys.append(key) + + def _cast_buffers(self): + for buffer in self.module.buffers(): + buffer.data = buffer.cuda() + if torch.is_floating_point(buffer): + buffer.data = buffer.half()