[gemini] support save state dict in shards (#3581)

* [gemini] support state dict shard

* [gemini] add test state dict shard

* [gemini] polish docstr

* [gemini] fix merge

* [gemini] polish code
pull/3586/head
Hongxin Liu 2023-04-17 17:11:09 +08:00 committed by GitHub
parent 7788e0b0a5
commit f313babd11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 172 additions and 13 deletions

View File

@ -1,12 +1,13 @@
import itertools import itertools
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Dict, List, Optional, Union from typing import Dict, Iterator, List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
@ -228,6 +229,32 @@ class ZeroDDP(ColoDDP):
destination = hook_result destination = hook_result
return destination return destination
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
"""
get gathered chunk content.
Args:
chunk (Chunk): a chunk
only_rank_0 (bool): whether to only save data on rank 0
Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters
chunk_to_save_data = dict()
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
assert tensor not in chunk_to_save_data
chunk_to_save_data[tensor] = record_tensor
del temp_chunk
return chunk_to_save_data
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict: def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
""" """
get param content from chunks. get param content from chunks.
@ -243,18 +270,7 @@ class ZeroDDP(ColoDDP):
param_to_save_data = dict() param_to_save_data = dict()
chunk_list = self.chunk_manager.get_chunks(param_list) chunk_list = self.chunk_manager.get_chunks(param_list)
for chunk in chunk_list: for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk) param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
for tensor, tensor_info in chunk.tensors_info.items():
record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag:
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
assert tensor not in param_to_save_data
param_to_save_data[tensor] = record_tensor
del temp_chunk
return param_to_save_data return param_to_save_data
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True): def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
@ -554,6 +570,93 @@ class ZeroDDP(ColoDDP):
p.__class__ = ColoParameter p.__class__ = ColoParameter
p.__init__(p, requires_grad=requires_grad) p.__init__(p, requires_grad=requires_grad)
def state_dict_shard(self,
prefix: str = '',
keep_vars: bool = False,
max_shard_size: int = 1024,
only_rank_0: bool = True) -> Iterator[OrderedDict]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Both parameters and persistent buffers (e.g. running averages) are included.
Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Args:
prefix (str, optional): the prefix for parameters and buffers used in this
module. Defaults to ''.
keep_vars (bool, optional): whether to keep variables. Defaults to False.
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
only_rank_0 (bool, optional): only get data on rank0. Defaults to True.
Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
sharder = _StateDictSharder(max_shard_size)
# get the mapping between copies and fp16 parameters
fp16_to_fp32 = dict()
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
fp16_to_fp32[p] = fp32_p
# key is fp32 param, and value is gathered param on CPU
gathered_param_buffer = dict()
for name, param in self.name2param.items():
if param is not None:
if is_ddp_ignored(param):
# deal with ddp ignored parameters
gathered_param = param if keep_vars else param.detach()
else:
fp32_param = fp16_to_fp32[param]
if fp32_param not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(fp32_param)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
gathered_param = gathered_param_buffer.pop(fp32_param)
block = sharder.append(prefix + name, gathered_param)
if block is not None:
yield block
del fp16_to_fp32
del gathered_param_buffer
# save all buffers
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block = sharder.append(prefix + name, buffer)
if block is not None:
yield block
# save extra states
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
block = sharder.append(extra_state_key, extra_state)
if block is not None:
yield block
yield sharder.current_block
class _StateDictSharder:
def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.current_block = OrderedDict()
self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
if self.current_block_size + tensor_size > self.max_shard_size:
ret_block = self.current_block
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block
class GeminiDDP(ZeroDDP): class GeminiDDP(ZeroDDP):

View File

@ -0,0 +1,56 @@
import pytest
import torch
from torch.testing import assert_close
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.cuda import get_current_device
from colossalai.zero import ColoInitContext, ZeroDDP
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from tests.components_to_test.registry import non_distributed_component_funcs
@parameterize('placement_policy', ['cuda', 'cpu'])
@parameterize('model_name', ['gpt2', 'bert'])
def exam_state_dict(placement_policy, model_name: str):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
chunk_manager = ChunkManager(config_dict)
gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set()
# ensure number of shards > 1
for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
for key, value in shard.items():
assert key not in accumulated_keys, f"key `{key}` is duplicated."
accumulated_keys.add(key)
assert key in zero_dict, f"{key} not in ZeRO dictionary."
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use()
def test_zero_ddp_state_dict_shard(world_size):
spawn(run_dist, world_size)
if __name__ == '__main__':
test_zero_ddp_state_dict_shard(1)