[zero] hotfix master param sync (#4618)

* [zero] add method to update master params

* [zero] update zero plugin

* [plugin] update low level zero plugin
pull/4612/head^2
Hongxin Liu 2023-09-05 15:04:02 +08:00 committed by GitHub
parent aaeb520ce3
commit 807e01a4ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 122 additions and 45 deletions

View File

@ -3,6 +3,7 @@ import os
import warnings import warnings
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple, Union from typing import Callable, Iterator, List, Optional, Tuple, Union
import torch import torch
@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
sharded_optimizer_loading_epilogue, sharded_optimizer_loading_epilogue,
unwrap_optimizer, unwrap_optimizer,
) )
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper from colossalai.zero import LowLevelZeroOptimizer
from .dp_plugin_base import DPPluginBase from .dp_plugin_base import DPPluginBase
from .torch_ddp_plugin import TorchDDPCheckpointIO from .torch_ddp_plugin import TorchDDPCheckpointIO
@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32'] SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(self, module: nn.Module, precision: str) -> None:
super().__init__(module)
self.dtype = None
if precision == 'fp16':
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO): class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False): def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
sharded_optimizer_loading_epilogue(optimizer) sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
use_safetensors: bool):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
class LowLevelZeroModel(ModelWrapper): def save_sharded_model(self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
use_safetensors)
def __init__(self, module: nn.Module, stage: int, precision: str) -> None: def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
super().__init__(module) assert isinstance(model, LowLevelZeroModel)
self.dtype = None super().load_unsharded_model(model.module, checkpoint, strict)
if precision == 'fp16': model.update_master_params()
self.dtype = torch.float16
elif precision == 'bf16':
self.dtype = torch.bfloat16
module = zero_model_wrapper(module, zero_stage=stage)
if self.dtype is not None:
module = module.to(self.dtype)
module = module.to(get_current_device())
self.module = module
self.convert_fn = None
if self.dtype is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
def forward(self, *args, **kwargs): def load_sharded_model(self,
if self.convert_fn is not None: model: LowLevelZeroModel,
args = tree_map(self.convert_fn, args) checkpoint_index_file: Path,
kwargs = tree_map(self.convert_fn, kwargs) strict: bool = False,
return super().forward(*args, **kwargs) use_safetensors: bool = False,
load_sub_module: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()
class LowLevelZeroPlugin(DPPluginBase): class LowLevelZeroPlugin(DPPluginBase):
@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
super().__init__() super().__init__()
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training' assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training' assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
self.stage = stage self.stage = stage
self.precision = precision self.precision = precision
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024, self.zero_optim_kwargs = dict(
communication_dtype=communication_dtype, initial_scale=initial_scale,
overlap_communication=overlap_communication, growth_factor=growth_factor,
cpu_offload=cpu_offload) backoff_factor=backoff_factor,
self.optim_kwargs = dict(initial_scale=initial_scale, growth_interval=growth_interval,
growth_factor=growth_factor, hysteresis=hysteresis,
backoff_factor=backoff_factor, min_scale=min_scale,
growth_interval=growth_interval, max_scale=max_scale,
hysteresis=hysteresis, clip_grad_norm=max_norm,
min_scale=min_scale, reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
max_scale=max_scale, communication_dtype=communication_dtype,
max_norm=max_norm, overlap_communication=overlap_communication,
norm_type=norm_type) cpu_offload=cpu_offload,
partition_grad=(stage == 2),
)
self.verbose = verbose self.verbose = verbose
# set class name with stage, for better error message # set class name with stage, for better error message
@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
model = LowLevelZeroModel(model, self.stage, self.precision) model = LowLevelZeroModel(model, self.precision)
if optimizer is not None and \ if optimizer is not None and \
not isinstance(optimizer, OptimizerWrapper): not isinstance(optimizer, OptimizerWrapper):
optimizer = zero_optim_wrapper(model.unwrap(), optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
optimizer, **self.zero_optim_kwargs,
optim_config=self.zero_optim_config, verbose=self.verbose)
**self.optim_kwargs, # inject update_master_params
verbose=self.verbose) model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -1,4 +1,4 @@
from .model import ModelWrapper from .model import AMPModelMixin, ModelWrapper
from .optimizer import OptimizerWrapper from .optimizer import OptimizerWrapper
__all__ = ['OptimizerWrapper', 'ModelWrapper'] __all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']

View File

@ -23,3 +23,14 @@ class ModelWrapper(nn.Module):
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
return self.module(*args, **kwargs) return self.module(*args, **kwargs)
class AMPModelMixin:
"""This mixin class defines the interface for AMP training.
"""
def update_master_params(self):
"""
Update the master parameters for AMP training.
"""
pass

View File

@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
@ -600,3 +601,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
ret_block_size += current_block_size ret_block_size += current_block_size
yield ret_block, ret_block_size yield ret_block, ret_block_size
def update_master_params(self, model: nn.Module) -> None:
"""Update master params from working params
Args:
model (nn.Module): The model to update master params
"""
for p in model.parameters():
p_id = id(p)
if p_id in self._param_store.working_to_master_param:
master_param = self._param_store.working_to_master_param[p_id]
padding_size = self._param_store.get_param_padding_size(p)
working_param = p.data.view(-1)
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])

View File

@ -14,6 +14,7 @@ from colossalai.testing import (
rerun_if_address_is_in_use, rerun_if_address_is_in_use,
spawn, spawn,
) )
from colossalai.zero import LowLevelZeroOptimizer
# stage 1 and 2 process the optimizer/mode the same way # stage 1 and 2 process the optimizer/mode the same way
@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
booster.load_model(new_model, model_ckpt_path) booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False) check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
# check master weight
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
working_param_id_set = set(id(p) for p in new_model.parameters())
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
assert p_id in working_param_id_set
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
padding = new_optimizer._param_store.get_param_padding_size(working_param)
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
assert torch.equal(working_shard,
master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device))
booster.load_optimizer(new_optimizer, optimizer_ckpt_path) booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False) check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)