fix zero ddp state dict (#1378)

pull/1382/head
ver217 2022-07-28 09:31:42 +08:00 committed by GitHub
parent 0c1a16ea5b
commit 5d5031e946
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -314,14 +314,18 @@ class ZeroDDP(ColoDDP):
module
"""
chunks = self.chunk_manager.get_chunks(self.fp32_params)
chunks_orig_device_type = []
for chunk in chunks:
chunks_orig_device_type.append(chunk.device_type)
self.chunk_manager.access_chunk(chunk)
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None:
rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu()
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
for chunk in chunks:
for orig_dvice_type, chunk in zip(chunks_orig_device_type, chunks):
self.chunk_manager.release_chunk(chunk)
if not chunk.is_empty and orig_dvice_type == 'cpu':
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()