mirror of https://github.com/hpcaitech/ColossalAI
fix zero ddp state dict (#1378)
parent
0c1a16ea5b
commit
5d5031e946
|
@ -314,14 +314,18 @@ class ZeroDDP(ColoDDP):
|
|||
module
|
||||
"""
|
||||
chunks = self.chunk_manager.get_chunks(self.fp32_params)
|
||||
chunks_orig_device_type = []
|
||||
for chunk in chunks:
|
||||
chunks_orig_device_type.append(chunk.device_type)
|
||||
self.chunk_manager.access_chunk(chunk)
|
||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||
if p is not None:
|
||||
rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu()
|
||||
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
|
||||
for chunk in chunks:
|
||||
for orig_dvice_type, chunk in zip(chunks_orig_device_type, chunks):
|
||||
self.chunk_manager.release_chunk(chunk)
|
||||
if not chunk.is_empty and orig_dvice_type == 'cpu':
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
for name, buf in self.named_buffers():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
|
|
Loading…
Reference in New Issue