mirror of https://github.com/hpcaitech/ColossalAI
[llama] polish training script and fix optim ckpt (#5368)
parent
a5756a8720
commit
eb4f2d90f9
|
@ -23,7 +23,7 @@ from colossal_llama2.utils.froze import freeze_non_embeds_parameters
|
||||||
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
from colossal_llama2.utils.neftune_patch import activate_neftune, deactivate_neftune
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
|
@ -232,7 +232,7 @@ def main() -> None:
|
||||||
else nullcontext()
|
else nullcontext()
|
||||||
)
|
)
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
model = LlamaForCausalLM.from_pretrained(args.pretrained)
|
model = LlamaForCausalLM(LlamaConfig.from_pretrained(args.pretrained))
|
||||||
# Freeze part of parameters.
|
# Freeze part of parameters.
|
||||||
if args.freeze_non_embeds_params:
|
if args.freeze_non_embeds_params:
|
||||||
freeze_non_embeds_parameters(model=model)
|
freeze_non_embeds_parameters(model=model)
|
||||||
|
@ -277,6 +277,8 @@ def main() -> None:
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
dataloader=dataloader,
|
dataloader=dataloader,
|
||||||
)
|
)
|
||||||
|
if args.load_checkpoint is None:
|
||||||
|
booster.load_model(model, args.pretrained)
|
||||||
|
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
|
|
||||||
|
@ -329,7 +331,12 @@ def main() -> None:
|
||||||
|
|
||||||
for epoch in range(start_epoch, args.num_epochs):
|
for epoch in range(start_epoch, args.num_epochs):
|
||||||
dataloader.sampler.set_epoch(epoch=epoch)
|
dataloader.sampler.set_epoch(epoch=epoch)
|
||||||
pbar = tqdm(desc=f"Epoch {epoch}", disable=not coordinator.is_master(), total=num_steps_per_epoch, initial=start_step // args.accumulation_steps)
|
pbar = tqdm(
|
||||||
|
desc=f"Epoch {epoch}",
|
||||||
|
disable=not coordinator.is_master(),
|
||||||
|
total=num_steps_per_epoch,
|
||||||
|
initial=start_step // args.accumulation_steps,
|
||||||
|
)
|
||||||
total_loss = torch.tensor(0.0, device=get_current_device())
|
total_loss = torch.tensor(0.0, device=get_current_device())
|
||||||
for step, batch in enumerate(dataloader, start=start_step):
|
for step, batch in enumerate(dataloader, start=start_step):
|
||||||
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
||||||
|
@ -369,6 +376,7 @@ def main() -> None:
|
||||||
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
||||||
deactivate_neftune(model, handle)
|
deactivate_neftune(model, handle)
|
||||||
|
|
||||||
|
accelerator.empty_cache()
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
save_dir=args.save_dir,
|
save_dir=args.save_dir,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
|
|
|
@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
|
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .general_checkpoint_io import GeneralCheckpointIO
|
from .general_checkpoint_io import GeneralCheckpointIO
|
||||||
from .index_file import CheckpointIndexFile
|
from .index_file import CheckpointIndexFile
|
||||||
|
@ -721,7 +722,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
tp_group=self.tp_group,
|
tp_group=self.tp_group,
|
||||||
use_zero=self.use_zero,
|
use_zero=self.use_zero,
|
||||||
inplace=False,
|
inplace=False,
|
||||||
device=torch.device("cuda"),
|
device=get_current_device(),
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.pp_size == 1:
|
if self.pp_size == 1:
|
||||||
|
@ -854,7 +855,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
||||||
if isinstance(v, torch.Tensor) and k != "step":
|
if isinstance(v, torch.Tensor) and k != "step":
|
||||||
# First gather Zero shards.
|
# First gather Zero shards.
|
||||||
if use_zero:
|
if use_zero:
|
||||||
v = v.cuda()
|
v = v.to(get_current_device())
|
||||||
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
|
||||||
dist.all_gather(gather_tensor, v, group=dp_group)
|
dist.all_gather(gather_tensor, v, group=dp_group)
|
||||||
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
|
||||||
|
|
Loading…
Reference in New Issue