[zero] alleviate memory usage in ZeRODDP state_dict (#1398)

pull/1396/head
HELSON 2 years ago committed by GitHub
parent 4f5f8f77d1
commit 4e98e938ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -6,12 +6,13 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from colossalai.gemini.chunk import TensorState, Chunk from colossalai.gemini.chunk import TensorState, Chunk
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Dict, Iterable, List, Optional from typing import Dict, Iterable, List, Optional, Set
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from collections import OrderedDict from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .reducer import Reducer from .reducer import Reducer
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError: except ImportError:
@ -84,6 +85,18 @@ class ColoDDP(torch.nn.Module):
def named_parameters(self, prefix: str = '', recurse: bool = True): def named_parameters(self, prefix: str = '', recurse: bool = True):
return self.module.named_parameters(prefix, recurse) return self.module.named_parameters(prefix, recurse)
def named_buffers(self, prefix: str = '', recurse: bool = True):
return self.module.named_buffers(prefix, recurse)
def named_children(self):
return self.module.named_children()
def named_modules(self,
memo: Optional[Set[torch.nn.Module]] = None,
prefix: str = '',
remove_duplicate: bool = True):
return self.module.named_modules(memo, prefix, remove_duplicate)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
self.module.zero_grad(set_to_none=True) self.module.zero_grad(set_to_none=True)
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
@ -274,7 +287,7 @@ class ZeroDDP(ColoDDP):
for tensor in chunk.get_tensors(): for tensor in chunk.get_tensors():
self.grads_device[tensor] = device self.grads_device[tensor] = device
def state_dict(self, destination=None, prefix='', keep_vars=False): def state_dict(self, destination=None, prefix='', keep_vars=False, only_rank_0: bool = True):
r"""Returns a dictionary containing a whole state of the module. r"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are Both parameters and persistent buffers (e.g. running averages) are
@ -291,18 +304,22 @@ class ZeroDDP(ColoDDP):
['bias', 'weight'] ['bias', 'weight']
""" """
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
record_flag = (not only_rank_0) or is_rank_0
if destination is None: if destination is None:
destination = OrderedDict() destination = OrderedDict()
destination._metadata = OrderedDict() destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version) destination._metadata[prefix[:-1]] = local_metadata = dict(version=self._version)
self._save_to_state_dict(destination, prefix, keep_vars) self._save_to_state_dict(destination, prefix, keep_vars, record_flag)
for hook in self._state_dict_hooks.values(): for hook in self._state_dict_hooks.values():
hook_result = hook(self, destination, prefix, local_metadata) hook_result = hook(self, destination, prefix, local_metadata)
if hook_result is not None: if hook_result is not None:
destination = hook_result destination = hook_result
return destination return destination
def _save_to_state_dict(self, destination, prefix, keep_vars): def _save_to_state_dict(self, destination, prefix, keep_vars, record_flag: bool = True):
r"""Saves module state to `destination` dictionary, containing a state r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`. submodule in :meth:`~torch.nn.Module.state_dict`.
@ -315,22 +332,36 @@ class ZeroDDP(ColoDDP):
prefix (str): the prefix for parameters and buffers used in this prefix (str): the prefix for parameters and buffers used in this
module module
""" """
chunks = self.chunk_manager.get_chunks(self.fp32_params) # save parameters
chunks_orig_device_type = [] param_to_save_data = dict()
for chunk in chunks: chunk_list = self.chunk_manager.get_chunks(self.fp32_params)
chunks_orig_device_type.append(chunk.device_type) for chunk in chunk_list:
# record the original device of the chunk
org_chunk_dev_typ = chunk.device_type
self.chunk_manager.access_chunk(chunk) self.chunk_manager.access_chunk(chunk)
for tensor in chunk.get_tensors():
rec_p = torch.empty([0])
if record_flag:
rec_p = tensor.cpu() # move the whole tensor to CPU mem
assert tensor not in param_to_save_data
param_to_save_data[tensor] = rec_p
# release the actual memory of the chunk
self.chunk_manager.release_chunk(chunk)
if not chunk.is_empty and org_chunk_dev_typ == 'cpu':
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params): for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
if p is not None: if p is not None:
rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu() assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
rec_p = param_to_save_data[fp32_p]
destination[prefix + name] = rec_p if keep_vars else rec_p.detach() destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
for orig_dvice_type, chunk in zip(chunks_orig_device_type, chunks):
self.chunk_manager.release_chunk(chunk) # save all buffers
if not chunk.is_empty and orig_dvice_type == 'cpu':
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
for name, buf in self.named_buffers(): for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set: if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach() destination[prefix + name] = buf if keep_vars else buf.detach()
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state: torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
@ -368,7 +399,7 @@ class ZeroDDP(ColoDDP):
state_dict = state_dict.copy() state_dict = state_dict.copy()
if metadata is not None: if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict # mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined] state_dict._metadata = metadata # type: ignore[attr-defined]
prefix = '' prefix = ''
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})

@ -1,3 +1,5 @@
import copy
import pytest import pytest
import colossalai import colossalai
import torch import torch
@ -11,9 +13,9 @@ from functools import partial
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP, ColoDDP from colossalai.nn.parallel import ZeroDDP, ColoDDP
from colossalai.gemini.gemini_mgr import GeminiManager from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable
from collections import OrderedDict from collections import OrderedDict
from colossalai.tensor import ProcessGroup, ColoParameter from colossalai.tensor import ProcessGroup, ColoParameter
from colossalai.testing import parameterize
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict): def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
@ -25,7 +27,27 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
else: else:
temp_t2 = t2 temp_t2 = t2
assert torch.equal(t1, temp_t2) assert torch.equal(t1, temp_t2), "\t{}\n\t{}".format(t1, temp_t2)
def check_model_equal(model_a, model_b, allow_empty: bool = False, same_dtype: bool = True):
for (na, pa), (nb, pb) in zip(model_a.named_parameters(), model_b.named_parameters()):
assert na == nb
if not allow_empty:
assert pa.storage().size() > 0
assert pb.storage().size() > 0
else:
if pa.storage().size() == 0 or pb.storage().size() == 0:
continue
if same_dtype:
assert pa.dtype == pb.dtype
temp_pb = pb
else:
temp_pb = pb.to(pa.dtype)
assert torch.equal(pa, temp_pb), "Parameter '{}' is not equal.\n {} {}".format(na, pa, pb)
def init_ddp(module: torch.nn.Module) -> ColoDDP: def init_ddp(module: torch.nn.Module) -> ColoDDP:
@ -33,22 +55,26 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
return ColoDDP(module, process_group=pg) return ColoDDP(module, process_group=pg)
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP: def init_ddpv2(module: torch.nn.Module,
use_chunk: bool = False,
use_zero: bool = False,
placement_policy: str = 'cuda') -> ZeroDDP:
pg = ProcessGroup() pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero) chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager('cuda', chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
return ZeroDDP(module, gemini_manager) return ZeroDDP(module, gemini_manager)
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): def run_ddp_state_dict():
get_components_func = non_distributed_component_funcs.get_callable('nested_model') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = model_builder() model = model_builder()
model = ddp_init_func(model) model = init_ddp(model)
torch_state_dict = torch_model.state_dict() torch_state_dict = torch_model.state_dict()
for param in model.parameters(): for param in model.parameters():
if isinstance(param, ColoParameter): if isinstance(param, ColoParameter):
assert param.get_process_group() is not None assert param.get_process_group() is not None
@ -62,13 +88,44 @@ def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
check_state_dict_equal(torch_state_dict, state_dict) check_state_dict_equal(torch_state_dict, state_dict)
@parameterize('use_chunk', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('use_zero', [False, True])
@parameterize('only_rank_0', [False, True])
def run_zero_state_dict(use_chunk, placement_policy, use_zero, only_rank_0):
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda()
org_torch_model = copy.deepcopy(torch_model)
torch_state_dict = torch_model.state_dict()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = init_ddpv2(model, use_chunk, use_zero, placement_policy)
for param in model.parameters():
if isinstance(param, ColoParameter):
assert param.get_process_group() is not None
model.load_state_dict(torch_state_dict, strict=False)
check_model_equal(model, torch_model, allow_empty=True, same_dtype=False)
for param in model.parameters():
if isinstance(param, ColoParameter):
assert param.get_process_group() is not None
pg = ProcessGroup()
state_dict = model.state_dict(only_rank_0=only_rank_0)
if not only_rank_0 or pg.dp_local_rank() == 0:
torch_model.load_state_dict(state_dict, strict=False)
check_model_equal(torch_model, org_torch_model, allow_empty=False, same_dtype=True)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_state_dict(init_ddp) run_ddp_state_dict()
run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=False)) run_zero_state_dict()
run_state_dict(partial(init_ddpv2, use_chunk=False, use_zero=True))
run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=False))
run_state_dict(partial(init_ddpv2, use_chunk=True, use_zero=True))
@pytest.mark.dist @pytest.mark.dist

Loading…
Cancel
Save