pull/5190/head
Xuanlei Zhao 2024-01-03 15:37:59 +08:00
parent 8ca8cf8ec3
commit 6b69f3085b
4 changed files with 125 additions and 29 deletions

View File

@ -1,15 +1,22 @@
import json
import os
from typing import Any, Dict, Tuple, Union
import torch
from huggingface_hub import snapshot_download
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from colossalai.booster import Booster
from colossalai.cluster import DistCoordinator
def move_to_cuda(batch, device):
return {k: v.to(device) for k, v in batch.items()}
@torch.no_grad()
def load_ckpt(ckpt_path: str, model, booster: Booster, optimizer = None):
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
# pytorch ckpt
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
@ -23,3 +30,73 @@ def load_ckpt(ckpt_path: str, model, booster: Booster, optimizer = None):
if optimizer is not None:
optimizer.sync_moe_master_param()
optimizer.update_master_params(model)
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
"""
Load file in JSON format
"""
with open(file=file_path, mode="r", encoding="utf-8") as fp:
return json.load(fp)
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
"""
Save as JSON format
"""
with open(file=file_path, mode="w", encoding="utf-8") as fp:
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
def save_checkpoint(
save_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
epoch: int,
step: int,
batch_size: int,
coordinator: DistCoordinator,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
running_states = {
"epoch": epoch,
"step": step,
"sample_start_index": step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, "running_states.json"))
def load_checkpoint(
load_dir: Union[str, os.PathLike],
booster: Booster,
model: torch.nn.Module,
optimizer: Optimizer,
lr_scheduler: _LRScheduler,
) -> Tuple[int, int, int]:
"""
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
"""
# Update booster params states.
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer)
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
return (
running_states["epoch"],
running_states["step"],
running_states["sample_start_index"],
)

View File

@ -1,11 +1,11 @@
import argparse
import os
import torch
import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_model
from transformers import AutoTokenizer
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
@ -16,7 +16,6 @@ from colossalai.cluster import DistCoordinator
from colossalai.moe import MOE_MANAGER
from colossalai.moe.utils import skip_init
from colossalai.utils import get_current_device
from colossal_moe.utils import load_ckpt
def parse_args():
@ -124,7 +123,7 @@ def main():
coordinator.print_on_master(f"Finish init booster")
# load ckpt
load_ckpt(args.model_name, model, booster)
load_model(args.model_name, model, booster)
coordinator.print_on_master(f"Finish load ckpt")
text = ["Hello my name is", "1+1=?"]

View File

@ -5,7 +5,7 @@ import torch.distributed as dist
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
from colossal_moe.models.mixtral_layer import replace_moe_layer
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
from colossal_moe.utils import load_ckpt, move_to_cuda
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
@ -58,6 +58,7 @@ def parse_args():
default="mistralai/Mixtral-8x7B-v0.1",
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
parser.add_argument(
"--plugin",
type=str,
@ -235,8 +236,12 @@ def main():
coordinator.print_on_master(f"Finish init booster")
# Load ckpt
load_ckpt(args.model_name, model, booster, optimizer)
coordinator.print_on_master(f"Finish load checkpoint")
if args.load_checkpoint is None:
load_model(args.model_name, model, booster, optimizer)
coordinator.print_on_master(f"Finish load checkpoint")
else:
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
coordinator.print_on_master(f"Finish load optimizer")
# Start finetuning
coordinator.print_on_master(f"Start finetuning")
@ -291,7 +296,17 @@ def main():
# save ckeckpoint
if (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
booster.save_model(model, args.output_path, shard=True)
save_checkpoint(
args.output_path,
booster,
model,
optimizer,
lr_scheduler,
epoch,
step,
args.batch_size,
coordinator,
)
# save checkpoint at the end of each epochs
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)

View File

@ -400,29 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
# Then shard the loaded optimizer states if using tp/zero.
for pid, state in list(state_dict.items()):
if pid in id_map:
param = id_map[pid]
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif (
hasattr(optimizer, "moe_master_to_working_map")
and id(param) in optimizer.moe_master_to_working_map
):
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
working_param,
current_shape=working_param.shape,
original_shape=original_shape,
device="cpu",
inplace=True,
)
state_dict[pid] = sharded_state
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if master_to_working_map is not None and id(param) in master_to_working_map:
working_param = master_to_working_map[id(param)]
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
working_param = optimizer.moe_master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info["param2shape"][id(working_param)]
sharded_state = self.pre_load_optim(
state,
working_param,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True,
)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose and self.coordinator.is_master():
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")