|
|
|
@ -9,6 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
|
from colossalai.checkpoint_io import GeneralCheckpointIO |
|
|
|
|
from colossalai.interface import ModelWrapper |
|
|
|
|
|
|
|
|
|
from .accelerator import Accelerator |
|
|
|
|
from .mixed_precision import MixedPrecision, mixed_precision_factory |
|
|
|
@ -165,11 +166,11 @@ class Booster:
|
|
|
|
|
assert self.plugin.support_no_sync, f'The plugin {self.plugin.__class__.__name__} does not support no_sync.' |
|
|
|
|
return self.plugin.no_sync(model) |
|
|
|
|
|
|
|
|
|
def load_model(self, model: nn.Module, checkpoint: str, strict: bool = True): |
|
|
|
|
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True): |
|
|
|
|
"""Load model from checkpoint. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
model (nn.Module): A model boosted by Booster. |
|
|
|
|
model (nn.Module or ModelWrapper): A model boosted by Booster. |
|
|
|
|
checkpoint (str): Path to the checkpoint. It must be a local path. |
|
|
|
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. |
|
|
|
|
strict (bool, optional): whether to strictly enforce that the keys |
|
|
|
@ -179,24 +180,34 @@ class Booster:
|
|
|
|
|
self.checkpoint_io.load_model(model, checkpoint, strict) |
|
|
|
|
|
|
|
|
|
def save_model(self, |
|
|
|
|
model: nn.Module, |
|
|
|
|
model: Union[nn.Module, ModelWrapper], |
|
|
|
|
checkpoint: str, |
|
|
|
|
prefix: str = None, |
|
|
|
|
shard: bool = False, |
|
|
|
|
size_per_shard: int = 1024): |
|
|
|
|
gather_dtensor: bool = True, |
|
|
|
|
prefix: Optional[str] = None, |
|
|
|
|
size_per_shard: int = 1024, |
|
|
|
|
use_safetensors: bool = False): |
|
|
|
|
"""Save model to checkpoint. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
model (nn.Module): A model boosted by Booster. |
|
|
|
|
model (nn.Module or ModelWrapper): A model boosted by Booster. |
|
|
|
|
checkpoint (str): Path to the checkpoint. It must be a local path. |
|
|
|
|
It is a file path if ``shard=False``. Otherwise, it is a directory path. |
|
|
|
|
prefix (str, optional): A prefix added to parameter and buffer |
|
|
|
|
names to compose the keys in state_dict. Defaults to None. |
|
|
|
|
shard (bool, optional): Whether to save checkpoint a sharded way. |
|
|
|
|
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. |
|
|
|
|
gather_dtensor (bool, optional): whether to gather the distributed tensor to the first device. Default: True. |
|
|
|
|
prefix (str, optional): A prefix added to parameter and buffer |
|
|
|
|
names to compose the keys in state_dict. Defaults to None. |
|
|
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. |
|
|
|
|
use_safetensors (bool, optional): whether to use safe tensors. Default: False. If set to True, the checkpoint will be saved. |
|
|
|
|
""" |
|
|
|
|
self.checkpoint_io.save_model(model, checkpoint=checkpoint, shard=shard, size_per_shard=size_per_shard) |
|
|
|
|
self.checkpoint_io.save_model(model, |
|
|
|
|
checkpoint=checkpoint, |
|
|
|
|
shard=shard, |
|
|
|
|
gather_dtensor=gather_dtensor, |
|
|
|
|
prefix=prefix, |
|
|
|
|
size_per_shard=size_per_shard, |
|
|
|
|
use_safetensors=use_safetensors) |
|
|
|
|
|
|
|
|
|
def load_optimizer(self, optimizer: Optimizer, checkpoint: str): |
|
|
|
|
"""Load optimizer from checkpoint. |
|
|
|
@ -205,12 +216,21 @@ class Booster:
|
|
|
|
|
optimizer (Optimizer): An optimizer boosted by Booster. |
|
|
|
|
checkpoint (str): Path to the checkpoint. It must be a local path. |
|
|
|
|
It should be a directory path if the checkpoint is sharded. Otherwise, it should be a file path. |
|
|
|
|
prefix (str, optional): A prefix added to parameter and buffer |
|
|
|
|
names to compose the keys in state_dict. Defaults to None. |
|
|
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. |
|
|
|
|
""" |
|
|
|
|
self.checkpoint_io.load_optimizer(optimizer, checkpoint) |
|
|
|
|
|
|
|
|
|
def save_optimizer(self, optimizer: Optimizer, checkpoint: str, shard: bool = False, size_per_shard: int = 1024): |
|
|
|
|
"""Save optimizer to checkpoint. |
|
|
|
|
Warning: Saving sharded optimizer checkpoint is not supported yet. |
|
|
|
|
def save_optimizer(self, |
|
|
|
|
optimizer: Optimizer, |
|
|
|
|
checkpoint: str, |
|
|
|
|
shard: bool = False, |
|
|
|
|
gather_dtensor: bool = True, |
|
|
|
|
prefix: Optional[str] = None, |
|
|
|
|
size_per_shard: int = 1024): |
|
|
|
|
""" |
|
|
|
|
Save optimizer to checkpoint. |
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
optimizer (Optimizer): An optimizer boosted by Booster. |
|
|
|
@ -218,9 +238,12 @@ class Booster:
|
|
|
|
|
It is a file path if ``shard=False``. Otherwise, it is a directory path. |
|
|
|
|
shard (bool, optional): Whether to save checkpoint a sharded way. |
|
|
|
|
If true, the checkpoint will be a folder. Otherwise, it will be a single file. Defaults to False. |
|
|
|
|
gather_dtensor (bool): whether to gather the distributed tensor to the first device. Default: True. |
|
|
|
|
prefix (str, optional): A prefix added to parameter and buffer |
|
|
|
|
names to compose the keys in state_dict. Defaults to None. |
|
|
|
|
size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. |
|
|
|
|
""" |
|
|
|
|
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, size_per_shard) |
|
|
|
|
self.checkpoint_io.save_optimizer(optimizer, checkpoint, shard, gather_dtensor, prefix, size_per_shard) |
|
|
|
|
|
|
|
|
|
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): |
|
|
|
|
"""Save lr scheduler to checkpoint. |
|
|
|
|