[ColossalChat] Add PP support (#6001)

* support pp training

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update rm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test case

* fix

* change to 4

* fix eval

* test

* add pp

* hotfix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* support pp training

* update rm

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactor

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update test case

* fix

* change to 4

* fix eval

* test

* add pp

* hotfix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* skip pp eval

* update all reduce

* update sft

* update ignore

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update no cache

* add eval

* remove fi

* remove debug

* remove parentheses to avoid warning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "add eval"

This reverts commit 3ab2f6fa32.

* add all reduce

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6031/head
Tong Li 2024-08-21 10:47:39 +08:00 committed by GitHub
parent 0d3b0bd864
commit 39e2597426
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 241 additions and 115 deletions

View File

@ -31,18 +31,18 @@ jobs:
- name: Install Colossal-AI - name: Install Colossal-AI
run: | run: |
BUILD_EXT=1 pip install -v -e . BUILD_EXT=1 pip install --no-cache-dir -v -e .
- name: Install ChatGPT - name: Install ChatGPT
run: | run: |
cd applications/ColossalChat cd applications/ColossalChat
pip install -v . pip install --no-cache-dir -v .
export BUILD_EXT=1 export BUILD_EXT=1
pip install -r examples/requirements.txt pip install --no-cache-dir -r examples/requirements.txt
- name: Install Transformers - name: Install Transformers
run: | run: |
pip install transformers==4.36.2 pip install --no-cache-dir transformers==4.36.2
- name: Execute Examples - name: Execute Examples
run: | run: |

View File

@ -161,3 +161,9 @@ applications/ColossalChat/sft_data
applications/ColossalChat/prompt_data applications/ColossalChat/prompt_data
applications/ColossalChat/preference_data applications/ColossalChat/preference_data
applications/ColossalChat/temp applications/ColossalChat/temp
# Testing data
/kto_data/
/preference_data/
/prompt_data/
/sft_data/

View File

@ -16,7 +16,7 @@ from coati.experience_buffer import NaiveExperienceBuffer
from coati.experience_maker import Experience from coati.experience_maker import Experience
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.booster import Booster from colossalai.booster import Booster, Plugin
from .utils import is_rank_0 from .utils import is_rank_0
@ -38,6 +38,7 @@ class SLTrainer(ABC):
max_epochs: int, max_epochs: int,
model: nn.Module, model: nn.Module,
optimizer: Optimizer, optimizer: Optimizer,
plugin: Plugin,
start_epoch: int = 0, start_epoch: int = 0,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -45,6 +46,7 @@ class SLTrainer(ABC):
self.max_epochs = max_epochs self.max_epochs = max_epochs
self.model = model self.model = model
self.optimizer = optimizer self.optimizer = optimizer
self.plugin = plugin
self.start_epoch = start_epoch self.start_epoch = start_epoch
@abstractmethod @abstractmethod

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from tqdm import trange from tqdm import trange
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -50,6 +50,7 @@ class DPOTrainer(SLTrainer):
ref_model: Any, ref_model: Any,
booster: Booster, booster: Booster,
actor_optim: Optimizer, actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler, actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1, max_epochs: int = 1,
@ -63,7 +64,9 @@ class DPOTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: DistCoordinator = None, coordinator: DistCoordinator = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) super().__init__(
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
)
self.ref_model = ref_model self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer

View File

@ -17,7 +17,7 @@ from torch.utils.data import DataLoader
from tqdm import trange from tqdm import trange
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -53,6 +53,7 @@ class KTOTrainer(SLTrainer):
ref_model: Any, ref_model: Any,
booster: Booster, booster: Booster,
actor_optim: Optimizer, actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler, actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1, max_epochs: int = 1,
@ -66,7 +67,9 @@ class KTOTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: DistCoordinator = None, coordinator: DistCoordinator = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) super().__init__(
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
)
self.ref_model = ref_model self.ref_model = ref_model
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer

View File

@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
from tqdm import trange from tqdm import trange
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -48,6 +48,7 @@ class ORPOTrainer(SLTrainer):
actor: Any, actor: Any,
booster: Booster, booster: Booster,
actor_optim: Optimizer, actor_optim: Optimizer,
plugin: Plugin,
actor_lr_scheduler: _LRScheduler, actor_lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1, max_epochs: int = 1,
@ -59,7 +60,9 @@ class ORPOTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: DistCoordinator = None, coordinator: DistCoordinator = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch) super().__init__(
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
)
self.actor_scheduler = actor_lr_scheduler self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.odds_ratio_loss_fn = OddsRatioLoss() self.odds_ratio_loss_fn = OddsRatioLoss()

View File

@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from colossalai.booster import Booster from colossalai.booster import Booster, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -48,6 +48,7 @@ class RewardModelTrainer(SLTrainer):
model: Any, model: Any,
booster: Booster, booster: Booster,
optimizer: Optimizer, optimizer: Optimizer,
plugin: Plugin,
lr_scheduler: _LRScheduler, lr_scheduler: _LRScheduler,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
loss_fn: Optional[Callable] = None, loss_fn: Optional[Callable] = None,
@ -59,7 +60,9 @@ class RewardModelTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: DistCoordinator = None, coordinator: DistCoordinator = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch) super().__init__(
booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch
)
self.actor_scheduler = lr_scheduler self.actor_scheduler = lr_scheduler
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta) self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)

View File

@ -6,14 +6,16 @@ import os
from typing import Optional from typing import Optional
import torch import torch
import torch.distributed as dist
from coati.trainer.utils import all_reduce_mean from coati.trainer.utils import all_reduce_mean
from coati.utils import AccumulativeMeanMeter, save_checkpoint from coati.utils import AccumulativeMeanMeter, save_checkpoint
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm import trange from tqdm import tqdm, trange
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from .base import SLTrainer from .base import SLTrainer
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
optim: Optimizer, optim: Optimizer,
lr_scheduler: _LRScheduler, lr_scheduler: _LRScheduler,
max_epochs: int = 2, max_epochs: int = 2,
plugin: Plugin = None,
accumulation_steps: int = 8, accumulation_steps: int = 8,
apply_loss_mask: bool = True, apply_loss_mask: bool = True,
start_epoch=0, start_epoch=0,
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
save_dir: str = None, save_dir: str = None,
coordinator: Optional[DistCoordinator] = None, coordinator: Optional[DistCoordinator] = None,
) -> None: ) -> None:
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch) super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
self.accumulation_steps = accumulation_steps self.accumulation_steps = accumulation_steps
self.scheduler = lr_scheduler self.scheduler = lr_scheduler
@ -94,60 +97,85 @@ class SFTTrainer(SLTrainer):
def _train(self, epoch: int): def _train(self, epoch: int):
self.model.train() self.model.train()
step_bar = trange( if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
len(self.train_dataloader) // self.accumulation_steps, data_iter = iter(self.train_dataloader)
desc=f"Epoch {epoch + 1}/{self.max_epochs}", step_bar = tqdm(
disable=not is_rank_0(), range(len(self.train_dataloader)),
) desc="Step",
for i, batch in enumerate(self.train_dataloader): disable=not (dist.get_rank() == dist.get_world_size() - 1),
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
) )
loss = outputs.loss for step in step_bar:
outputs = self.booster.execute_pipeline(
data_iter,
self.model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=self.optimizer,
return_loss=True,
)
loss = outputs["loss"]
self.booster.backward(loss=loss, optimizer=self.optimizer) if self.booster.plugin.stage_manager.is_last_stage():
global_loss = all_reduce_mean(loss, self.plugin)
if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"train/loss": global_loss.item()})
loss_mean = all_reduce_mean(tensor=loss)
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
# Gradient accumulation
if (i + 1) % self.accumulation_steps == 0:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
self.scheduler.step() else:
step_bar = trange(
len(self.train_dataloader) // self.accumulation_steps,
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss = outputs.loss
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")}) self.booster.backward(loss=loss, optimizer=self.optimizer)
if self.writer:
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
self.num_train_step += 1
self.accumulative_meter.reset()
step_bar.update()
# Save checkpoint loss_mean = all_reduce_mean(tensor=loss)
if ( self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
self.save_dir is not None
and self.save_interval is not None # Gradient accumulation
and (self.num_train_step + 1) % self.save_interval == 0 if (i + 1) % self.accumulation_steps == 0:
): self.optimizer.step()
save_checkpoint( self.optimizer.zero_grad()
save_dir=self.save_dir, self.scheduler.step()
booster=self.booster,
model=self.model, step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
optimizer=self.optimizer, if self.writer:
lr_scheduler=self.scheduler, self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
epoch=epoch, self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
step=self.num_train_step + 1, self.num_train_step += 1
batch_size=batch_size, self.accumulative_meter.reset()
coordinator=self.coordinator, step_bar.update()
)
self.coordinator.print_on_master( # Save checkpoint
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}" if (
) self.save_dir is not None
and self.save_interval is not None
and (self.num_train_step + 1) % self.save_interval == 0
):
save_checkpoint(
save_dir=self.save_dir,
booster=self.booster,
model=self.model,
optimizer=self.optimizer,
lr_scheduler=self.scheduler,
epoch=epoch,
step=self.num_train_step + 1,
batch_size=batch_size,
coordinator=self.coordinator,
)
self.coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
)
step_bar.close() step_bar.close()
def _eval(self, epoch: int): def _eval(self, epoch: int):
@ -157,27 +185,64 @@ class SFTTrainer(SLTrainer):
self.accumulative_meter.reset() self.accumulative_meter.reset()
self.model.eval() self.model.eval()
with torch.no_grad(): with torch.no_grad():
step_bar = trange( if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
len(self.eval_dataloader), data_iter = iter(self.eval_dataloader)
desc=f"Epoch {epoch + 1}/{self.max_epochs}", step_bar = tqdm(
disable=not is_rank_0(), range(len(self.eval_dataloader)),
) desc="Step",
for batch in self.eval_dataloader: disable=not (dist.get_rank() == dist.get_world_size() - 1),
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
) )
loss_mean = all_reduce_mean(tensor=outputs.loss) for step in step_bar:
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0)) outputs = self.booster.execute_pipeline(
step_bar.update() data_iter,
loss_mean = self.accumulative_meter.get("loss") self.model,
msg = "Evaluation Result:\n" criterion=lambda outputs, inputs: outputs[0],
for tag in ["loss"]: optimizer=self.optimizer,
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n" return_loss=True,
self.coordinator.print_on_master(msg) )
os.makedirs(self.save_dir, exist_ok=True) loss = outputs["loss"]
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f: if self.booster.plugin.stage_manager.is_last_stage():
f.write(msg) global_loss = all_reduce_mean(loss, self.plugin)
step_bar.close() if dist.get_rank() == dist.get_world_size() - 1:
step_bar.set_postfix({"eval/loss": global_loss.item()})
self.accumulative_meter.add("loss", global_loss.item())
if dist.get_rank() == dist.get_world_size() - 1:
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
print(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()
else:
step_bar = trange(
len(self.eval_dataloader),
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
disable=not is_rank_0(),
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
loss_mean = self.accumulative_meter.get("loss")
msg = "Evaluation Result:\n"
for tag in ["loss"]:
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
self.coordinator.print_on_master(msg)
if self.save_dir is not None:
os.makedirs(self.save_dir, exist_ok=True)
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
f.write(msg)
step_bar.close()

View File

@ -9,6 +9,8 @@ import torch.distributed as dist
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from colossalai.booster import Plugin
class CycledDataLoader: class CycledDataLoader:
""" """
@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any:
return tree_map(_to, x) return tree_map(_to, x)
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
""" """
Perform all-reduce operation on the given tensor and compute the mean across all processes. Perform all-reduce operation on the given tensor and compute the mean across all processes.
@ -95,8 +97,13 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
Returns: Returns:
torch.Tensor: The reduced tensor with mean computed across all processes. torch.Tensor: The reduced tensor with mean computed across all processes.
""" """
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) # All reduce mean across DP group
tensor.div_(dist.get_world_size()) if plugin is not None:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
tensor.div_(plugin.dp_size)
else:
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor return tensor

View File

@ -267,6 +267,7 @@ def train(args):
ref_model=ref_model, ref_model=ref_model,
booster=booster, booster=booster,
actor_optim=optim, actor_optim=optim,
plugin=plugin,
actor_lr_scheduler=lr_scheduler, actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer, tokenizer=tokenizer,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,

View File

@ -286,6 +286,7 @@ def train(args):
ref_model=ref_model, ref_model=ref_model,
booster=booster, booster=booster,
actor_optim=optim, actor_optim=optim,
plugin=plugin,
actor_lr_scheduler=lr_scheduler, actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer, tokenizer=tokenizer,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,

View File

@ -250,6 +250,7 @@ def train(args):
actor=model, actor=model,
booster=booster, booster=booster,
actor_optim=optim, actor_optim=optim,
plugin=plugin,
actor_lr_scheduler=lr_scheduler, actor_lr_scheduler=lr_scheduler,
tokenizer=tokenizer, tokenizer=tokenizer,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,

View File

@ -262,6 +262,7 @@ def train(args):
model, model,
booster, booster,
optim, optim,
plugin,
lr_scheduler, lr_scheduler,
tokenizer, tokenizer,
loss_fn=loss_fn, loss_fn=loss_fn,

View File

@ -114,7 +114,7 @@ def train(args):
parallel_output=False, parallel_output=False,
max_norm=args.grad_clip, max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
microbatch_size=args.batch_size, microbatch_size=args.microbatch_size,
) )
else: else:
raise ValueError(f"Unknown plugin {args.plugin}") raise ValueError(f"Unknown plugin {args.plugin}")
@ -269,6 +269,7 @@ def train(args):
model=model, model=model,
booster=booster, booster=booster,
optim=optim, optim=optim,
plugin=plugin,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
max_epochs=args.max_epochs, max_epochs=args.max_epochs,
accumulation_steps=args.accumulation_steps, accumulation_steps=args.accumulation_steps,
@ -344,6 +345,7 @@ if __name__ == "__main__":
parser.add_argument("--use_wandb", default=False, action="store_true") parser.add_argument("--use_wandb", default=False, action="store_true")
parser.add_argument("--grad_checkpoint", default=False, action="store_true") parser.add_argument("--grad_checkpoint", default=False, action="store_true")
parser.add_argument("--use_flash_attn", default=False, action="store_true") parser.add_argument("--use_flash_attn", default=False, action="store_true")
parser.add_argument("--microbatch_size", type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
if args.config_file is not None: if args.config_file is not None:
os.makedirs(os.path.dirname(args.config_file), exist_ok=True) os.makedirs(os.path.dirname(args.config_file), exist_ok=True)

View File

@ -61,7 +61,7 @@ def test_overfit():
_, predicted = torch.max(outputs.data, 1) _, predicted = torch.max(outputs.data, 1)
total = labels.size(0) total = labels.size(0)
correct = (predicted == Y).sum().item() correct = (predicted == Y).sum().item()
assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset") assert correct / total > 0.95
assert (weight_to_compare - model.fc1.weight).sum() < 0.01 assert (weight_to_compare - model.fc1.weight).sum() < 0.01

View File

@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
} }
set_n_least_used_CUDA_VISIBLE_DEVICES 2 set_n_least_used_CUDA_VISIBLE_DEVICES 4
set -xu set -xu
@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
MODELS_DIR=$TEMP_DIR/models_config MODELS_DIR=$TEMP_DIR/models_config
# Skip those tests due to CI tests timeout # Skip those tests due to CI tests timeout
MODELS=('llama') MODELS=('llama')
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp' 'tp_pp')
PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json" LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
@ -91,7 +91,7 @@ SKIPPED_TESTS=(
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
llama-gemini-20 # gemini doesn't support lora llama-gemini-20 # gemini doesn't support lora
) )
skip_eval=false
GRAD_CKPTS=('--grad_checkpoint') GRAD_CKPTS=('--grad_checkpoint')
for lora_rank in ${LORA_RANK[@]}; do for lora_rank in ${LORA_RANK[@]}; do
for model in ${MODELS[@]}; do for model in ${MODELS[@]}; do
@ -129,15 +129,18 @@ for lora_rank in ${LORA_RANK[@]}; do
plugin='3d' plugin='3d'
fi fi
if [[ $plugin == "tp_pp" ]]; then if [[ $plugin == "tp_pp" ]]; then
echo "Here"
tp='2' tp='2'
bs='8' bs='8'
pp='2' pp='2'
plugin='3d' plugin='3d'
skip_eval=true
fi fi
if [[ $plugin == "pp" ]]; then if [[ $plugin == "pp" ]]; then
bs='8' bs='8'
pp='2' pp='2'
plugin='3d' plugin='3d'
skip_eval=true
fi fi
if [[ $plugin == "sp_split_gather" ]]; then if [[ $plugin == "sp_split_gather" ]]; then
enable_sequence_parallelism='--enable_sequence_parallelism' enable_sequence_parallelism='--enable_sequence_parallelism'
@ -175,28 +178,53 @@ for lora_rank in ${LORA_RANK[@]}; do
for split in $(seq -f "%05g" 0 0); do for split in $(seq -f "%05g" 0 0); do
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split") dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
done done
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--pretrain $pretrain \ if [[ $skip_eval ]]; then
--tokenizer_dir $tokenizer_dir \ colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--dataset ${dataset[@]} \ --pretrain $pretrain \
--eval_dataset ${dataset[@]} \ --tokenizer_dir $tokenizer_dir \
--save_path $MODEL_SAVE_PATH \ --dataset ${dataset[@]} \
--config_file $MODELS_DIR/config.jsonl \ --save_path $MODEL_SAVE_PATH \
$lora_config \ --config_file $MODELS_DIR/config.jsonl \
--plugin $plugin \ $lora_config \
--batch_size $bs \ --plugin $plugin \
--max_epochs 1 \ --batch_size $bs \
--accumulation_steps $grad_accu \ --max_epochs 1 \
--tp $tp \ --accumulation_steps $grad_accu \
--pp $pp \ --tp $tp \
--zero_stage $zero_stage \ --pp $pp \
--sp $sp \ --zero_stage $zero_stage \
--sp_mode $sp_mode \ --sp $sp \
$enable_sequence_parallelism \ --sp_mode $sp_mode \
--lr 2e-5 \ $enable_sequence_parallelism \
$grad_ckpt \ --lr 2e-5 \
--max_len 400 \ $grad_ckpt \
--use_flash_attn --max_len 400 \
--use_flash_attn
else
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
--pretrain $pretrain \
--tokenizer_dir $tokenizer_dir \
--dataset ${dataset[@]} \
--eval_dataset ${dataset[@]} \
--save_path $MODEL_SAVE_PATH \
--config_file $MODELS_DIR/config.jsonl \
$lora_config \
--plugin $plugin \
--batch_size $bs \
--max_epochs 1 \
--accumulation_steps $grad_accu \
--tp $tp \
--pp $pp \
--zero_stage $zero_stage \
--sp $sp \
--sp_mode $sp_mode \
$enable_sequence_parallelism \
--lr 2e-5 \
$grad_ckpt \
--max_len 400 \
--use_flash_attn
fi
passed=$? passed=$?
if [ $passed -eq 0 ]; then if [ $passed -eq 0 ]; then
rm -rf ${MODEL_SAVE_PATH:?}/* rm -rf ${MODEL_SAVE_PATH:?}/*