mirror of https://github.com/hpcaitech/ColossalAI
support pp training
parent
ceb1e262e7
commit
38c84a1aa0
|
@ -17,6 +17,7 @@ from coati.experience_maker import Experience
|
|||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster import Plugin
|
||||
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
@ -38,6 +39,7 @@ class SLTrainer(ABC):
|
|||
max_epochs: int,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
plugin: Plugin,
|
||||
start_epoch: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -45,6 +47,7 @@ class SLTrainer(ABC):
|
|||
self.max_epochs = max_epochs
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.plugin = plugin
|
||||
self.start_epoch = start_epoch
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -6,14 +6,16 @@ import os
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import trange
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from .base import SLTrainer
|
||||
|
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
|
|||
optim: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
max_epochs: int = 2,
|
||||
plugin: Plugin = None,
|
||||
accumulation_steps: int = 8,
|
||||
apply_loss_mask: bool = True,
|
||||
start_epoch=0,
|
||||
|
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
|
|||
save_dir: str = None,
|
||||
coordinator: Optional[DistCoordinator] = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
|
||||
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
|
||||
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.scheduler = lr_scheduler
|
||||
|
@ -94,60 +97,82 @@ class SFTTrainer(SLTrainer):
|
|||
|
||||
def _train(self, epoch: int):
|
||||
self.model.train()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||
data_iter = iter(self.train_dataloader)
|
||||
step_bar = tqdm(
|
||||
range(len(self.train_dataloader)),
|
||||
desc="Step",
|
||||
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
|
||||
# Gradient accumulation
|
||||
if (i + 1) % self.accumulation_steps == 0:
|
||||
for step in step_bar:
|
||||
outputs = self.booster.execute_pipeline(
|
||||
data_iter,
|
||||
self.model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
step_bar.set_postfix({"train/loss": loss.item()})
|
||||
step_bar.update()
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
else:
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
|
||||
if self.writer:
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
||||
self.num_train_step += 1
|
||||
self.accumulative_meter.reset()
|
||||
step_bar.update()
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
|
||||
# Save checkpoint
|
||||
if (
|
||||
self.save_dir is not None
|
||||
and self.save_interval is not None
|
||||
and (self.num_train_step + 1) % self.save_interval == 0
|
||||
):
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.scheduler,
|
||||
epoch=epoch,
|
||||
step=self.num_train_step + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
|
||||
)
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
|
||||
# Gradient accumulation
|
||||
if (i + 1) % self.accumulation_steps == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
|
||||
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
|
||||
if self.writer:
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
||||
self.num_train_step += 1
|
||||
self.accumulative_meter.reset()
|
||||
step_bar.update()
|
||||
|
||||
# Save checkpoint
|
||||
if (
|
||||
self.save_dir is not None
|
||||
and self.save_interval is not None
|
||||
and (self.num_train_step + 1) % self.save_interval == 0
|
||||
):
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.scheduler,
|
||||
epoch=epoch,
|
||||
step=self.num_train_step + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
|
||||
)
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
|
|
|
@ -114,7 +114,7 @@ def train(args):
|
|||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.batch_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -269,6 +269,7 @@ def train(args):
|
|||
model=model,
|
||||
booster=booster,
|
||||
optim=optim,
|
||||
plugin=plugin,
|
||||
lr_scheduler=lr_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
|
@ -344,6 +345,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
parser.add_argument("--microbatch_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
if args.config_file is not None:
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
|
|
Loading…
Reference in New Issue