[llama] polish training script and fix optim ckpt (#5368)

pull/5377/head
Hongxin Liu 10 months ago committed by GitHub
parent a5756a8720
commit eb4f2d90f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from transformers import LlamaForCausalLM, LlamaTokenizer from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
import colossalai import colossalai
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
@ -232,7 +232,7 @@ def main() -> None:
else nullcontext() else nullcontext()
) )
with init_ctx: with init_ctx:
model = LlamaForCausalLM.from_pretrained(args.pretrained) model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
# Freeze part of parameters. # Freeze part of parameters.
if args.freeze_non_embeds_params: if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model) freeze_non_embeds_parameters(model=model)
@ -277,6 +277,8 @@ def main() -> None:
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
dataloader=dataloader, dataloader=dataloader,
) )
if args.load_checkpoint is None:
booster.load_model(model, args.pretrained)
torch.set_default_dtype(torch.float) torch.set_default_dtype(torch.float)
@ -329,7 +331,12 @@ def main() -> None:
for epoch in range(start_epoch, args.num_epochs): for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch=epoch) dataloader.sampler.set_epoch(epoch=epoch)
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps) pbar = tqdm(
desc=f"Epoch {epoch}",
disable=not coordinator.is_master(),
total=num_steps_per_epoch,
initial=start_step // args.accumulation_steps,
)
total_loss = torch.tensor(0.0, device=get_current_device()) total_loss = torch.tensor(0.0, device=get_current_device())
for step, batch in enumerate(dataloader, start=start_step): for step, batch in enumerate(dataloader, start=start_step):
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)} batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
@ -369,6 +376,7 @@ def main() -> None:
coordinator.print_on_master("Deactivate NEFTune before saving model.") coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle) deactivate_neftune(model, handle)
accelerator.empty_cache()
save_checkpoint( save_checkpoint(
save_dir=args.save_dir, save_dir=args.save_dir,
booster=booster, booster=booster,

@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.utils import get_current_device
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile from .index_file import CheckpointIndexFile
@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
tp_group=self.tp_group, tp_group=self.tp_group,
use_zero=self.use_zero, use_zero=self.use_zero,
inplace=False, inplace=False,
device=torch.device("cuda"), device=get_current_device(),
) )
if self.pp_size == 1: if self.pp_size == 1:
@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if isinstance(v, torch.Tensor) and k != "step": if isinstance(v, torch.Tensor) and k != "step":
# First gather Zero shards. # First gather Zero shards.
if use_zero: if use_zero:
v = v.cuda() v = v.to(get_current_device())
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)] gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group) dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param) v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)

Loading…
Cancel
Save