support pp training

pull/6007/head
Tong Li 2024-08-12 10:13:03 +00:00
parent ceb1e262e7
commit 38c84a1aa0
3 changed files with 82 additions and 52 deletions

View File

@ -17,6 +17,7 @@ from coati.experience_maker import Experience
from torch.optim import Optimizer
from colossalai.booster import Booster
from colossalai.booster import Plugin
from .utils import is_rank_0
@ -38,6 +39,7 @@ class SLTrainer(ABC):
max_epochs: int,
model: nn.Module,
optimizer: Optimizer,
plugin: Plugin,
start_epoch: int = 0,
) -> None:
super().__init__()
@ -45,6 +47,7 @@ class SLTrainer(ABC):
self.max_epochs = max_epochs
self.model = model
self.optimizer = optimizer
self.plugin = plugin
self.start_epoch = start_epoch
@abstractmethod

View File

@ -6,14 +6,16 @@ import os
from typing import Optional
import torch
import torch.distributed as dist
from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import trange
from tqdm import tqdm, trange
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
from colossalai.cluster import DistCoordinator
from .base import SLTrainer
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
optim: Optimizer,
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
plugin: Plugin = None,
accumulation_steps: int = 8,
apply_loss_mask: bool = True,
start_epoch=0,
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
save_dir: str = None,
coordinator: Optional[DistCoordinator] = None,
) -> None:
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler
@ -94,60 +97,82 @@ class SFTTrainer(SLTrainer):
def _train(self, epoch: int):
self.model.train()
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
data_iter = iter(self.train_dataloader)
step_bar = tqdm(
range(len(self.train_dataloader)),
desc="Step",
disable=not (dist.get_rank() == dist.get_world_size() - 1),
)
loss = outputs.loss
self.booster.backward(loss=loss, optimizer=self.optimizer)
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
for step in step_bar:
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"train/loss": loss.item()})
step_bar.update()
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
else:
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss = outputs.loss
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
self.booster.backward(loss=loss, optimizer=self.optimizer)
# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step()
self.optimizer.zero_grad()
self.scheduler.step()
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
# Save checkpoint
if (
self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
step_bar.close()
def _eval(self, epoch: int):

View File

@ -114,7 +114,7 @@ def train(args):
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.batch_size,
microbatch_size=args.microbatch_size,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")
@ -269,6 +269,7 @@ def train(args):
model=model,
booster=booster,
optim=optim,
plugin=plugin,
lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps,
@ -344,6 +345,7 @@ if __name__ == "__main__":
parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument("--microbatch_size", type=int, default=1)
args = parser.parse_args()
if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)