mirror of https://github.com/hpcaitech/ColossalAI
[gemini] support save state dict in shards (#3581)
* [gemini] support state dict shard * [gemini] add test state dict shard * [gemini] polish docstr * [gemini] fix merge * [gemini] polish codepull/3586/head
parent
7788e0b0a5
commit
f313babd11
|
@ -1,12 +1,13 @@
|
||||||
import itertools
|
import itertools
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from colossalai.checkpoint_io.utils import calculate_tensor_size
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
from colossalai.nn.parallel.data_parallel import ColoDDP, _cast_float, free_storage
|
||||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||||
|
@ -228,6 +229,32 @@ class ZeroDDP(ColoDDP):
|
||||||
destination = hook_result
|
destination = hook_result
|
||||||
return destination
|
return destination
|
||||||
|
|
||||||
|
def _get_chunk_to_save_data(self, chunk: Chunk, only_rank_0: bool) -> Dict:
|
||||||
|
"""
|
||||||
|
get gathered chunk content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk (Chunk): a chunk
|
||||||
|
only_rank_0 (bool): whether to only save data on rank 0
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict: a dict whose key is param name and value is param with correct payload
|
||||||
|
"""
|
||||||
|
# save parameters
|
||||||
|
chunk_to_save_data = dict()
|
||||||
|
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
||||||
|
for tensor, tensor_info in chunk.tensors_info.items():
|
||||||
|
record_tensor = torch.empty([0])
|
||||||
|
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
||||||
|
if record_flag:
|
||||||
|
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
||||||
|
|
||||||
|
assert tensor not in chunk_to_save_data
|
||||||
|
chunk_to_save_data[tensor] = record_tensor
|
||||||
|
|
||||||
|
del temp_chunk
|
||||||
|
return chunk_to_save_data
|
||||||
|
|
||||||
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
def _get_param_to_save_data(self, param_list: List[torch.nn.Parameter], only_rank_0: bool) -> Dict:
|
||||||
"""
|
"""
|
||||||
get param content from chunks.
|
get param content from chunks.
|
||||||
|
@ -243,18 +270,7 @@ class ZeroDDP(ColoDDP):
|
||||||
param_to_save_data = dict()
|
param_to_save_data = dict()
|
||||||
chunk_list = self.chunk_manager.get_chunks(param_list)
|
chunk_list = self.chunk_manager.get_chunks(param_list)
|
||||||
for chunk in chunk_list:
|
for chunk in chunk_list:
|
||||||
temp_chunk = get_temp_total_chunk_on_cuda(chunk)
|
param_to_save_data.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||||
|
|
||||||
for tensor, tensor_info in chunk.tensors_info.items():
|
|
||||||
record_tensor = torch.empty([0])
|
|
||||||
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
|
|
||||||
if record_flag:
|
|
||||||
record_tensor = temp_chunk[tensor_info.offset:tensor_info.end].view(tensor.shape).cpu()
|
|
||||||
|
|
||||||
assert tensor not in param_to_save_data
|
|
||||||
param_to_save_data[tensor] = record_tensor
|
|
||||||
|
|
||||||
del temp_chunk
|
|
||||||
return param_to_save_data
|
return param_to_save_data
|
||||||
|
|
||||||
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):
|
||||||
|
@ -554,6 +570,93 @@ class ZeroDDP(ColoDDP):
|
||||||
p.__class__ = ColoParameter
|
p.__class__ = ColoParameter
|
||||||
p.__init__(p, requires_grad=requires_grad)
|
p.__init__(p, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
def state_dict_shard(self,
|
||||||
|
prefix: str = '',
|
||||||
|
keep_vars: bool = False,
|
||||||
|
max_shard_size: int = 1024,
|
||||||
|
only_rank_0: bool = True) -> Iterator[OrderedDict]:
|
||||||
|
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
|
||||||
|
|
||||||
|
Both parameters and persistent buffers (e.g. running averages) are included.
|
||||||
|
Keys are corresponding parameter and buffer names.
|
||||||
|
Parameters and buffers set to ``None`` are not included.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prefix (str, optional): the prefix for parameters and buffers used in this
|
||||||
|
module. Defaults to ''.
|
||||||
|
keep_vars (bool, optional): whether to keep variables. Defaults to False.
|
||||||
|
max_shard_size (int, optional): max size of state dict shard (in MB). Defaults to 1024.
|
||||||
|
only_rank_0 (bool, optional): only get data on rank0. Defaults to True.
|
||||||
|
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Iterator[OrderedDict]: A generator of state dict shard
|
||||||
|
"""
|
||||||
|
sharder = _StateDictSharder(max_shard_size)
|
||||||
|
|
||||||
|
# get the mapping between copies and fp16 parameters
|
||||||
|
fp16_to_fp32 = dict()
|
||||||
|
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
|
||||||
|
fp16_to_fp32[p] = fp32_p
|
||||||
|
|
||||||
|
# key is fp32 param, and value is gathered param on CPU
|
||||||
|
gathered_param_buffer = dict()
|
||||||
|
for name, param in self.name2param.items():
|
||||||
|
if param is not None:
|
||||||
|
if is_ddp_ignored(param):
|
||||||
|
# deal with ddp ignored parameters
|
||||||
|
gathered_param = param if keep_vars else param.detach()
|
||||||
|
else:
|
||||||
|
fp32_param = fp16_to_fp32[param]
|
||||||
|
if fp32_param not in gathered_param_buffer:
|
||||||
|
chunk = self.chunk_manager.get_chunk(fp32_param)
|
||||||
|
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
|
||||||
|
gathered_param = gathered_param_buffer.pop(fp32_param)
|
||||||
|
|
||||||
|
block = sharder.append(prefix + name, gathered_param)
|
||||||
|
if block is not None:
|
||||||
|
yield block
|
||||||
|
|
||||||
|
del fp16_to_fp32
|
||||||
|
del gathered_param_buffer
|
||||||
|
|
||||||
|
# save all buffers
|
||||||
|
for name, buf in self.named_buffers():
|
||||||
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||||
|
buffer = buf if keep_vars else buf.detach()
|
||||||
|
block = sharder.append(prefix + name, buffer)
|
||||||
|
if block is not None:
|
||||||
|
yield block
|
||||||
|
# save extra states
|
||||||
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||||
|
if getattr(self.__class__, "get_extra_state",
|
||||||
|
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
|
||||||
|
extra_state = self.get_extra_state()
|
||||||
|
block = sharder.append(extra_state_key, extra_state)
|
||||||
|
if block is not None:
|
||||||
|
yield block
|
||||||
|
|
||||||
|
yield sharder.current_block
|
||||||
|
|
||||||
|
|
||||||
|
class _StateDictSharder:
|
||||||
|
|
||||||
|
def __init__(self, max_shard_size: int) -> None:
|
||||||
|
self.max_shard_size = max_shard_size
|
||||||
|
self.current_block = OrderedDict()
|
||||||
|
self.current_block_size = 0
|
||||||
|
|
||||||
|
def append(self, name: str, tensor: torch.Tensor) -> Optional[OrderedDict]:
|
||||||
|
tensor_size = calculate_tensor_size(tensor)
|
||||||
|
ret_block = None
|
||||||
|
if self.current_block_size + tensor_size > self.max_shard_size:
|
||||||
|
ret_block = self.current_block
|
||||||
|
self.current_block = OrderedDict()
|
||||||
|
self.current_block_size = 0
|
||||||
|
self.current_block[name] = tensor
|
||||||
|
self.current_block_size += tensor_size
|
||||||
|
return ret_block
|
||||||
|
|
||||||
|
|
||||||
class GeminiDDP(ZeroDDP):
|
class GeminiDDP(ZeroDDP):
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||||
|
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||||
|
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||||
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||||
|
@parameterize('model_name', ['gpt2', 'bert'])
|
||||||
|
def exam_state_dict(placement_policy, model_name: str):
|
||||||
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = model_builder()
|
||||||
|
|
||||||
|
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||||
|
|
||||||
|
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||||
|
chunk_manager = ChunkManager(config_dict)
|
||||||
|
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||||
|
model = ZeroDDP(model, gemini_manager)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
zero_dict = model.state_dict(only_rank_0=False)
|
||||||
|
accumulated_keys = set()
|
||||||
|
# ensure number of shards > 1
|
||||||
|
for shard in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
|
||||||
|
for key, value in shard.items():
|
||||||
|
assert key not in accumulated_keys, f"key `{key}` is duplicated."
|
||||||
|
accumulated_keys.add(key)
|
||||||
|
assert key in zero_dict, f"{key} not in ZeRO dictionary."
|
||||||
|
assert torch.equal(value, zero_dict[key]), f"{key} not equal."
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
config = {}
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
exam_state_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_zero_ddp_state_dict_shard(world_size):
|
||||||
|
spawn(run_dist, world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_zero_ddp_state_dict_shard(1)
|
Loading…
Reference in New Issue