mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (#4758)
* fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugspull/4659/merge
parent
7b9b86441f
commit
c0a033700c
|
@ -2,7 +2,7 @@ from typing import Dict, List
|
|||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import Module, Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
@ -152,3 +152,18 @@ class MixedPrecisionOptimizer(OptimizerWrapper):
|
|||
if p is working_param:
|
||||
continue
|
||||
working_param.data.copy_(p.data)
|
||||
|
||||
def update_master_params(self, model: Module):
|
||||
# Update master params from working params
|
||||
with torch.no_grad():
|
||||
for p in model.parameters():
|
||||
if (p is None) or (p not in self.working_to_master_map):
|
||||
continue
|
||||
master_param = self.working_to_master_map[p]
|
||||
master_param.data.copy_(p.data)
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return {id(working_p): master_p for working_p, master_p in self.working_to_master_map.items()}
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
return {id(master_p): working_p for master_p, working_p in self.master_to_working_map.items()}
|
||||
|
|
|
@ -139,7 +139,7 @@ class Booster:
|
|||
|
||||
if self.plugin and not self.plugin.control_device():
|
||||
# transform model for accelerator
|
||||
model = self.accelerator.configure(model)
|
||||
model = self.accelerator.configure_model(model)
|
||||
|
||||
if self.mixed_precision and (self.plugin is None or self.plugin and not self.plugin.control_precision()):
|
||||
# transform model for mixed precision
|
||||
|
|
|
@ -44,6 +44,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
||||
state_dict = model.state_dict(only_rank_0=True)
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors)
|
||||
|
@ -53,24 +54,27 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
Load model from checkpoint with automatic unwrapping.
|
||||
The model should be unwrapped in self.load_model via ModelWrapper.unwrap.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save unsharded optimizer state dict to checkpoint.
|
||||
After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank.
|
||||
As there is communication when getting state dict, optimizer.state_dict() must be called on all processes.
|
||||
The saving process will only be executed by master rank.
|
||||
"""
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||
state_dict = optimizer.state_dict()
|
||||
if self.coordinator.is_master():
|
||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
def load_unsharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint: str):
|
||||
"""
|
||||
Loading unsharded optimizer from checkpoint file.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_sharded_model(
|
||||
|
@ -86,6 +90,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
Save sharded model.
|
||||
As there is communication when getting state dict, model.state_dict() must be called on all processes.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before saving!"
|
||||
if os.path.isfile(checkpoint_path):
|
||||
logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file")
|
||||
return
|
||||
|
@ -111,7 +116,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
if self.coordinator.is_master():
|
||||
index_file.append_meta_data("total_size", total_size)
|
||||
index_file.write_index_file(save_index_file)
|
||||
save_config_file(model.module, checkpoint_path)
|
||||
save_config_file(model.unwrap(), checkpoint_path)
|
||||
logging.info(
|
||||
f"The model is split into checkpoint shards. "
|
||||
f"You can find where each parameters has been saved in the "
|
||||
|
@ -124,17 +129,17 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Load shard model, load model from multiple files.
|
||||
"""
|
||||
assert isinstance(model, GeminiDDP), "Please boost the model before loading!"
|
||||
return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
self, optimizer: GeminiOptimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, size_per_shard: int
|
||||
):
|
||||
"""
|
||||
Save sharded optimizer state dict to checkpoint folder.
|
||||
As there is communication when getting state dict, this must be called on all processes.
|
||||
"""
|
||||
|
||||
assert isinstance(optimizer, GeminiOptimizer)
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before saving!"
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
|
@ -176,12 +181,12 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
f"index located at {save_index_file}."
|
||||
)
|
||||
|
||||
def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str):
|
||||
def load_sharded_optimizer(self, optimizer: GeminiOptimizer, checkpoint_index_file: Path, prefix: str):
|
||||
"""
|
||||
Loading sharded optimizer from checkpoint folder, with index file given.
|
||||
For each process, only loading optimizer states of parameters it controls.
|
||||
"""
|
||||
|
||||
assert isinstance(optimizer, GeminiOptimizer), "Please boost the optimizer before loading!"
|
||||
if not os.path.isfile(checkpoint_index_file):
|
||||
logging.error(f"Provided path ({checkpoint_index_file}) should be a file")
|
||||
|
||||
|
@ -383,7 +388,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = GeminiOptimizer(
|
||||
optimizer, model.unwrap(), **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
|
||||
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose
|
||||
)
|
||||
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import random
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -165,6 +166,15 @@ class HybridParallelNaiveOptimizer(OptimizerWrapper):
|
|||
init_pipeline_optimizer(optim, model)
|
||||
super().__init__(optim)
|
||||
|
||||
def update_master_params(self, model: Module):
|
||||
pass
|
||||
|
||||
def get_working_to_master_map(self):
|
||||
return None
|
||||
|
||||
def get_master_to_working_map(self):
|
||||
return None
|
||||
|
||||
|
||||
class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
|
||||
def __init__(
|
||||
|
@ -466,9 +476,6 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
max_norm=self.max_norm,
|
||||
**self.amp_config,
|
||||
)
|
||||
self.checkpoint_io.link_master_and_working_param(
|
||||
optimizer.working_to_master_map, optimizer.master_to_working_map
|
||||
)
|
||||
else:
|
||||
optimizer = HybridParallelNaiveOptimizer(
|
||||
optimizer, model, use_pipeline=self.enable_pipeline_parallelism, param_info=param_info
|
||||
|
@ -488,10 +495,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
**self.zero_config,
|
||||
**self.amp_config,
|
||||
)
|
||||
self.checkpoint_io.link_master_and_working_param(
|
||||
optimizer._param_store.working_to_master_param, optimizer._param_store.master_to_working_param
|
||||
)
|
||||
|
||||
# inject update_master_params
|
||||
model.update_master_params = MethodType(optimizer.update_master_params, model)
|
||||
return model, optimizer, criterion, dataloader, lr_scheduler
|
||||
|
||||
def execute_pipeline(
|
||||
|
@ -567,8 +572,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
)
|
||||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
self.checkpoint_io = HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
return self.checkpoint_io
|
||||
return HybridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
|
||||
|
||||
def no_sync(self, model: Module) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -22,7 +22,6 @@ from colossalai.checkpoint_io.utils import (
|
|||
save_param_groups,
|
||||
save_state_dict,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
unwrap_optimizer,
|
||||
)
|
||||
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -65,10 +64,6 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
kwargs = tree_map(self.convert_fn, kwargs)
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
def unwrap(self):
|
||||
# TODO(ver217): this is a workaround for loading model
|
||||
return self
|
||||
|
||||
|
||||
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
||||
|
@ -79,7 +74,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
checkpoint (str): Path to save checkpoint
|
||||
gather_dtensor (bool): Whether to gather_dtensor, not used
|
||||
"""
|
||||
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
||||
# the `state_dict` in LowLevelZeroOptimizer has communication
|
||||
# if only the master rank collect state_dict and save,
|
||||
# the communication on each rank would not match
|
||||
|
@ -109,6 +104,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file that store state tensors
|
||||
"""
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
@ -160,9 +156,8 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
index_file_path (str): Path to the index file
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer), "Please boost the optimizer before Loading!"
|
||||
optimizer = optimizer.unwrap()
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
@ -194,44 +189,23 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
v_list = v.split(v.numel() // self.coordinator.world_size)
|
||||
state_dict[param_idx][k] = v_list[self.coordinator.rank].detach().clone()
|
||||
load_states_into_optimizer(optimizer, state_dict, id_map)
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer)
|
||||
|
||||
def save_unsharded_model(
|
||||
self, model: LowLevelZeroModel, checkpoint: str, gather_dtensor: bool, use_safetensors: bool
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_unsharded_model(model.module, checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
max_shard_size: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().save_sharded_model(
|
||||
model.module, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
)
|
||||
|
||||
def load_unsharded_model(self, model: LowLevelZeroModel, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_unsharded_model(model.module, checkpoint, strict)
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model, checkpoint, strict)
|
||||
model.update_master_params()
|
||||
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: LowLevelZeroModel,
|
||||
model: ModelWrapper,
|
||||
checkpoint_index_file: Path,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
):
|
||||
assert isinstance(model, LowLevelZeroModel)
|
||||
super().load_sharded_model(model.module, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
||||
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
model.update_master_params()
|
||||
|
||||
|
||||
|
|
|
@ -20,24 +20,33 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
||||
"""
|
||||
Load model from checkpoint with automatic unwrapping.
|
||||
Load model from checkpoint.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
return super().load_unsharded_model(model, checkpoint, strict=strict)
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_unsharded_model(model.unwrap(), checkpoint, strict=strict)
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
||||
super().save_unsharded_model(model.unwrap(), checkpoint, gather_dtensor, use_safetensors)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str):
|
||||
"""
|
||||
Load optimizer from checkpoint.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_unsharded_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
|
@ -50,7 +59,7 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model: ModelWrapper,
|
||||
checkpoint_path: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
|
@ -60,22 +69,52 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_model(model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors)
|
||||
super().save_sharded_model(
|
||||
model.unwrap(), checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
||||
)
|
||||
|
||||
def load_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint_index_file: str,
|
||||
strict: bool = False,
|
||||
use_safetensors: bool = False,
|
||||
load_sub_module: bool = True,
|
||||
):
|
||||
"""
|
||||
Load model from sharded checkpoint.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
super().load_sharded_model(model.unwrap(), checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
||||
|
||||
def save_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
optimizer: OptimizerWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
Save optimizer to sharded checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if self.coordinator.is_master():
|
||||
super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
super().save_sharded_optimizer(optimizer.unwrap(), checkpoint, gather_dtensor, prefix, size_per_shard)
|
||||
|
||||
def load_sharded_optimizer(
|
||||
self,
|
||||
optimizer: Optimizer,
|
||||
index_file_path: str,
|
||||
prefix: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Load optimizer from sharded checkpoint.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
|
|
|
@ -39,31 +39,35 @@ class TorchFSDPCheckpointIO(GeneralCheckpointIO):
|
|||
super().__init__()
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool):
|
||||
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool):
|
||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before loading!"
|
||||
model = model.unwrap()
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
model.load_state_dict(checkpoint)
|
||||
|
||||
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: Path):
|
||||
def load_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: Path):
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
checkpoint = utils.load_state_dict(checkpoint)
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
sharded_osd = FSDP.scatter_full_optim_state_dict(checkpoint, fsdp_model)
|
||||
optimizer.load_state_dict(sharded_osd)
|
||||
|
||||
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
"""
|
||||
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
|
||||
assert isinstance(model, TorchFSDPModel), "Please boost the model before saving!"
|
||||
model = model.unwrap()
|
||||
cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, cfg):
|
||||
full_model_state = model.state_dict()
|
||||
utils.save_state_dict(full_model_state, checkpoint_file_path=checkpoint, use_safetensors=use_safetensors)
|
||||
|
||||
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
|
||||
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool):
|
||||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper)
|
||||
assert isinstance(optimizer, FSDPOptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
fsdp_model = optimizer.unwrap_model()
|
||||
full_optimizer_state = FSDP.full_optim_state_dict(fsdp_model, optim=optimizer, rank0_only=True)
|
||||
utils.save_state_dict(full_optimizer_state, checkpoint_file_path=checkpoint, use_safetensors=False)
|
||||
|
|
|
@ -87,9 +87,6 @@ class CheckpointIO(ABC):
|
|||
# return the origin model instead of the unwrapped model
|
||||
origin_model = model
|
||||
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
||||
if index_file_exists:
|
||||
self.load_sharded_model(model, index_file_path, strict)
|
||||
else:
|
||||
|
@ -134,9 +131,6 @@ class CheckpointIO(ABC):
|
|||
use_safetensors (bool): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved
|
||||
"""
|
||||
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
|
||||
if shard:
|
||||
self.save_sharded_model(model, checkpoint, gather_dtensor, prefix, size_per_shard, use_safetensors)
|
||||
else:
|
||||
|
|
|
@ -8,8 +8,6 @@ from typing import Optional
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
|
||||
from .checkpoint_io_base import CheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
from .utils import (
|
||||
|
@ -28,7 +26,6 @@ from .utils import (
|
|||
shard_model_checkpoint,
|
||||
shard_optimizer_checkpoint,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
unwrap_optimizer,
|
||||
)
|
||||
|
||||
__all__ = ["GeneralCheckpointIO"]
|
||||
|
@ -58,10 +55,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
Load sharded optimizer with the given path to index file.
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(index_file_path)
|
||||
|
||||
|
@ -98,10 +91,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
- Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way
|
||||
"""
|
||||
|
||||
# If optimizer is wrapped, unwrap it.
|
||||
if isinstance(optimizer, OptimizerWrapper):
|
||||
optimizer = unwrap_optimizer(optimizer)
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import os
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
|
||||
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -13,7 +13,7 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
from .index_file import CheckpointIndexFile
|
||||
|
@ -71,8 +71,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
self.tp_size = dist.get_world_size(tp_group)
|
||||
self.use_zero = zero_stage > 0
|
||||
self.verbose = verbose
|
||||
self.working_to_master_map = None
|
||||
self.master_to_working_map = None
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
@staticmethod
|
||||
|
@ -159,7 +157,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
|
@ -184,6 +182,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
|
||||
"""
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model = model.unwrap()
|
||||
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
@ -279,7 +280,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
f"index located at {final_index_file_path}."
|
||||
)
|
||||
|
||||
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
|
||||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
|
@ -289,6 +290,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model_before_wrapping = model # backup for model before wrapping
|
||||
model = model.unwrap()
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
|
@ -347,23 +351,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
_load(extra_state_key)
|
||||
|
||||
# Update master params if mixed-precision training is enabled.
|
||||
with torch.no_grad():
|
||||
if self.working_to_master_map is not None:
|
||||
for param in model.parameters():
|
||||
if (param is None) or (id(param) not in self.working_to_master_map):
|
||||
continue
|
||||
master_param = self.working_to_master_map[id(param)]
|
||||
if self.use_zero:
|
||||
# master_param is sharded under Zero setting
|
||||
padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
|
||||
if padding_size > 0:
|
||||
padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
||||
else:
|
||||
padded_param = param.data.view(-1)
|
||||
sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
|
||||
master_param.data.copy_(sharded_param.data)
|
||||
else:
|
||||
master_param.data.copy_(param.data)
|
||||
model_before_wrapping.update_master_params()
|
||||
|
||||
if self.verbose:
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
@ -392,6 +380,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
prefix (str): Perfix of file to save
|
||||
size_per_shard (int): Max file size of each file shard that store state tensors
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before saving!"
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
|
@ -410,7 +399,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
use_zero=self.use_zero,
|
||||
dp_group=self.dp_group,
|
||||
tp_group=self.tp_group,
|
||||
master_to_working_map=self.master_to_working_map,
|
||||
master_to_working_map=optimizer.get_master_to_working_map(),
|
||||
size_per_shard=size_per_shard,
|
||||
)
|
||||
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
|
||||
|
@ -511,6 +500,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
prefix (str): Not used.
|
||||
"""
|
||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
|
||||
def _get_param_id_from_optimizer_param(
|
||||
param: torch.Tensor, master_to_working_map: Optional[Dict[int, torch.Tensor]] = None
|
||||
|
@ -525,9 +515,10 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
|
||||
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
|
||||
id_map = {}
|
||||
master_to_working_map = optimizer.get_master_to_working_map()
|
||||
for pg in optimizer.optim.param_groups:
|
||||
for param in pg["params"]:
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
id_map[param_id] = param
|
||||
|
||||
# Read checkpoint index file.
|
||||
|
@ -560,7 +551,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
for param in pg["params"]:
|
||||
if param is None:
|
||||
continue
|
||||
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
|
||||
param_id = _get_param_id_from_optimizer_param(param, master_to_working_map)
|
||||
if param_id not in weight_map:
|
||||
continue
|
||||
filename = weight_map[param_id]
|
||||
|
@ -577,8 +568,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if self.master_to_working_map is not None:
|
||||
working_param = self.master_to_working_map[id(param)]
|
||||
if master_to_working_map is not None:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
|
@ -614,42 +605,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if self.coordinator.is_master():
|
||||
super().save_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def link_master_and_working_param(
|
||||
self,
|
||||
working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
|
||||
This mapping can only be created when mixied precision is used.
|
||||
The created mappings should be mappings from integer parameter addresses to parameter objects.
|
||||
|
||||
Args:
|
||||
working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
|
||||
master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
|
||||
"""
|
||||
self.working_to_master_map = dict()
|
||||
for k, v in working_to_master_map.items():
|
||||
if isinstance(k, torch.Tensor):
|
||||
self.working_to_master_map[id(k)] = v
|
||||
elif isinstance(k, int):
|
||||
self.working_to_master_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
||||
)
|
||||
|
||||
self.master_to_working_map = dict()
|
||||
for k, v in master_to_working_map.items():
|
||||
if isinstance(k, torch.Tensor):
|
||||
self.master_to_working_map[id(k)] = v
|
||||
elif isinstance(k, int):
|
||||
self.master_to_working_map[k] = v
|
||||
else:
|
||||
raise ValueError(
|
||||
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def gather_from_sharded_optimizer_state(
|
||||
state: OrderedDict,
|
||||
|
|
|
@ -11,7 +11,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.tensor.d_tensor import (
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
|
@ -122,14 +121,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
|
|||
# ======================================
|
||||
# Helper classes and functions for saving shard file
|
||||
# ======================================
|
||||
def unwrap_optimizer(optimizer: OptimizerWrapper):
|
||||
"""
|
||||
Unwrap a wrapped optimizer.
|
||||
This method should be used before saving/loading it to/from sharded checkpoints.
|
||||
"""
|
||||
|
||||
unwrapped_optim = optimizer.optim
|
||||
return unwrapped_optim
|
||||
|
||||
|
||||
class StateDictSharder:
|
||||
|
|
|
@ -186,10 +186,6 @@ class GeminiDDP(ModelWrapper):
|
|||
for p in params_to_ignore:
|
||||
p._ddp_to_ignore = True
|
||||
|
||||
def unwrap(self):
|
||||
# as save/load state dict is overwrited, only return self
|
||||
return self
|
||||
|
||||
def _get_non_persistent_buffers_set(
|
||||
self, module, memo: Optional[Set[nn.Module]] = None, prefix: str = "", remove_duplicate: bool = True
|
||||
):
|
||||
|
|
|
@ -648,3 +648,9 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if padding_size > 0:
|
||||
working_param = torch.nn.functional.pad(working_param, [0, padding_size])
|
||||
master_param.copy_(working_param.chunk(self._world_size)[self._local_rank])
|
||||
|
||||
def get_working_to_master_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.working_to_master_param
|
||||
|
||||
def get_master_to_working_map(self) -> Dict[int, torch.Tensor]:
|
||||
return self._param_store.master_to_working_param
|
||||
|
|
|
@ -61,9 +61,7 @@ def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per
|
|||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
if plugin_type == "gemini":
|
||||
check_state_dict_equal(
|
||||
model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False
|
||||
)
|
||||
check_state_dict_equal(model.state_dict(only_rank_0=False), new_model.state_dict(only_rank_0=False), False)
|
||||
else:
|
||||
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
|
||||
dist.barrier()
|
||||
|
|
Loading…
Reference in New Issue