mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix zero ddp buffer cast (#1376)
* fix zero ddp buffer cast * fix zero ddp ignore paramspull/1382/head
parent
5d5031e946
commit
83328329dd
|
@ -192,7 +192,7 @@ class ZeroDDP(ColoDDP):
|
|||
"""
|
||||
|
||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__(module.half(), process_group=gemini_manager.chunk_manager.process_group)
|
||||
super().__init__(module, process_group=gemini_manager.chunk_manager.process_group)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
|
@ -204,13 +204,15 @@ class ZeroDDP(ColoDDP):
|
|||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
if getattr(p, '_ddp_to_ignore', False):
|
||||
p.data = p.half()
|
||||
continue
|
||||
assert p.dtype == torch.half
|
||||
fp32_p = p.float().detach()
|
||||
p.data = p.half()
|
||||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
self._cast_buffers()
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
|
@ -481,3 +483,9 @@ class ZeroDDP(ColoDDP):
|
|||
input_name = key[len(prefix):]
|
||||
if input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def _cast_buffers(self):
|
||||
for buffer in self.module.buffers():
|
||||
buffer.data = buffer.cuda()
|
||||
if torch.is_floating_point(buffer):
|
||||
buffer.data = buffer.half()
|
||||
|
|
Loading…
Reference in New Issue