From 6b69f3085bc3e2145feaf56f5fb237e4df8f368b Mon Sep 17 00:00:00 2001 From: Xuanlei Zhao Date: Wed, 3 Jan 2024 15:37:59 +0800 Subject: [PATCH] update --- .../ColossalMoE/colossal_moe/utils.py | 81 ++++++++++++++++++- applications/ColossalMoE/infer.py | 5 +- applications/ColossalMoE/train.py | 23 +++++- colossalai/moe/checkpoint.py | 45 ++++++----- 4 files changed, 125 insertions(+), 29 deletions(-) diff --git a/applications/ColossalMoE/colossal_moe/utils.py b/applications/ColossalMoE/colossal_moe/utils.py index 89484df2c..70b827264 100644 --- a/applications/ColossalMoE/colossal_moe/utils.py +++ b/applications/ColossalMoE/colossal_moe/utils.py @@ -1,15 +1,22 @@ +import json import os +from typing import Any, Dict, Tuple, Union import torch - from huggingface_hub import snapshot_download +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer + from colossalai.booster import Booster +from colossalai.cluster import DistCoordinator + def move_to_cuda(batch, device): return {k: v.to(device) for k, v in batch.items()} + @torch.no_grad() -def load_ckpt(ckpt_path: str, model, booster: Booster, optimizer = None): +def load_model(ckpt_path: str, model, booster: Booster, optimizer=None): # pytorch ckpt if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")): ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json") @@ -23,3 +30,73 @@ def load_ckpt(ckpt_path: str, model, booster: Booster, optimizer = None): if optimizer is not None: optimizer.sync_moe_master_param() optimizer.update_master_params(model) + + +def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]: + """ + Load file in JSON format + """ + with open(file=file_path, mode="r", encoding="utf-8") as fp: + return json.load(fp) + + +def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None: + """ + Save as JSON format + """ + with open(file=file_path, mode="w", encoding="utf-8") as fp: + json.dump(data, fp=fp, ensure_ascii=False, indent=4) + + +def save_checkpoint( + save_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, + epoch: int, + step: int, + batch_size: int, + coordinator: DistCoordinator, +) -> None: + """ + Save model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}") + os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True) + + booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True) + booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True) + booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler")) + running_states = { + "epoch": epoch, + "step": step, + "sample_start_index": step * batch_size, + } + if coordinator.is_master(): + save_json(running_states, os.path.join(save_dir, "running_states.json")) + + +def load_checkpoint( + load_dir: Union[str, os.PathLike], + booster: Booster, + model: torch.nn.Module, + optimizer: Optimizer, + lr_scheduler: _LRScheduler, +) -> Tuple[int, int, int]: + """ + Load model checkpoint, optimizer, LR scheduler and intermedidate running states. + """ + + # Update booster params states. + load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer) + booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer")) + booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler")) + + running_states = load_json(file_path=os.path.join(load_dir, "running_states.json")) + return ( + running_states["epoch"], + running_states["step"], + running_states["sample_start_index"], + ) diff --git a/applications/ColossalMoE/infer.py b/applications/ColossalMoE/infer.py index 7989ba4e5..70ddff940 100644 --- a/applications/ColossalMoE/infer.py +++ b/applications/ColossalMoE/infer.py @@ -1,11 +1,11 @@ import argparse -import os import torch import torch.distributed as dist from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_layer import replace_moe_layer from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy +from colossal_moe.utils import load_model from transformers import AutoTokenizer from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM @@ -16,7 +16,6 @@ from colossalai.cluster import DistCoordinator from colossalai.moe import MOE_MANAGER from colossalai.moe.utils import skip_init from colossalai.utils import get_current_device -from colossal_moe.utils import load_ckpt def parse_args(): @@ -124,7 +123,7 @@ def main(): coordinator.print_on_master(f"Finish init booster") # load ckpt - load_ckpt(args.model_name, model, booster) + load_model(args.model_name, model, booster) coordinator.print_on_master(f"Finish load ckpt") text = ["Hello my name is", "1+1=?"] diff --git a/applications/ColossalMoE/train.py b/applications/ColossalMoE/train.py index 7c8807c24..1d0441a5a 100644 --- a/applications/ColossalMoE/train.py +++ b/applications/ColossalMoE/train.py @@ -5,7 +5,7 @@ import torch.distributed as dist from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO from colossal_moe.models.mixtral_layer import replace_moe_layer from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy -from colossal_moe.utils import load_ckpt, move_to_cuda +from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint from torch.utils.data import Dataset from tqdm import tqdm from transformers import AutoTokenizer @@ -58,6 +58,7 @@ def parse_args(): default="mistralai/Mixtral-8x7B-v0.1", help="Path to pretrained model or model identifier from huggingface.co/models.", ) + parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint") parser.add_argument( "--plugin", type=str, @@ -235,8 +236,12 @@ def main(): coordinator.print_on_master(f"Finish init booster") # Load ckpt - load_ckpt(args.model_name, model, booster, optimizer) - coordinator.print_on_master(f"Finish load checkpoint") + if args.load_checkpoint is None: + load_model(args.model_name, model, booster, optimizer) + coordinator.print_on_master(f"Finish load checkpoint") + else: + load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler) + coordinator.print_on_master(f"Finish load optimizer") # Start finetuning coordinator.print_on_master(f"Start finetuning") @@ -291,7 +296,17 @@ def main(): # save ckeckpoint if (step + 1) % args.save_interval == 0: coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}") - booster.save_model(model, args.output_path, shard=True) + save_checkpoint( + args.output_path, + booster, + model, + optimizer, + lr_scheduler, + epoch, + step, + args.batch_size, + coordinator, + ) # save checkpoint at the end of each epochs booster.save_model(model, args.output_path, shard=True, size_per_shard=5120) diff --git a/colossalai/moe/checkpoint.py b/colossalai/moe/checkpoint.py index 9928c801d..b37ffabea 100644 --- a/colossalai/moe/checkpoint.py +++ b/colossalai/moe/checkpoint.py @@ -400,29 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO): file_path = os.path.join(ckpt_root_path, filename) state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False) + + # Then shard the loaded optimizer states if using tp/zero. + for pid, state in list(state_dict.items()): + if pid in id_map: + param = id_map[pid] + if master_to_working_map is not None and id(param) in master_to_working_map: + working_param = master_to_working_map[id(param)] + elif ( + hasattr(optimizer, "moe_master_to_working_map") + and id(param) in optimizer.moe_master_to_working_map + ): + working_param = optimizer.moe_master_to_working_map[id(param)] + else: + working_param = param + original_shape = optimizer.param_info["param2shape"][id(working_param)] + sharded_state = self.pre_load_optim( + state, + working_param, + current_shape=working_param.shape, + original_shape=original_shape, + device="cpu", + inplace=True, + ) + state_dict[pid] = sharded_state + load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True) loaded_file.add(filename) - # Then shard the loaded optimizer states if using tp/zero. - for param, state in optimizer.optim.state.items(): - device = param.device - if master_to_working_map is not None and id(param) in master_to_working_map: - working_param = master_to_working_map[id(param)] - elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map: - working_param = optimizer.moe_master_to_working_map[id(param)] - else: - working_param = param - original_shape = optimizer.param_info["param2shape"][id(working_param)] - sharded_state = self.pre_load_optim( - state, - working_param, - current_shape=working_param.shape, - original_shape=original_shape, - device=device, - inplace=True, - ) - optimizer.optim.state[param] = sharded_state - sharded_optimizer_loading_epilogue(optimizer.optim) if self.verbose and self.coordinator.is_master(): logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")