[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)

* fix master param sync for hybrid plugin

* rewrite unwrap for ddp/fsdp

* rewrite unwrap for zero/gemini

* rewrite unwrap for hybrid plugin

* fix geemini unwrap

* fix bugs
pull/4659/merge
Baizhou Zhang 2023-09-20 18:29:37 +08:00 committed by GitHub
parent 7b9b86441f
commit c0a033700c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 141 additions and 171 deletions

View File

@ -2,7 +2,7 @@ from typing import Dict, List
import torch
from torch import Tensor
from torch.nn import Parameter
from torch.nn import Module, Parameter
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
@ -152,3 +152,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
if p is working_param:
continue
working_param.data.copy_(p.data)
def update_master_params(self, model: Module):
# Update master params from working params
with torch.no_grad():
for p in model.parameters():
if (p is None) or (p not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[p]
master_param.data.copy_(p.data)
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}

View File

@ -139,7 +139,7 @@ class Booster:
if self.plugin and not self.plugin.control_device():
# transform model for accelerator
model = self.accelerator.configure(model)
model = self.accelerator.configure_model(model)
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
# transform model for mixed precision

View File

@ -44,6 +44,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
state_dict = model.state_dict(only_rank_0=True)
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors)
@ -53,24 +54,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Load model from checkpoint with automatic unwrapping.
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict=strict)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
"""
Save unsharded optimizer state dict to checkpoint.
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
The saving process will only be executed by master rank.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
state_dict = optimizer.state_dict()
if self.coordinator.is_master():
save_state_dict(state_dict, checkpoint, use_safetensors=False)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
"""
Loading unsharded optimizer from checkpoint file.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_sharded_model(
@ -86,6 +90,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
Save sharded model.
As there is communication when getting state dict, model.state_dict() must be called on all processes.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
if os.path.isfile(checkpoint_path):
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
return
@ -111,7 +116,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
save_config_file(model.unwrap(), checkpoint_path)
logging.info(
f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
@ -124,17 +129,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
"""
Load shard model, load model from multiple files.
"""
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
def save_sharded_optimizer(
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
):
"""
Save sharded optimizer state dict to checkpoint folder.
As there is communication when getting state dict, this must be called on all processes.
"""
assert isinstance(optimizer, GeminiOptimizer)
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
@ -176,12 +181,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
f"index located at {save_index_file}."
)
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
"""
Loading sharded optimizer from checkpoint folder, with index file given.
For each process, only loading optimizer states of parameters it controls.
"""
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
if not os.path.isfile(checkpoint_index_file):
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
@ -383,7 +388,7 @@ class GeminiPlugin(DPPluginBase):
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer(
optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
)
return model, optimizer, criterion, dataloader, lr_scheduler

View File

@ -1,6 +1,7 @@
import random
from contextlib import nullcontext
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
@ -165,6 +166,15 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
init_pipeline_optimizer(optim, model)
super().__init__(optim)
def update_master_params(self, model: Module):
pass
def get_working_to_master_map(self):
return None
def get_master_to_working_map(self):
return None
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
def __init__(
@ -466,9 +476,6 @@ class HybridParallelPlugin(PipelinePluginBase):
max_norm=self.max_norm,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
optimizer.working_to_master_map, optimizer.master_to_working_map
)
else:
optimizer = HybridParallelNaiveOptimizer(
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
@ -488,10 +495,8 @@ class HybridParallelPlugin(PipelinePluginBase):
**self.zero_config,
**self.amp_config,
)
self.checkpoint_io.link_master_and_working_param(
optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param
)
# inject update_master_params
model.update_master_params = MethodType(optimizer.update_master_params, model)
return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(
@ -567,8 +572,7 @@ class HybridParallelPlugin(PipelinePluginBase):
)
def get_checkpoint_io(self) -> CheckpointIO:
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError

View File

@ -22,7 +22,6 @@ from colossalai.checkpoint_io.utils import (
save_param_groups,
save_state_dict,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
@ -65,10 +64,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
# TODO(ver217): this is a workaround for loading model
return self
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
@ -79,7 +74,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
checkpoint (str): Path to save checkpoint
gather_dtensor (bool): Whether to gather_dtensor, not used
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
# the `state_dict` in LowLevelZeroOptimizer has communication
# if only the master rank collect state_dict and save,
# the communication on each rank would not match
@ -109,6 +104,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file that store state tensors
"""
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@ -160,9 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
index_file_path (str): Path to the index file
prefix (str): Not used.
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
optimizer = optimizer.unwrap()
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@ -194,44 +189,23 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
v_list = v.split(v.numel() // self.coordinator.world_size)
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
load_states_into_optimizer(optimizer, state_dict, id_map)
sharded_optimizer_loading_epilogue(optimizer)
def save_unsharded_model(
self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
):
assert isinstance(model, LowLevelZeroModel)
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
def save_sharded_model(
self,
model: nn.Module,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
max_shard_size: int = 1024,
use_safetensors: bool = False,
):
assert isinstance(model, LowLevelZeroModel)
super().save_sharded_model(
model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel)
super().load_unsharded_model(model.module, checkpoint, strict)
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
super().load_unsharded_model(model, checkpoint, strict)
model.update_master_params()
def load_sharded_model(
self,
model: LowLevelZeroModel,
model: ModelWrapper,
checkpoint_index_file: Path,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
assert isinstance(model, LowLevelZeroModel)
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
model.update_master_params()

View File

@ -20,24 +20,33 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
"""
Load model from checkpoint with automatic unwrapping.
Load model from checkpoint.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super().load_unsharded_model(model, checkpoint, strict=strict)
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
"""
Load optimizer from checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_unsharded_optimizer(optimizer, checkpoint)
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
@ -50,7 +59,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
def save_sharded_model(
self,
model: nn.Module,
model: ModelWrapper,
checkpoint_path: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
@ -60,22 +69,52 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
"""
Save model to checkpoint but only on master process.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
super().save_sharded_model(
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
)
def load_sharded_model(
self,
model: ModelWrapper,
checkpoint_index_file: str,
strict: bool = False,
use_safetensors: bool = False,
load_sub_module: bool = True,
):
"""
Load model from sharded checkpoint.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
def save_sharded_optimizer(
self,
optimizer: Optimizer,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
):
"""
Save optimizer to checkpoint but only on master process.
Save optimizer to sharded checkpoint but only on master process.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if self.coordinator.is_master():
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
def load_sharded_optimizer(
self,
optimizer: Optimizer,
index_file_path: str,
prefix: Optional[str] = None,
):
"""
Load optimizer from sharded checkpoint.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
class TorchDDPModel(ModelWrapper):

View File

@ -39,31 +39,35 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
super().__init__()
self.coordinator = DistCoordinator()
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
model = model.unwrap()
checkpoint = utils.load_state_dict(checkpoint)
model.load_state_dict(checkpoint)
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
checkpoint = utils.load_state_dict(checkpoint)
fsdp_model = optimizer.unwrap_model()
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
optimizer.load_state_dict(sharded_osd)
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
"""
Save model to checkpoint but only on master process.
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
model = model.unwrap()
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
full_model_state = model.state_dict()
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
"""
Save optimizer to checkpoint but only on master process.
"""
assert isinstance(optimizer, FSDPOptimizerWrapper)
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
fsdp_model = optimizer.unwrap_model()
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)

View File

@ -87,9 +87,6 @@ class CheckpointIO(ABC):
# return the origin model instead of the unwrapped model
origin_model = model
if isinstance(model, ModelWrapper):
model = model.unwrap()
if index_file_exists:
self.load_sharded_model(model, index_file_path, strict)
else:
@ -134,9 +131,6 @@ class CheckpointIO(ABC):
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
"""
if isinstance(model, ModelWrapper):
model = model.unwrap()
if shard:
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
else:

View File

@ -8,8 +8,6 @@ from typing import Optional
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from .checkpoint_io_base import CheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
@ -28,7 +26,6 @@ from .utils import (
shard_model_checkpoint,
shard_optimizer_checkpoint,
sharded_optimizer_loading_epilogue,
unwrap_optimizer,
)
__all__ = ["GeneralCheckpointIO"]
@ -58,10 +55,6 @@ class GeneralCheckpointIO(CheckpointIO):
Load sharded optimizer with the given path to index file.
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
@ -98,10 +91,6 @@ class GeneralCheckpointIO(CheckpointIO):
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
"""
# If optimizer is wrapped, unwrap it.
if isinstance(optimizer, OptimizerWrapper):
optimizer = unwrap_optimizer(optimizer)
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return

View File

@ -3,7 +3,7 @@ import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
import torch
import torch.distributed as dist
@ -13,7 +13,7 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
@ -71,8 +71,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = zero_stage > 0
self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
self.coordinator = DistCoordinator()
@staticmethod
@ -159,7 +157,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
def save_sharded_model(
self,
model: nn.Module,
model: ModelWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
@ -184,6 +182,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
model = model.unwrap()
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@ -279,7 +280,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
f"index located at {final_index_file_path}."
)
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
@ -289,6 +290,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
model_before_wrapping = model # backup for model before wrapping
model = model.unwrap()
# Check whether the checkpoint uses safetensors.
use_safetensors = False
@ -347,23 +351,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
_load(extra_state_key)
# Update master params if mixed-precision training is enabled.
with torch.no_grad():
if self.working_to_master_map is not None:
for param in model.parameters():
if (param is None) or (id(param) not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[id(param)]
if self.use_zero:
# master_param is sharded under Zero setting
padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
if padding_size > 0:
padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padded_param = param.data.view(-1)
sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
master_param.data.copy_(sharded_param.data)
else:
master_param.data.copy_(param.data)
model_before_wrapping.update_master_params()
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
@ -392,6 +380,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
@ -410,7 +399,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_zero=self.use_zero,
dp_group=self.dp_group,
tp_group=self.tp_group,
master_to_working_map=self.master_to_working_map,
master_to_working_map=optimizer.get_master_to_working_map(),
size_per_shard=size_per_shard,
)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
@ -511,6 +500,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
def _get_param_id_from_optimizer_param(
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
@ -525,9 +515,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
master_to_working_map = optimizer.get_master_to_working_map()
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
id_map[param_id] = param
# Read checkpoint index file.
@ -560,7 +551,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
for param in pg["params"]:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
@ -577,8 +568,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if self.master_to_working_map is not None:
working_param = self.master_to_working_map[id(param)]
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
@ -614,42 +605,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def link_master_and_working_param(
self,
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
):
"""
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
This mapping can only be created when mixied precision is used.
The created mappings should be mappings from integer parameter addresses to parameter objects.
Args:
working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
"""
self.working_to_master_map = dict()
for k, v in working_to_master_map.items():
if isinstance(k, torch.Tensor):
self.working_to_master_map[id(k)] = v
elif isinstance(k, int):
self.working_to_master_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
self.master_to_working_map = dict()
for k, v in master_to_working_map.items():
if isinstance(k, torch.Tensor):
self.master_to_working_map[id(k)] = v
elif isinstance(k, int):
self.master_to_working_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
)
@staticmethod
def gather_from_sharded_optimizer_state(
state: OrderedDict,

View File

@ -11,7 +11,6 @@ import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
@ -122,14 +121,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
# ======================================
# Helper classes and functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
"""
Unwrap a wrapped optimizer.
This method should be used before saving/loading it to/from sharded checkpoints.
"""
unwrapped_optim = optimizer.optim
return unwrapped_optim
class StateDictSharder:

View File

@ -186,10 +186,6 @@ class GeminiDDP(ModelWrapper):
for p in params_to_ignore:
p._ddp_to_ignore = True
def unwrap(self):
# as save/load state dict is overwrited, only return self
return self
def _get_non_persistent_buffers_set(
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
):

View File

@ -648,3 +648,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if padding_size > 0:
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.working_to_master_param
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
return self._param_store.master_to_working_param

View File

@ -61,9 +61,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
if plugin_type == "gemini":
check_state_dict_equal(
model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False
)
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
else:
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
dist.barrier()