From eb4f2d90f900e4a9f96f12fb3692f0793ff9884d Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 6 Feb 2024 11:52:17 +0800 Subject: [PATCH] [llama] polish training script and fix optim ckpt (#5368) --- applications/Colossal-LLaMA-2/train.py | 14 +++++++++++--- .../checkpoint_io/hybrid_parallel_checkpoint_io.py | 5 +++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/applications/Colossal-LLaMA-2/train.py b/applications/Colossal-LLaMA-2/train.py index ebb919979..1c1389b5c 100644 --- a/applications/Colossal-LLaMA-2/train.py +++ b/applications/Colossal-LLaMA-2/train.py @@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm -from transformers import LlamaForCausalLM, LlamaTokenizer +from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer import colossalai from colossalai.accelerator import get_accelerator @@ -232,7 +232,7 @@ def main() -> None: else nullcontext() ) with init_ctx: - model = LlamaForCausalLM.from_pretrained(args.pretrained) + model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained)) # Freeze part of parameters. if args.freeze_non_embeds_params: freeze_non_embeds_parameters(model=model) @@ -277,6 +277,8 @@ def main() -> None: lr_scheduler=lr_scheduler, dataloader=dataloader, ) + if args.load_checkpoint is None: + booster.load_model(model, args.pretrained) torch.set_default_dtype(torch.float) @@ -329,7 +331,12 @@ def main() -> None: for epoch in range(start_epoch, args.num_epochs): dataloader.sampler.set_epoch(epoch=epoch) - pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps) + pbar = tqdm( + desc=f"Epoch {epoch}", + disable=not coordinator.is_master(), + total=num_steps_per_epoch, + initial=start_step // args.accumulation_steps, + ) total_loss = torch.tensor(0.0, device=get_current_device()) for step, batch in enumerate(dataloader, start=start_step): batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} @@ -369,6 +376,7 @@ def main() -> None: coordinator.print_on_master("Deactivate NEFTune before saving model.") deactivate_neftune(model, handle) + accelerator.empty_cache() save_checkpoint( save_dir=args.save_dir, booster=booster, diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 5f832f13c..36df30335 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.utils import get_current_device from .general_checkpoint_io import GeneralCheckpointIO from .index_file import CheckpointIndexFile @@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): tp_group=self.tp_group, use_zero=self.use_zero, inplace=False, - device=torch.device("cuda"), + device=get_current_device(), ) if self.pp_size == 1: @@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO): if isinstance(v, torch.Tensor) and k != "step": # First gather Zero shards. if use_zero: - v = v.cuda() + v = v.to(get_current_device()) gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] dist.all_gather(gather_tensor, v, group=dp_group) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)