mirror of https://github.com/hpcaitech/ColossalAI
update
parent
8ca8cf8ec3
commit
6b69f3085b
|
@ -1,15 +1,22 @@
|
|||
import json
|
||||
import os
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
|
||||
def move_to_cuda(batch, device):
|
||||
return {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def load_ckpt(ckpt_path: str, model, booster: Booster, optimizer = None):
|
||||
def load_model(ckpt_path: str, model, booster: Booster, optimizer=None):
|
||||
# pytorch ckpt
|
||||
if os.path.exists(os.path.join(ckpt_path, "model.safetensors.index.json")):
|
||||
ckpt_path = os.path.join(ckpt_path, "model.safetensors.index.json")
|
||||
|
@ -23,3 +30,73 @@ def load_ckpt(ckpt_path: str, model, booster: Booster, optimizer = None):
|
|||
if optimizer is not None:
|
||||
optimizer.sync_moe_master_param()
|
||||
optimizer.update_master_params(model)
|
||||
|
||||
|
||||
def load_json(file_path: Union[str, os.PathLike]) -> Dict[str, Any]:
|
||||
"""
|
||||
Load file in JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="r", encoding="utf-8") as fp:
|
||||
return json.load(fp)
|
||||
|
||||
|
||||
def save_json(data: Dict[str, Any], file_path: Union[str, os.PathLike]) -> None:
|
||||
"""
|
||||
Save as JSON format
|
||||
"""
|
||||
with open(file=file_path, mode="w", encoding="utf-8") as fp:
|
||||
json.dump(data, fp=fp, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
save_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
epoch: int,
|
||||
step: int,
|
||||
batch_size: int,
|
||||
coordinator: DistCoordinator,
|
||||
) -> None:
|
||||
"""
|
||||
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
|
||||
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
|
||||
|
||||
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
|
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
|
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
|
||||
running_states = {
|
||||
"epoch": epoch,
|
||||
"step": step,
|
||||
"sample_start_index": step * batch_size,
|
||||
}
|
||||
if coordinator.is_master():
|
||||
save_json(running_states, os.path.join(save_dir, "running_states.json"))
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
load_dir: Union[str, os.PathLike],
|
||||
booster: Booster,
|
||||
model: torch.nn.Module,
|
||||
optimizer: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Load model checkpoint, optimizer, LR scheduler and intermedidate running states.
|
||||
"""
|
||||
|
||||
# Update booster params states.
|
||||
load_model(os.path.join(load_dir, "modeling"), model, booster, optimizer)
|
||||
booster.load_optimizer(optimizer=optimizer, checkpoint=os.path.join(load_dir, "optimizer"))
|
||||
booster.load_lr_scheduler(lr_scheduler=lr_scheduler, checkpoint=os.path.join(load_dir, "lr_scheduler"))
|
||||
|
||||
running_states = load_json(file_path=os.path.join(load_dir, "running_states.json"))
|
||||
return (
|
||||
running_states["epoch"],
|
||||
running_states["step"],
|
||||
running_states["sample_start_index"],
|
||||
)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from colossal_moe.utils import load_model
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
|
||||
|
||||
|
@ -16,7 +16,6 @@ from colossalai.cluster import DistCoordinator
|
|||
from colossalai.moe import MOE_MANAGER
|
||||
from colossalai.moe.utils import skip_init
|
||||
from colossalai.utils import get_current_device
|
||||
from colossal_moe.utils import load_ckpt
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -124,7 +123,7 @@ def main():
|
|||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# load ckpt
|
||||
load_ckpt(args.model_name, model, booster)
|
||||
load_model(args.model_name, model, booster)
|
||||
coordinator.print_on_master(f"Finish load ckpt")
|
||||
|
||||
text = ["Hello my name is", "1+1=?"]
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch.distributed as dist
|
|||
from colossal_moe.models.mixtral_checkpoint import MixtralMoECheckpointIO
|
||||
from colossal_moe.models.mixtral_layer import replace_moe_layer
|
||||
from colossal_moe.models.mixtral_policy import MixtralForCausalLMPolicy
|
||||
from colossal_moe.utils import load_ckpt, move_to_cuda
|
||||
from colossal_moe.utils import load_checkpoint, load_model, move_to_cuda, save_checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer
|
||||
|
@ -58,6 +58,7 @@ def parse_args():
|
|||
default="mistralai/Mixtral-8x7B-v0.1",
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
||||
parser.add_argument(
|
||||
"--plugin",
|
||||
type=str,
|
||||
|
@ -235,8 +236,12 @@ def main():
|
|||
coordinator.print_on_master(f"Finish init booster")
|
||||
|
||||
# Load ckpt
|
||||
load_ckpt(args.model_name, model, booster, optimizer)
|
||||
coordinator.print_on_master(f"Finish load checkpoint")
|
||||
if args.load_checkpoint is None:
|
||||
load_model(args.model_name, model, booster, optimizer)
|
||||
coordinator.print_on_master(f"Finish load checkpoint")
|
||||
else:
|
||||
load_checkpoint(args.load_checkpoint, booster, model, optimizer, lr_scheduler)
|
||||
coordinator.print_on_master(f"Finish load optimizer")
|
||||
|
||||
# Start finetuning
|
||||
coordinator.print_on_master(f"Start finetuning")
|
||||
|
@ -291,7 +296,17 @@ def main():
|
|||
# save ckeckpoint
|
||||
if (step + 1) % args.save_interval == 0:
|
||||
coordinator.print_on_master(f"Saving model checkpoint to {args.output_path}")
|
||||
booster.save_model(model, args.output_path, shard=True)
|
||||
save_checkpoint(
|
||||
args.output_path,
|
||||
booster,
|
||||
model,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
epoch,
|
||||
step,
|
||||
args.batch_size,
|
||||
coordinator,
|
||||
)
|
||||
|
||||
# save checkpoint at the end of each epochs
|
||||
booster.save_model(model, args.output_path, shard=True, size_per_shard=5120)
|
||||
|
|
|
@ -400,29 +400,34 @@ class MoECheckpintIO(HybridParallelCheckpointIO):
|
|||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for pid, state in list(state_dict.items()):
|
||||
if pid in id_map:
|
||||
param = id_map[pid]
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif (
|
||||
hasattr(optimizer, "moe_master_to_working_map")
|
||||
and id(param) in optimizer.moe_master_to_working_map
|
||||
):
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
working_param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device="cpu",
|
||||
inplace=True,
|
||||
)
|
||||
state_dict[pid] = sharded_state
|
||||
|
||||
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Then shard the loaded optimizer states if using tp/zero.
|
||||
for param, state in optimizer.optim.state.items():
|
||||
device = param.device
|
||||
if master_to_working_map is not None and id(param) in master_to_working_map:
|
||||
working_param = master_to_working_map[id(param)]
|
||||
elif hasattr(optimizer, "moe_master_to_working_map") and id(param) in optimizer.moe_master_to_working_map:
|
||||
working_param = optimizer.moe_master_to_working_map[id(param)]
|
||||
else:
|
||||
working_param = param
|
||||
original_shape = optimizer.param_info["param2shape"][id(working_param)]
|
||||
sharded_state = self.pre_load_optim(
|
||||
state,
|
||||
working_param,
|
||||
current_shape=working_param.shape,
|
||||
original_shape=original_shape,
|
||||
device=device,
|
||||
inplace=True,
|
||||
)
|
||||
optimizer.optim.state[param] = sharded_state
|
||||
|
||||
sharded_optimizer_loading_epilogue(optimizer.optim)
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
|
Loading…
Reference in New Issue