mirror of https://github.com/hpcaitech/ColossalAI
[zero] hotfix master param sync (#4618)
* [zero] add method to update master params * [zero] update zero plugin * [plugin] update low level zero pluginpull/4612/head^2
parent
aaeb520ce3
commit
807e01a4ba
|
@ -3,6 +3,7 @@ import os
|
|||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Callable, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -25,9 +26,9 @@ from colossalai.checkpoint_io.utils import (
|
|||
sharded_optimizer_loading_epilogue,
|
||||
unwrap_optimizer,
|
||||
)
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import LowLevelZeroOptimizer, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
from .dp_plugin_base import DPPluginBase
|
||||
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
||||
|
@ -44,6 +45,34 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|||
SUPPORTED_PRECISION = ['fp16', 'bf16', 'fp32']
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
|
||||
def __init__(self, module: nn.Module, precision: str) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == 'fp16':
|
||||
self.dtype = torch.float16
|
||||
elif precision == 'bf16':
|
||||
self.dtype = torch.bfloat16
|
||||
if self.dtype is not None:
|
||||
module = module.to(self.dtype)
|
||||
module = module.to(get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
# TODO(ver217): this is a workaround for loading model
|
||||
return self
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
|
@ -165,30 +194,36 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def save_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool,
|
||||
use_safetensors: bool):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper):
|
||||
def save_sharded_model(self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_sharded_model(model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size,
|
||||
use_safetensors)
|
||||
|
||||
def __init__(self, module: nn.Module, stage: int, precision: str) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == 'fp16':
|
||||
self.dtype = torch.float16
|
||||
elif precision == 'bf16':
|
||||
self.dtype = torch.bfloat16
|
||||
module = zero_model_wrapper(module, zero_stage=stage)
|
||||
if self.dtype is not None:
|
||||
module = module.to(self.dtype)
|
||||
module = module.to(get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_unsharded_model(model.module, checkpoint, strict)
|
||||
model.update_master_params()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.convert_fn is not None:
|
||||
args = tree_map(self.convert_fn, args)
|
||||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
def load_sharded_model(self,
|
||||
model: LowLevelZeroModel,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
|
||||
class LowLevelZeroPlugin(DPPluginBase):
|
||||
|
@ -248,22 +283,24 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
super().__init__()
|
||||
assert stage in (1, 2), f'LowLevelZeroPlugin only supports stage 1/2 training'
|
||||
assert precision in SUPPORTED_PRECISION, f'LowLevelZeroPlugin only supports {SUPPORTED_PRECISION} training'
|
||||
|
||||
assert norm_type == 2.0, f'LowLevelZeroPlugin only supports norm_type=2.0 now'
|
||||
self.stage = stage
|
||||
self.precision = precision
|
||||
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload)
|
||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type)
|
||||
self.zero_optim_kwargs = dict(
|
||||
initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
clip_grad_norm=max_norm,
|
||||
reduce_bucket_size=reduce_bucket_size_in_m * 1024 * 1024,
|
||||
communication_dtype=communication_dtype,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=cpu_offload,
|
||||
partition_grad=(stage == 2),
|
||||
)
|
||||
self.verbose = verbose
|
||||
|
||||
# set class name with stage, for better error message
|
||||
|
@ -294,15 +331,15 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(model, self.stage, self.precision)
|
||||
model = LowLevelZeroModel(model, self.precision)
|
||||
|
||||
if optimizer is not None and \
|
||||
not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = zero_optim_wrapper(model.unwrap(),
|
||||
optimizer,
|
||||
optim_config=self.zero_optim_config,
|
||||
**self.optim_kwargs,
|
||||
verbose=self.verbose)
|
||||
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(optimizer,
|
||||
**self.zero_optim_kwargs,
|
||||
verbose=self.verbose)
|
||||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .model import ModelWrapper
|
||||
from .model import AMPModelMixin, ModelWrapper
|
||||
from .optimizer import OptimizerWrapper
|
||||
|
||||
__all__ = ['OptimizerWrapper', 'ModelWrapper']
|
||||
__all__ = ['OptimizerWrapper', 'ModelWrapper', 'AMPModelMixin']
|
||||
|
|
|
@ -23,3 +23,14 @@ class ModelWrapper(nn.Module):
|
|||
|
||||
def forward(self, *args, **kwargs):
|
||||
return self.module(*args, **kwargs)
|
||||
|
||||
|
||||
class AMPModelMixin:
|
||||
"""This mixin class defines the interface for AMP training.
|
||||
"""
|
||||
|
||||
def update_master_params(self):
|
||||
"""
|
||||
Update the master parameters for AMP training.
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import Dict, Iterator, Optional, Tuple
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -600,3 +601,19 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
ret_block_size += current_block_size
|
||||
|
||||
yield ret_block, ret_block_size
|
||||
|
||||
def update_master_params(self, model: nn.Module) -> None:
|
||||
"""Update master params from working params
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to update master params
|
||||
"""
|
||||
for p in model.parameters():
|
||||
p_id = id(p)
|
||||
if p_id in self._param_store.working_to_master_param:
|
||||
master_param = self._param_store.working_to_master_param[p_id]
|
||||
padding_size = self._param_store.get_param_padding_size(p)
|
||||
working_param = p.data.view(-1)
|
||||
if padding_size > 0:
|
||||
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
|
|
|
@ -14,6 +14,7 @@ from colossalai.testing import (
|
|||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
||||
|
||||
# stage 1 and 2 process the optimizer/mode the same way
|
||||
|
@ -50,6 +51,17 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
|
|||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
# check master weight
|
||||
assert isinstance(new_optimizer, LowLevelZeroOptimizer)
|
||||
working_param_id_set = set(id(p) for p in new_model.parameters())
|
||||
for p_id, master_param in new_optimizer._param_store.working_to_master_param.items():
|
||||
assert p_id in working_param_id_set
|
||||
working_param = new_optimizer._param_store.master_to_working_param[id(master_param)]
|
||||
padding = new_optimizer._param_store.get_param_padding_size(working_param)
|
||||
padded_param = torch.nn.functional.pad(working_param.data.view(-1), (0, padding))
|
||||
working_shard = padded_param.chunk(dist.get_world_size())[dist.get_rank()]
|
||||
assert torch.equal(working_shard,
|
||||
master_param.data.view(-1).to(dtype=padded_param.dtype, device=padded_param.device))
|
||||
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
|
||||
|
|
Loading…
Reference in New Issue