[hotfix] fix zero optim save/load state dict (#1381)

pull/1384/head
ver217 2022-07-28 17:19:39 +08:00 committed by GitHub
parent b6fd165f66
commit 828b9e5e0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 160 additions and 75 deletions

View File

@ -104,8 +104,8 @@ class ProcessGroup:
def set_cpu_groups(self): def set_cpu_groups(self):
if self.has_cpu_groups: if self.has_cpu_groups:
return return
self.logger.info( # self.logger.info(
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}') # f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo') PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
self._has_cpu_groups = True self._has_cpu_groups = True

View File

@ -8,6 +8,9 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import get_current_device, disposable from colossalai.utils import get_current_device, disposable
from collections import defaultdict, abc as container_abcs
from copy import deepcopy
from itertools import chain
class OptimState(Enum): class OptimState(Enum):
@ -191,22 +194,105 @@ class ZeroOptimizer(ColossalaiOptimizer):
self.chunk_manager.add_extern_static_tensor(val) self.chunk_manager.add_extern_static_tensor(val)
def state_dict(self): def state_dict(self):
r"""Returns the state of the optimizer as a :class:`dict`. For DP rank != 0, this function returns None.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
if not self.chunk_manager.enable_distributed_storage and not is_rank_0:
return
optim_state_dict = super().state_dict() optim_state_dict = super().state_dict()
scaler_state_dict = self.grad_scaler.state_dict() scaler_state_dict = self.grad_scaler.state_dict()
optim_state_dict['scaler'] = scaler_state_dict optim_state_dict['scaler'] = scaler_state_dict
if not self.chunk_manager.enable_distributed_storage:
return optim_state_dict
local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0}
if not self.chunk_manager.process_group.has_cpu_groups:
self.chunk_manager.process_group.set_cpu_groups()
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())]
dist.gather_object(local_state,
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
dst=dst_rank,
group=self.chunk_manager.process_group.cpu_dp_process_group())
if not is_rank_0:
return
for state in output:
optim_state_dict['state'].update(state)
return optim_state_dict return optim_state_dict
def load_state_dict(self, *args, **kwargs): def load_state_dict(self, state_dict):
if 'scaler' not in args[0]: r"""Loads the optimizer state.
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
if 'scaler' not in state_dict:
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0]) self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
else: else:
scaler_state_dict = args[0].pop('scaler') self.grad_scaler.load_state_dict(deepcopy(state_dict['scaler']))
self.grad_scaler.load_state_dict(scaler_state_dict)
super().load_state_dict(*args, **kwargs) # Validate the state_dict
for group in self.optim.param_groups: groups = self.param_groups
for p in group['params']: saved_groups = deepcopy(state_dict['param_groups'])
state = self.optim.state[p]
for k, v in state.items(): if len(groups) != len(saved_groups):
if isinstance(v, torch.Tensor): raise ValueError("loaded state dict has a different number of "
state[k] = v.to(dtype=self.fp16_param_to_fp32_param[p].dtype, "parameter groups")
device=self.fp16_param_to_fp32_param[p].device) param_lens = (len(g['params']) for g in groups)
saved_lens = (len(g['params']) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
raise ValueError("loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group")
# Update the state
id_map = {
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
)), chain.from_iterable((g['params'] for g in groups)))
}
def cast(param, value):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if param.is_floating_point():
value = value.to(param.dtype)
value = value.to(param.device)
return value
elif isinstance(value, dict):
return {k: cast(param, v) for k, v in value.items()}
elif isinstance(value, container_abcs.Iterable):
return type(value)(cast(param, v) for v in value)
else:
return value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state = defaultdict(dict)
for k, v in state_dict['state'].items():
if k in id_map:
param = self.fp16_param_to_fp32_param[id_map[k]]
if param.storage().size() > 0:
state[param] = cast(param, deepcopy(v))
else:
state[k] = deepcopy(v)
# Update parameter groups, setting their 'params' value
def update_group(group, new_group):
new_group['params'] = group['params']
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
self.__setstate__({'state': state, 'param_groups': param_groups})
def convert_state_dict_to_cpu(state: Dict[str, torch.Tensor]):
return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in state.items()}

View File

@ -1,100 +1,99 @@
import pytest import pytest
import colossalai import colossalai
import torch import torch
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.core import global_context as gpc from colossalai.gemini import ChunkManager
from functools import partial from functools import partial
from tests.test_tensor.common_utils import set_seed
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel.data_parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.testing import parameterize
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from colossalai.testing import parameterize
from colossalai.gemini.gemini_mgr import GeminiManager
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
def init_zero(model, use_chunk, use_zero, placement_policy): def check_state(s1, s2):
for v1, v2 in zip(s1.values(), s2.values()):
if isinstance(v1, torch.Tensor):
v1 = v1.to(v2.device)
assert torch.equal(v1, v2), f'{torch.sum((v1-v2).abs())}'
else:
assert v1 == v2
def check_load_state_dict(optim, torch_optim):
for group, torch_group in zip(optim.optim.param_groups, torch_optim.param_groups):
for p, torch_p in zip(group['params'], torch_group['params']):
state = optim.optim.state[p]
torch_state = torch_optim.state[torch_p]
if p.storage().size() == 0:
assert len(state) == 0
check_state(state, torch_state)
def check_state_dict(state_dict, torch_state_dict):
for (k1, s1), (k2, s2) in zip(state_dict['state'].items(), torch_state_dict['state'].items()):
assert k1 == k2
check_state(s1, s2)
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
with ColoInitContext(device=get_current_device()):
model = model_builder()
model = model.cuda()
torch_model = model_builder().cuda()
pg = ProcessGroup() pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
pg, pg,
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
return ZeroDDP(model, gemini_manager) model = ZeroDDP(model, gemini_manager)
def run_step(model, optim, criterion, data, label):
optim.zero_grad()
logits = model(data)
loss = criterion(logits, label)
optim.backward(loss)
optim.step()
def check_state_dict_eq(state_dict, other):
for p, state in state_dict['state'].items():
other_state = other['state'][p]
for k, v in state.items():
if isinstance(v, torch.Tensor):
assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}'
else:
assert v == other_state[k]
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu'])
def run_nested_model(use_chunk, use_zero, placement_policy):
get_components_func = non_distributed_component_funcs.get_callable('nested_model')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
set_seed(42)
with ColoInitContext(device=get_current_device()):
model = model_builder()
set_seed(42)
with ColoInitContext(device=get_current_device()):
model_copy = model_builder()
model = init_zero(model, use_chunk, use_zero, placement_policy)
model_copy = init_zero(model_copy, use_chunk, use_zero, placement_policy)
optim = HybridAdam(model.parameters(), lr=1e-3) optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=32) optim = ZeroOptimizer(optim, model, initial_scale=1)
optim_copy = HybridAdam(model_copy.parameters(), lr=1e-3)
optim_copy = ZeroOptimizer(optim_copy, model_copy, initial_scale=32)
model.train() torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
model_copy.train()
set_seed(gpc.get_local_rank(ParallelMode.DATA))
data_iter = iter(train_dataloader)
data, label = map(lambda x: x.cuda(), next(data_iter)) for p in torch_model.parameters():
run_step(model, optim, criterion, data, label) p.grad = torch.rand_like(p)
optim_copy.load_state_dict(optim.state_dict())
check_state_dict_eq(optim.state_dict(), optim_copy.state_dict())
data, label = map(lambda x: x.cuda(), next(data_iter)) torch_optim.step()
run_step(model_copy, optim_copy, criterion, data, label) torch_state_dict = torch_optim.state_dict()
optim.load_state_dict(torch_state_dict)
check_load_state_dict(optim, torch_optim)
state_dict = optim.state_dict()
if pg.rank() == 0:
check_state_dict(state_dict, torch_state_dict)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') config = {}
run_nested_model() colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_zero_optim_state_dict()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2]) @pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_zero_optim_state_dist(world_size): def test_zero_optim_state_dict(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port()) run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_zero_optim_state_dist(2) test_zero_optim_state_dict(2)