mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix zero optim save/load state dict (#1381)
parent
b6fd165f66
commit
828b9e5e0d
|
@ -104,8 +104,8 @@ class ProcessGroup:
|
|||
def set_cpu_groups(self):
|
||||
if self.has_cpu_groups:
|
||||
return
|
||||
self.logger.info(
|
||||
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||
# self.logger.info(
|
||||
# f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
|
||||
PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
|
||||
PYTORCHPGDICT_.get(self._dp_rank_list, 'gloo')
|
||||
self._has_cpu_groups = True
|
||||
|
|
|
@ -8,6 +8,9 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
|||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
from collections import defaultdict, abc as container_abcs
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -191,22 +194,105 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self.chunk_manager.add_extern_static_tensor(val)
|
||||
|
||||
def state_dict(self):
|
||||
r"""Returns the state of the optimizer as a :class:`dict`. For DP rank != 0, this function returns None.
|
||||
|
||||
It contains two entries:
|
||||
|
||||
* state - a dict holding current optimization state. Its content
|
||||
differs between optimizer classes.
|
||||
* param_groups - a list containing all parameter groups where each
|
||||
parameter group is a dict
|
||||
"""
|
||||
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
|
||||
if not self.chunk_manager.enable_distributed_storage and not is_rank_0:
|
||||
return
|
||||
optim_state_dict = super().state_dict()
|
||||
scaler_state_dict = self.grad_scaler.state_dict()
|
||||
optim_state_dict['scaler'] = scaler_state_dict
|
||||
if not self.chunk_manager.enable_distributed_storage:
|
||||
return optim_state_dict
|
||||
local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0}
|
||||
if not self.chunk_manager.process_group.has_cpu_groups:
|
||||
self.chunk_manager.process_group.set_cpu_groups()
|
||||
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
|
||||
output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())]
|
||||
dist.gather_object(local_state,
|
||||
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
|
||||
dst=dst_rank,
|
||||
group=self.chunk_manager.process_group.cpu_dp_process_group())
|
||||
if not is_rank_0:
|
||||
return
|
||||
for state in output:
|
||||
optim_state_dict['state'].update(state)
|
||||
return optim_state_dict
|
||||
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
if 'scaler' not in args[0]:
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""Loads the optimizer state.
|
||||
|
||||
Args:
|
||||
state_dict (dict): optimizer state. Should be an object returned
|
||||
from a call to :meth:`state_dict`.
|
||||
"""
|
||||
if 'scaler' not in state_dict:
|
||||
self._logger.warning('Missing scaler when loading optimizer state dict', ranks=[0])
|
||||
else:
|
||||
scaler_state_dict = args[0].pop('scaler')
|
||||
self.grad_scaler.load_state_dict(scaler_state_dict)
|
||||
super().load_state_dict(*args, **kwargs)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.optim.state[p]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
state[k] = v.to(dtype=self.fp16_param_to_fp32_param[p].dtype,
|
||||
device=self.fp16_param_to_fp32_param[p].device)
|
||||
self.grad_scaler.load_state_dict(deepcopy(state_dict['scaler']))
|
||||
|
||||
# Validate the state_dict
|
||||
groups = self.param_groups
|
||||
saved_groups = deepcopy(state_dict['param_groups'])
|
||||
|
||||
if len(groups) != len(saved_groups):
|
||||
raise ValueError("loaded state dict has a different number of "
|
||||
"parameter groups")
|
||||
param_lens = (len(g['params']) for g in groups)
|
||||
saved_lens = (len(g['params']) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
raise ValueError("loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group")
|
||||
|
||||
# Update the state
|
||||
id_map = {
|
||||
old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups
|
||||
)), chain.from_iterable((g['params'] for g in groups)))
|
||||
}
|
||||
|
||||
def cast(param, value):
|
||||
r"""Make a deep copy of value, casting all tensors to device of param."""
|
||||
if isinstance(value, torch.Tensor):
|
||||
# Floating-point types are a bit special here. They are the only ones
|
||||
# that are assumed to always match the type of params.
|
||||
if param.is_floating_point():
|
||||
value = value.to(param.dtype)
|
||||
value = value.to(param.device)
|
||||
return value
|
||||
elif isinstance(value, dict):
|
||||
return {k: cast(param, v) for k, v in value.items()}
|
||||
elif isinstance(value, container_abcs.Iterable):
|
||||
return type(value)(cast(param, v) for v in value)
|
||||
else:
|
||||
return value
|
||||
|
||||
# Copy state assigned to params (and cast tensors to appropriate types).
|
||||
# State that is not assigned to params is copied as is (needed for
|
||||
# backward compatibility).
|
||||
state = defaultdict(dict)
|
||||
for k, v in state_dict['state'].items():
|
||||
if k in id_map:
|
||||
param = self.fp16_param_to_fp32_param[id_map[k]]
|
||||
if param.storage().size() > 0:
|
||||
state[param] = cast(param, deepcopy(v))
|
||||
else:
|
||||
state[k] = deepcopy(v)
|
||||
|
||||
# Update parameter groups, setting their 'params' value
|
||||
def update_group(group, new_group):
|
||||
new_group['params'] = group['params']
|
||||
return new_group
|
||||
|
||||
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
||||
self.__setstate__({'state': state, 'param_groups': param_groups})
|
||||
|
||||
|
||||
def convert_state_dict_to_cpu(state: Dict[str, torch.Tensor]):
|
||||
return {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in state.items()}
|
||||
|
|
|
@ -1,100 +1,99 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.gemini import ChunkManager
|
||||
from functools import partial
|
||||
from tests.test_tensor.common_utils import set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||
from colossalai.gemini import ChunkManager, GeminiManager
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
def init_zero(model, use_chunk, use_zero, placement_policy):
|
||||
def check_state(s1, s2):
|
||||
for v1, v2 in zip(s1.values(), s2.values()):
|
||||
if isinstance(v1, torch.Tensor):
|
||||
v1 = v1.to(v2.device)
|
||||
assert torch.equal(v1, v2), f'{torch.sum((v1-v2).abs())}'
|
||||
else:
|
||||
assert v1 == v2
|
||||
|
||||
|
||||
def check_load_state_dict(optim, torch_optim):
|
||||
for group, torch_group in zip(optim.optim.param_groups, torch_optim.param_groups):
|
||||
for p, torch_p in zip(group['params'], torch_group['params']):
|
||||
state = optim.optim.state[p]
|
||||
torch_state = torch_optim.state[torch_p]
|
||||
if p.storage().size() == 0:
|
||||
assert len(state) == 0
|
||||
check_state(state, torch_state)
|
||||
|
||||
|
||||
def check_state_dict(state_dict, torch_state_dict):
|
||||
for (k1, s1), (k2, s2) in zip(state_dict['state'].items(), torch_state_dict['state'].items()):
|
||||
assert k1 == k2
|
||||
check_state(s1, s2)
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@parameterize('use_zero', [False, True])
|
||||
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
|
||||
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
model = model.cuda()
|
||||
torch_model = model_builder().cuda()
|
||||
|
||||
pg = ProcessGroup()
|
||||
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
pg,
|
||||
enable_distributed_storage=use_zero,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
return ZeroDDP(model, gemini_manager)
|
||||
|
||||
|
||||
def run_step(model, optim, criterion, data, label):
|
||||
optim.zero_grad()
|
||||
logits = model(data)
|
||||
loss = criterion(logits, label)
|
||||
optim.backward(loss)
|
||||
optim.step()
|
||||
|
||||
|
||||
def check_state_dict_eq(state_dict, other):
|
||||
for p, state in state_dict['state'].items():
|
||||
other_state = other['state'][p]
|
||||
for k, v in state.items():
|
||||
if isinstance(v, torch.Tensor):
|
||||
assert torch.allclose(v, other_state[k], atol=1e-3), f'{v} vs {other_state[k]}'
|
||||
else:
|
||||
assert v == other_state[k]
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
@parameterize('use_zero', [False, True])
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
def run_nested_model(use_chunk, use_zero, placement_policy):
|
||||
get_components_func = non_distributed_component_funcs.get_callable('nested_model')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
set_seed(42)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder()
|
||||
set_seed(42)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model_copy = model_builder()
|
||||
model = init_zero(model, use_chunk, use_zero, placement_policy)
|
||||
model_copy = init_zero(model_copy, use_chunk, use_zero, placement_policy)
|
||||
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||
optim_copy = HybridAdam(model_copy.parameters(), lr=1e-3)
|
||||
optim_copy = ZeroOptimizer(optim_copy, model_copy, initial_scale=32)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=1)
|
||||
|
||||
model.train()
|
||||
model_copy.train()
|
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||
data_iter = iter(train_dataloader)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
|
||||
data, label = map(lambda x: x.cuda(), next(data_iter))
|
||||
run_step(model, optim, criterion, data, label)
|
||||
optim_copy.load_state_dict(optim.state_dict())
|
||||
check_state_dict_eq(optim.state_dict(), optim_copy.state_dict())
|
||||
for p in torch_model.parameters():
|
||||
p.grad = torch.rand_like(p)
|
||||
|
||||
data, label = map(lambda x: x.cuda(), next(data_iter))
|
||||
run_step(model_copy, optim_copy, criterion, data, label)
|
||||
torch_optim.step()
|
||||
torch_state_dict = torch_optim.state_dict()
|
||||
optim.load_state_dict(torch_state_dict)
|
||||
check_load_state_dict(optim, torch_optim)
|
||||
|
||||
state_dict = optim.state_dict()
|
||||
if pg.rank() == 0:
|
||||
check_state_dict(state_dict, torch_state_dict)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_nested_model()
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_zero_optim_state_dict()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_zero_optim_state_dist(world_size):
|
||||
def test_zero_optim_state_dict(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_zero_optim_state_dist(2)
|
||||
test_zero_optim_state_dict(2)
|
||||
|
|
Loading…
Reference in New Issue