[hotfix] fix zero ddp buffer cast (#1376)

* fix zero ddp buffer cast

* fix zero ddp ignore params
pull/1382/head
ver217 2022-07-28 10:54:44 +08:00 committed by GitHub
parent 5d5031e946
commit 83328329dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 2 deletions

View File

@ -192,7 +192,7 @@ class ZeroDDP(ColoDDP):
"""
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
super().__init__(module.half(), process_group=gemini_manager.chunk_manager.process_group)
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager)
@ -204,13 +204,15 @@ class ZeroDDP(ColoDDP):
# TODO: get param order and filter unused params
for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False):
p.data = p.half()
continue
assert p.dtype == torch.half
fp32_p = p.float().detach()
p.data = p.half()
self.chunk_manager.append_tensor(p, 'fp16_param')
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
self.fp32_params.append(fp32_p)
self.grads_device[p] = self.gemini_manager.default_device
self._cast_buffers()
self._logger = get_dist_logger()
def forward(self, *args, **kwargs):
@ -481,3 +483,9 @@ class ZeroDDP(ColoDDP):
input_name = key[len(prefix):]
if input_name not in local_state:
unexpected_keys.append(key)
def _cast_buffers(self):
for buffer in self.module.buffers():
buffer.data = buffer.cuda()
if torch.is_floating_point(buffer):
buffer.data = buffer.half()