[zero] add state dict for low level zero (#4179)

* add state dict for zero

* fix unit test

* polish
pull/4359/head
LuGY 1 year ago committed by Hongxin Liu
parent c668801d36
commit dd7cc58299

@ -1,4 +1,5 @@
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import copy
from contextlib import contextmanager
from functools import partial
from typing import Optional
@ -198,7 +199,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
params_current_rank = []
device = 'cpu' if self._cpu_offload else get_current_device()
for param in reversed(param_list):
for param in param_list:
padding_size = (self._world_size - param.numel() % self._world_size) % self._world_size
self._param_store.record_param_padding_size(param, padding_size)
@ -468,3 +469,68 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
yield
finally:
self.require_grad_sync = old_require_grad_sync
##############
# State Dict #
##############
def _pack_state(self, state: dict) -> dict:
# comes from pytorch optimizer.state_dict()
param_mappings = {}
start_index = 0
def pack_group(group):
nonlocal start_index
packed = {k: v for k, v in group.items() if k != 'params'}
param_mappings.update(
{id(p): i for i, p in enumerate(group['params'], start_index) if id(p) not in param_mappings})
packed['params'] = [param_mappings[id(p)] for p in group['params']]
start_index += len(packed['params'])
return packed
param_groups = [pack_group(g) for g in self.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
return {'state': packed_state, 'param_groups': param_groups}
def state_dict(self) -> dict:
"""Return a state_dict same with DDP
Returns:
dict: the pytorch form state_dict
"""
zero_state = dict()
for param, state in self.optim.state.items():
zero_state[param] = copy.deepcopy(state)
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
working_param = self._param_store.master_to_working_param[id(param)]
gather_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(gather_tensor, v, group=self.dp_pg)
param_state = torch.stack(gather_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
zero_state[param][k] = param_state
states_dict = self._pack_state(zero_state)
return states_dict
def load_state_dict(self, state_dict: dict):
"""Load state dict, requires the state_dict be the pytorch form
Args:
state_dict (dict): A pytorch form state_dict
"""
zero_state_dict = copy.deepcopy(state_dict)
for param_idx, state in zero_state_dict['state'].items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self._world_size - v.numel() % self._world_size) % self._world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self._world_size)
zero_state_dict['state'][param_idx][k] = v_list[self._local_rank].detach()
self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()

@ -0,0 +1,121 @@
import copy
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
from colossalai.zero import LowLevelZeroOptimizer
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(12, 24)
self.linear2 = nn.Linear(24, 12)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
def loose_close(a, b, dtype: torch.dtype = torch.float32):
rtol = None
atol = None
if dtype is torch.float16:
rtol = 5e-2
atol = 5e-4
elif dtype is torch.bfloat16:
rtol = 4e-3
atol = 4e-3
a = a.detach().to(dtype)
b = b.detach().to(dtype)
assert_close(a, b, rtol=rtol, atol=atol)
def exam_zero_1_torch_ddp_ckpt():
"""
We examine the state_dict of zero and DDP.
Moreover, we examine the zero's loading checkpoint of a torch ckpt.
"""
local_rank = torch.distributed.get_rank()
seed_all(1453)
# create models
torch_model = MlpModel().cuda()
zero_model = copy.deepcopy(torch_model)
torch_model = DDP(torch_model.cuda(), static_graph=True).cuda()
# create optimizer
zero_optimizer = torch.optim.Adam(zero_model.parameters(), lr=1)
# we only test stage 1 here
# the state dicts of stage 1 and stage 2 are the same
zero_optimizer = LowLevelZeroOptimizer(zero_optimizer,
overlap_communication=True,
initial_scale=1,
reduce_bucket_size=262144)
torch_optimizer = torch.optim.Adam(torch_model.parameters(), lr=1)
seed_all(1453 + local_rank)
# create
input_data = torch.rand(4, 12).cuda()
# forward
zero_output = zero_model(input_data)
torch_output = torch_model(input_data)
# backward
zero_optimizer.backward(zero_output.mean().float())
torch_output.mean().backward()
# step
zero_optimizer.step()
torch_optimizer.step()
torch_state_dict = torch_optimizer.state_dict()
zero_state_dict = zero_optimizer.state_dict()
# examine the original state dict
for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()):
for t_v, z_v in zip(torch_state.values(), zero_state.values()):
loose_close(t_v, z_v)
# empty the optimzer state
zero_optimizer.optim.state = []
# zero load a torch checkpoint
zero_optimizer.load_state_dict(copy.deepcopy(torch_state_dict))
zero_state_dict = zero_optimizer.state_dict()
# examine the loaded state dict
for torch_state, zero_state in zip(torch_state_dict['state'].values(), zero_state_dict['state'].values()):
for t_v, z_v in zip(torch_state.values(), zero_state.values()):
loose_close(t_v, z_v)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_zero_1_torch_ddp_ckpt()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_zero_ckpt():
spawn(run_dist, 2)
if __name__ == '__main__':
test_zero_ckpt()
Loading…
Cancel
Save