[zero] support shard optimizer state dict of zero (#4194)

* support shard optimizer of zero

* polish code

* support sync grad manually
pull/4359/head
LuGY 2023-07-11 18:03:13 +08:00 committed by Hongxin Liu
parent dd7cc58299
commit 1a49a5ea00
4 changed files with 239 additions and 68 deletions

View File

@ -1,5 +1,8 @@
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch
@ -10,10 +13,16 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO
from colossalai.checkpoint_io.utils import (
get_optimizer_base_filenames,
get_shard_filename,
save_param_groups,
save_state_dict,
)
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO
@ -32,21 +41,104 @@ SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
# TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
warnings.warn(
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
"""Save optimizer to checkpoint but only on master process.
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
warnings.warn(
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
super().load_optimizer(optimizer, checkpoint)
Args:
optimizer (OptimizerWrapper): Optimizer to save state_dict
checkpoint (str): Path to save checkpoint
gather_dtensor (bool): Whether to gather_dtensor, not used
"""
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def save_sharded_optimizer(self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = False,
prefix: str = None,
size_per_shard: int = 1024):
"""
Save sharded Zero-optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# state_dict only provide only 'param_groups'
state_dict = optimizer.optim.state_dict()
# state shard would be handled by the low-level zero optimizer
sharded_state = optimizer.state_dict_shard(max_shard_size=size_per_shard)
# Preparing file paths and index file.
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
# Store the information of param groups to param_group_file.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(state_dict, group_file_path)
# Save shards of optimizer states.
total_size = 0
for idx, shard_pair in enumerate(sharded_state):
shard, current_size = shard_pair
shard_file = get_shard_filename(states_name, idx)
total_size = total_size + current_size
for param_id in shard.keys():
index_file.append_weight_map(str(param_id), shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
if self.coordinator.is_master():
save_state_dict(shard, checkpoint_file_path, use_safetensors=False)
# Wrap up index file.
index_file.append_meta_data("total_size", total_size)
if self.coordinator.is_master():
index_file.write_index_file(save_index_file)
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, index_file_path: str, prefix: str):
"""Load sharded optimizer with the given path to index file.
Args:
optimizer (OptimizerWrapper): Optimizer to load state_dict
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
super().load_sharded_optimizer(optimizer, index_file_path, prefix)
current_rank_state_dict = optimizer.optim.state_dict()['state']
for param_idx, state in current_rank_state_dict.items():
for k, v in state.items():
if isinstance(v, torch.Tensor) and k != 'step':
padding_size = (self.coordinator.world_size -
v.numel() % self.coordinator.world_size) % self.coordinator.world_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
v_list = v.split(v.numel() // self.coordinator.world_size)
current_rank_state_dict[param_idx][k] = v_list[self.coordinator.rank].detach()
class LowLevelZeroModel(ModelWrapper):
@ -74,36 +166,6 @@ class LowLevelZeroModel(ModelWrapper):
return super().forward(*args, **kwargs)
class LowLevelZeroOptimizer(OptimizerWrapper):
def __init__(self,
module: nn.Module,
optimizer: Optimizer,
zero_optim_config: dict,
optim_kwargs: dict,
verbose: bool = False) -> None:
optimizer = zero_optim_wrapper(module,
optimizer,
optim_config=zero_optim_config,
**optim_kwargs,
verbose=verbose)
super().__init__(optimizer)
def backward(self, loss: Tensor, *args, **kwargs):
self.optim.backward(loss)
def clip_grad_by_norm(self,
max_norm: Union[float, int],
norm_type: Union[float, int] = 2,
error_if_nonfinite: bool = False,
*args,
**kwargs) -> Tensor:
warnings.warn(f'LowLevelZero controls grad clipping by itself, so you should not use clip_grad_by_norm')
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('LowLevelZero does not support clip_grad_by_value')
class LowLevelZeroPlugin(DPPluginBase):
"""
Plugin for low level zero.
@ -211,8 +273,11 @@ class LowLevelZeroPlugin(DPPluginBase):
if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper):
optimizer = LowLevelZeroOptimizer(model.unwrap(), optimizer, self.zero_optim_config, self.optim_kwargs,
self.verbose)
optimizer = zero_optim_wrapper(model.unwrap(),
optimizer,
optim_config=self.zero_optim_config,
**self.optim_kwargs,
verbose=self.verbose)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -2,7 +2,7 @@
import copy
from contextlib import contextmanager
from functools import partial
from typing import Optional
from typing import Dict, Iterator, Optional, Tuple
import torch
import torch.distributed as dist
@ -447,18 +447,23 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
# Gradient Synchronization #
############################
# this method is used to sync gradient manually
def sync_grad(self):
for group_id in range(self.num_param_groups):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.requires_grad and param.grad is not None:
self._add_to_bucket(param, group_id)
self._run_reduction()
def _reduce_grad(self, partition_grad):
# if not overlapping communication (no reduction hook is attached) when zero1
# we need to manually reduce these gradients
if not partition_grad and not self._overlap_communication:
for group_id in range(len(self._working_param_groups)):
param_group = self._working_param_groups[group_id]
for param in param_group:
if param.grad is not None:
self._add_to_bucket(param, group_id)
# run reduction
self._run_reduction()
self.sync_grad()
else:
self._run_reduction()
# this context comes from pytorch DDP
@contextmanager
@ -473,7 +478,8 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
##############
# State Dict #
##############
def _pack_state(self, state: dict) -> dict:
def _pack_state(self, state: Dict) -> Dict:
# comes from pytorch optimizer.state_dict()
param_mappings = {}
start_index = 0
@ -487,17 +493,17 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
start_index += len(packed['params'])
return packed
param_groups = [pack_group(g) for g in self.param_groups]
param_groups = [pack_group(g) for g in self.optim.param_groups]
# Remap state to use order indices as keys
packed_state = {(param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v for k, v in state.items()}
return {'state': packed_state, 'param_groups': param_groups}
def state_dict(self) -> dict:
def state_dict(self) -> Dict:
"""Return a state_dict same with DDP
Returns:
dict: the pytorch form state_dict
Dict: the pytorch form state_dict
"""
zero_state = dict()
for param, state in self.optim.state.items():
@ -514,7 +520,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
return states_dict
def load_state_dict(self, state_dict: dict):
def load_state_dict(self, state_dict: Dict):
"""Load state dict, requires the state_dict be the pytorch form
Args:
@ -534,3 +540,46 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.optim.load_state_dict(zero_state_dict)
zero_state_dict = dict()
def state_dict_shard(self, max_shard_size: int = 1024) -> Iterator[Tuple[Dict, int]]:
"""Returns dictionaries containing a whole state of the module one by one. The max size of dictionary shard is specified by ``max_shard_size``.
Only include the 'state' in state_dict.
Args:
max_shard_size (int, optional): max size of state shard (in MB). Defaults to 1024.
Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
ret_block = dict()
ret_block_size = 0
local_states = self.optim.state_dict()['state']
for param_idx, states in local_states.items():
current_block_size = 0
current_block = copy.deepcopy(states)
# find the working param of current param_id
for group_id, pg in self._master_param_groups_of_current_rank.items():
if (group_id + 1) * len(pg) < param_idx:
continue
master_param = pg[param_idx - (group_id) * len(pg)]
working_param = self._param_store.master_to_working_param[id(master_param)]
for k, v in states.items():
if isinstance(v, torch.Tensor) and k != 'step':
state_tensor = [torch.zeros_like(v) for _ in range(self._world_size)]
dist.all_gather(state_tensor, v, group=self.dp_pg)
state_tensor = torch.stack(state_tensor).view(-1)[:working_param.numel()].reshape_as(working_param)
current_block_size += state_tensor.numel()
current_block[k] = state_tensor
if ret_block_size + current_block_size > max_shard_size and len(ret_block) > 0:
yield ret_block, ret_block_size
ret_block = dict()
ret_block_size = 0
ret_block[param_idx] = current_block
ret_block_size += current_block_size
yield ret_block, ret_block_size

View File

@ -0,0 +1,54 @@
# Low Level ZeRO
>Low Level ZeRO == ZeRO-DP stage 1 and 2, we would denote it as ZeRO.
## Design:
### Notion
`p32` denotes the param copy in the optimizer
`p` denotes the model param
`g` denotes the gradient
### INIT
In low level zero(1, 2), `p32` is split. Different from the previous implement, we split each `p32` evenly by world_size. Thus, rank0 got a param list as `[p00, p10]`, rank1 got a param list as `[p-01, p-11]`, etc.
<img width="840" alt="image" src="https://github.com/hpcaitech/ColossalAI/assets/74758262/f7758d7d-c5e5-44a4-a121-3aba8b05c904">
For the detailed implementation, we first pad `p` for it can be split by world_size if needed. Then, we would view it to the shape `[world_size, -1]`, and each rank got its own part `p32` by cloning.
### BWD
To leverage the communication, a gradient would be added to a bucket first. When the bucket is full, each `g` in it would be reshaped as `[world_size, -1]`. And the `[local_rank]` parts would be united.
The data structure looks like this:
```
{
0: [g-00, g-10],
1: [g-01, g-11],
2: [g-02, g-12]
}
```
After that, the gradients would be flattened by rank, and the data structure looks like this:
```
# g-0 means flatten([g-00, g-10])
{
0: [g-0],
1: [g-1],
2: [g-2]
}
```
For zero1, we iterate the dictionary and do `all_reduce`. For zero2, we can just do `reduce-scatter`.
### Optim
For each rank gets its own `p32` and the counterpart `g`, it is quite easy to do `optim.step()`.
However, we have to consider a situation of layer drop, for instance:
```
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(128, 256)
self.drop_linear = nn.Linear(256, 256)
self.linear2 = nn.Linear(256, 512)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
```
And the solution is to build a mapping of `p32`, `p`, and `g`. Before `optim.step()`, we collect `p` which `requires_grad=True` and `p.grad != None` as a real working param. And select the counterpart `p32` and `g`.

View File

@ -38,9 +38,8 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
optimizer_ckpt_path = f"{tempdir}/optimizer"
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
booster.save_model(model, model_ckpt_path, shard=shard)
if not shard:
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
booster.save_optimizer(optimizer, optimizer_ckpt_path)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard)
dist.barrier()
new_model = resnet18()
@ -49,9 +48,9 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool):
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
if not shard:
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
def run_dist(rank, world_size, port):
@ -62,3 +61,7 @@ def run_dist(rank, world_size, port):
@rerun_if_address_is_in_use()
def test_low_level_zero_checkpointIO():
spawn(run_dist, 2)
if __name__ == "__main__":
test_low_level_zero_checkpointIO()