|
|
|
@ -6,12 +6,13 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
|
|
|
|
from colossalai.gemini.chunk import TensorState, Chunk |
|
|
|
|
from colossalai.tensor.param_op_hook import ParamOpHookManager |
|
|
|
|
from colossalai.gemini.gemini_mgr import GeminiManager |
|
|
|
|
from typing import Dict, Iterable, List, Optional |
|
|
|
|
from typing import Dict, Iterable, List, Optional, Set |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from collections import OrderedDict |
|
|
|
|
from colossalai.tensor.colo_parameter import ColoParameter |
|
|
|
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup |
|
|
|
|
from .reducer import Reducer |
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys |
|
|
|
|
except ImportError: |
|
|
|
@ -84,6 +85,18 @@ class ColoDDP(torch.nn.Module):
|
|
|
|
|
def named_parameters(self, prefix: str = '', recurse: bool = True): |
|
|
|
|
return self.module.named_parameters(prefix, recurse) |
|
|
|
|
|
|
|
|
|
def named_buffers(self, prefix: str = '', recurse: bool = True): |
|
|
|
|
return self.module.named_buffers(prefix, recurse) |
|
|
|
|
|
|
|
|
|
def named_children(self): |
|
|
|
|
return self.module.named_children() |
|
|
|
|
|
|
|
|
|
def named_modules(self, |
|
|
|
|
memo: Optional[Set[torch.nn.Module]] = None, |
|
|
|
|
prefix: str = '', |
|
|
|
|
remove_duplicate: bool = True): |
|
|
|
|
return self.module.named_modules(memo, prefix, remove_duplicate) |
|
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
|
|
self.module.zero_grad(set_to_none=True) |
|
|
|
|
return self.module(*args, **kwargs) |
|
|
|
@ -274,7 +287,7 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
for tensor in chunk.get_tensors(): |
|
|
|
|
self.grads_device[tensor] = device |
|
|
|
|
|
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False): |
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True): |
|
|
|
|
r"""Returns a dictionary containing a whole state of the module. |
|
|
|
|
|
|
|
|
|
Both parameters and persistent buffers (e.g. running averages) are |
|
|
|
@ -291,18 +304,22 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
['bias', 'weight'] |
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0 |
|
|
|
|
record_flag = (not only_rank_0) or is_rank_0 |
|
|
|
|
|
|
|
|
|
if destination is None: |
|
|
|
|
destination = OrderedDict() |
|
|
|
|
destination._metadata = OrderedDict() |
|
|
|
|
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) |
|
|
|
|
self._save_to_state_dict(destination, prefix, keep_vars) |
|
|
|
|
self._save_to_state_dict(destination, prefix, keep_vars, record_flag) |
|
|
|
|
|
|
|
|
|
for hook in self._state_dict_hooks.values(): |
|
|
|
|
hook_result = hook(self, destination, prefix, local_metadata) |
|
|
|
|
if hook_result is not None: |
|
|
|
|
destination = hook_result |
|
|
|
|
return destination |
|
|
|
|
|
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars): |
|
|
|
|
def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True): |
|
|
|
|
r"""Saves module state to `destination` dictionary, containing a state |
|
|
|
|
of the module, but not its descendants. This is called on every |
|
|
|
|
submodule in :meth:`~torch.nn.Module.state_dict`. |
|
|
|
@ -315,22 +332,36 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
prefix (str): the prefix for parameters and buffers used in this |
|
|
|
|
module |
|
|
|
|
""" |
|
|
|
|
chunks = self.chunk_manager.get_chunks(self.fp32_params) |
|
|
|
|
chunks_orig_device_type = [] |
|
|
|
|
for chunk in chunks: |
|
|
|
|
chunks_orig_device_type.append(chunk.device_type) |
|
|
|
|
# save parameters |
|
|
|
|
param_to_save_data = dict() |
|
|
|
|
chunk_list = self.chunk_manager.get_chunks(self.fp32_params) |
|
|
|
|
for chunk in chunk_list: |
|
|
|
|
# record the original device of the chunk |
|
|
|
|
org_chunk_dev_typ = chunk.device_type |
|
|
|
|
self.chunk_manager.access_chunk(chunk) |
|
|
|
|
|
|
|
|
|
for tensor in chunk.get_tensors(): |
|
|
|
|
rec_p = torch.empty([0]) |
|
|
|
|
if record_flag: |
|
|
|
|
rec_p = tensor.cpu() # move the whole tensor to CPU mem |
|
|
|
|
assert tensor not in param_to_save_data |
|
|
|
|
param_to_save_data[tensor] = rec_p |
|
|
|
|
# release the actual memory of the chunk |
|
|
|
|
self.chunk_manager.release_chunk(chunk) |
|
|
|
|
if not chunk.is_empty and org_chunk_dev_typ == 'cpu': |
|
|
|
|
self.chunk_manager.move_chunk(chunk, torch.device('cpu')) |
|
|
|
|
|
|
|
|
|
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): |
|
|
|
|
if p is not None: |
|
|
|
|
rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu() |
|
|
|
|
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name) |
|
|
|
|
rec_p = param_to_save_data[fp32_p] |
|
|
|
|
destination[prefix + name] = rec_p if keep_vars else rec_p.detach() |
|
|
|
|
for orig_dvice_type, chunk in zip(chunks_orig_device_type, chunks): |
|
|
|
|
self.chunk_manager.release_chunk(chunk) |
|
|
|
|
if not chunk.is_empty and orig_dvice_type == 'cpu': |
|
|
|
|
self.chunk_manager.move_chunk(chunk, torch.device('cpu')) |
|
|
|
|
|
|
|
|
|
# save all buffers |
|
|
|
|
for name, buf in self.named_buffers(): |
|
|
|
|
if buf is not None and name not in self._non_persistent_buffers_set: |
|
|
|
|
destination[prefix + name] = buf if keep_vars else buf.detach() |
|
|
|
|
# save extra states |
|
|
|
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX |
|
|
|
|
if getattr(self.__class__, "get_extra_state", |
|
|
|
|
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: |
|
|
|
@ -368,7 +399,7 @@ class ZeroDDP(ColoDDP):
|
|
|
|
|
state_dict = state_dict.copy() |
|
|
|
|
if metadata is not None: |
|
|
|
|
# mypy isn't aware that "_metadata" exists in state_dict |
|
|
|
|
state_dict._metadata = metadata # type: ignore[attr-defined] |
|
|
|
|
state_dict._metadata = metadata # type: ignore[attr-defined] |
|
|
|
|
|
|
|
|
|
prefix = '' |
|
|
|
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
|
|
|
|