mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'main' into feature/fp8_comm
commit
caab4a307f
|
@ -31,18 +31,18 @@ jobs:
|
|||
|
||||
- name: Install Colossal-AI
|
||||
run: |
|
||||
BUILD_EXT=1 pip install -v -e .
|
||||
BUILD_EXT=1 pip install --no-cache-dir -v -e .
|
||||
|
||||
- name: Install ChatGPT
|
||||
run: |
|
||||
cd applications/ColossalChat
|
||||
pip install -v .
|
||||
pip install --no-cache-dir -v .
|
||||
export BUILD_EXT=1
|
||||
pip install -r examples/requirements.txt
|
||||
pip install --no-cache-dir -r examples/requirements.txt
|
||||
|
||||
- name: Install Transformers
|
||||
run: |
|
||||
pip install transformers==4.36.2
|
||||
pip install --no-cache-dir transformers==4.36.2
|
||||
|
||||
- name: Execute Examples
|
||||
run: |
|
||||
|
|
|
@ -161,3 +161,9 @@ applications/ColossalChat/sft_data
|
|||
applications/ColossalChat/prompt_data
|
||||
applications/ColossalChat/preference_data
|
||||
applications/ColossalChat/temp
|
||||
|
||||
# Testing data
|
||||
/kto_data/
|
||||
/preference_data/
|
||||
/prompt_data/
|
||||
/sft_data/
|
||||
|
|
|
@ -16,7 +16,7 @@ from coati.experience_buffer import NaiveExperienceBuffer
|
|||
from coati.experience_maker import Experience
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster import Booster, Plugin
|
||||
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
@ -38,6 +38,7 @@ class SLTrainer(ABC):
|
|||
max_epochs: int,
|
||||
model: nn.Module,
|
||||
optimizer: Optimizer,
|
||||
plugin: Plugin,
|
||||
start_epoch: int = 0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
@ -45,6 +46,7 @@ class SLTrainer(ABC):
|
|||
self.max_epochs = max_epochs
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.plugin = plugin
|
||||
self.start_epoch = start_epoch
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import trange
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster import Booster, Plugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
@ -50,6 +50,7 @@ class DPOTrainer(SLTrainer):
|
|||
ref_model: Any,
|
||||
booster: Booster,
|
||||
actor_optim: Optimizer,
|
||||
plugin: Plugin,
|
||||
actor_lr_scheduler: _LRScheduler,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
|
@ -63,7 +64,9 @@ class DPOTrainer(SLTrainer):
|
|||
save_dir: str = None,
|
||||
coordinator: DistCoordinator = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
||||
super().__init__(
|
||||
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
|
||||
)
|
||||
self.ref_model = ref_model
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
|
|
|
@ -17,7 +17,7 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import trange
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster import Booster, Plugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
@ -53,6 +53,7 @@ class KTOTrainer(SLTrainer):
|
|||
ref_model: Any,
|
||||
booster: Booster,
|
||||
actor_optim: Optimizer,
|
||||
plugin: Plugin,
|
||||
actor_lr_scheduler: _LRScheduler,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
|
@ -66,7 +67,9 @@ class KTOTrainer(SLTrainer):
|
|||
save_dir: str = None,
|
||||
coordinator: DistCoordinator = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
||||
super().__init__(
|
||||
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
|
||||
)
|
||||
self.ref_model = ref_model
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
|
|
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
|||
from tqdm import trange
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster import Booster, Plugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
@ -48,6 +48,7 @@ class ORPOTrainer(SLTrainer):
|
|||
actor: Any,
|
||||
booster: Booster,
|
||||
actor_optim: Optimizer,
|
||||
plugin: Plugin,
|
||||
actor_lr_scheduler: _LRScheduler,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
max_epochs: int = 1,
|
||||
|
@ -59,7 +60,9 @@ class ORPOTrainer(SLTrainer):
|
|||
save_dir: str = None,
|
||||
coordinator: DistCoordinator = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
||||
super().__init__(
|
||||
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
|
||||
)
|
||||
self.actor_scheduler = actor_lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.odds_ratio_loss_fn = OddsRatioLoss()
|
||||
|
|
|
@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster import Booster, Plugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
@ -48,6 +48,7 @@ class RewardModelTrainer(SLTrainer):
|
|||
model: Any,
|
||||
booster: Booster,
|
||||
optimizer: Optimizer,
|
||||
plugin: Plugin,
|
||||
lr_scheduler: _LRScheduler,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
loss_fn: Optional[Callable] = None,
|
||||
|
@ -59,7 +60,9 @@ class RewardModelTrainer(SLTrainer):
|
|||
save_dir: str = None,
|
||||
coordinator: DistCoordinator = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch)
|
||||
super().__init__(
|
||||
booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch
|
||||
)
|
||||
self.actor_scheduler = lr_scheduler
|
||||
self.tokenizer = tokenizer
|
||||
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)
|
||||
|
|
|
@ -6,14 +6,16 @@ import os
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import trange
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from .base import SLTrainer
|
||||
|
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
|
|||
optim: Optimizer,
|
||||
lr_scheduler: _LRScheduler,
|
||||
max_epochs: int = 2,
|
||||
plugin: Plugin = None,
|
||||
accumulation_steps: int = 8,
|
||||
apply_loss_mask: bool = True,
|
||||
start_epoch=0,
|
||||
|
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
|
|||
save_dir: str = None,
|
||||
coordinator: Optional[DistCoordinator] = None,
|
||||
) -> None:
|
||||
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
|
||||
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
|
||||
|
||||
self.accumulation_steps = accumulation_steps
|
||||
self.scheduler = lr_scheduler
|
||||
|
@ -94,60 +97,85 @@ class SFTTrainer(SLTrainer):
|
|||
|
||||
def _train(self, epoch: int):
|
||||
self.model.train()
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||
data_iter = iter(self.train_dataloader)
|
||||
step_bar = tqdm(
|
||||
range(len(self.train_dataloader)),
|
||||
desc="Step",
|
||||
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||
)
|
||||
loss = outputs.loss
|
||||
for step in step_bar:
|
||||
outputs = self.booster.execute_pipeline(
|
||||
data_iter,
|
||||
self.model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
global_loss = all_reduce_mean(loss, self.plugin)
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
step_bar.set_postfix({"train/loss": global_loss.item()})
|
||||
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
|
||||
# Gradient accumulation
|
||||
if (i + 1) % self.accumulation_steps == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
else:
|
||||
step_bar = trange(
|
||||
len(self.train_dataloader) // self.accumulation_steps,
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for i, batch in enumerate(self.train_dataloader):
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
batch_size = batch["input_ids"].size(0)
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
)
|
||||
loss = outputs.loss
|
||||
|
||||
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
|
||||
if self.writer:
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
||||
self.num_train_step += 1
|
||||
self.accumulative_meter.reset()
|
||||
step_bar.update()
|
||||
self.booster.backward(loss=loss, optimizer=self.optimizer)
|
||||
|
||||
# Save checkpoint
|
||||
if (
|
||||
self.save_dir is not None
|
||||
and self.save_interval is not None
|
||||
and (self.num_train_step + 1) % self.save_interval == 0
|
||||
):
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.scheduler,
|
||||
epoch=epoch,
|
||||
step=self.num_train_step + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
|
||||
)
|
||||
loss_mean = all_reduce_mean(tensor=loss)
|
||||
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).item())
|
||||
|
||||
# Gradient accumulation
|
||||
if (i + 1) % self.accumulation_steps == 0:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
|
||||
step_bar.set_postfix({"train/loss": self.accumulative_meter.get("loss")})
|
||||
if self.writer:
|
||||
self.writer.add_scalar("train/loss", self.accumulative_meter.get("loss"), self.num_train_step)
|
||||
self.writer.add_scalar("train/lr", self.scheduler.get_last_lr()[0], self.num_train_step)
|
||||
self.num_train_step += 1
|
||||
self.accumulative_meter.reset()
|
||||
step_bar.update()
|
||||
|
||||
# Save checkpoint
|
||||
if (
|
||||
self.save_dir is not None
|
||||
and self.save_interval is not None
|
||||
and (self.num_train_step + 1) % self.save_interval == 0
|
||||
):
|
||||
save_checkpoint(
|
||||
save_dir=self.save_dir,
|
||||
booster=self.booster,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.scheduler,
|
||||
epoch=epoch,
|
||||
step=self.num_train_step + 1,
|
||||
batch_size=batch_size,
|
||||
coordinator=self.coordinator,
|
||||
)
|
||||
self.coordinator.print_on_master(
|
||||
f"Saved checkpoint at epoch {epoch} step {self.num_train_step} at folder {self.save_dir}"
|
||||
)
|
||||
step_bar.close()
|
||||
|
||||
def _eval(self, epoch: int):
|
||||
|
@ -157,27 +185,64 @@ class SFTTrainer(SLTrainer):
|
|||
self.accumulative_meter.reset()
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
step_bar = trange(
|
||||
len(self.eval_dataloader),
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for batch in self.eval_dataloader:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||
data_iter = iter(self.eval_dataloader)
|
||||
step_bar = tqdm(
|
||||
range(len(self.eval_dataloader)),
|
||||
desc="Step",
|
||||
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||
)
|
||||
loss_mean = all_reduce_mean(tensor=outputs.loss)
|
||||
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
|
||||
step_bar.update()
|
||||
loss_mean = self.accumulative_meter.get("loss")
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
for step in step_bar:
|
||||
outputs = self.booster.execute_pipeline(
|
||||
data_iter,
|
||||
self.model,
|
||||
criterion=lambda outputs, inputs: outputs[0],
|
||||
optimizer=self.optimizer,
|
||||
return_loss=True,
|
||||
)
|
||||
loss = outputs["loss"]
|
||||
if self.booster.plugin.stage_manager.is_last_stage():
|
||||
global_loss = all_reduce_mean(loss, self.plugin)
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
step_bar.set_postfix({"eval/loss": global_loss.item()})
|
||||
self.accumulative_meter.add("loss", global_loss.item())
|
||||
|
||||
if dist.get_rank() == dist.get_world_size() - 1:
|
||||
loss_mean = self.accumulative_meter.get("loss")
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
print(msg)
|
||||
if self.save_dir is not None:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
|
||||
else:
|
||||
step_bar = trange(
|
||||
len(self.eval_dataloader),
|
||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||
disable=not is_rank_0(),
|
||||
)
|
||||
for batch in self.eval_dataloader:
|
||||
batch = to_device(batch, torch.cuda.current_device())
|
||||
outputs = self.model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
|
||||
)
|
||||
loss_mean = all_reduce_mean(tensor=outputs.loss)
|
||||
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
|
||||
step_bar.update()
|
||||
|
||||
loss_mean = self.accumulative_meter.get("loss")
|
||||
msg = "Evaluation Result:\n"
|
||||
for tag in ["loss"]:
|
||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||
self.coordinator.print_on_master(msg)
|
||||
if self.save_dir is not None:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||
f.write(msg)
|
||||
step_bar.close()
|
||||
|
|
|
@ -9,6 +9,8 @@ import torch.distributed as dist
|
|||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.booster import Plugin
|
||||
|
||||
|
||||
class CycledDataLoader:
|
||||
"""
|
||||
|
@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any:
|
|||
return tree_map(_to, x)
|
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||
"""
|
||||
Perform all-reduce operation on the given tensor and compute the mean across all processes.
|
||||
|
||||
|
@ -95,8 +97,13 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
|||
Returns:
|
||||
torch.Tensor: The reduced tensor with mean computed across all processes.
|
||||
"""
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
# All reduce mean across DP group
|
||||
if plugin is not None:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
|
||||
tensor.div_(plugin.dp_size)
|
||||
else:
|
||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||
tensor.div_(dist.get_world_size())
|
||||
return tensor
|
||||
|
||||
|
||||
|
|
|
@ -267,6 +267,7 @@ def train(args):
|
|||
ref_model=ref_model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
plugin=plugin,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
|
|
|
@ -286,6 +286,7 @@ def train(args):
|
|||
ref_model=ref_model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
plugin=plugin,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
|
|
|
@ -250,6 +250,7 @@ def train(args):
|
|||
actor=model,
|
||||
booster=booster,
|
||||
actor_optim=optim,
|
||||
plugin=plugin,
|
||||
actor_lr_scheduler=lr_scheduler,
|
||||
tokenizer=tokenizer,
|
||||
max_epochs=args.max_epochs,
|
||||
|
|
|
@ -262,6 +262,7 @@ def train(args):
|
|||
model,
|
||||
booster,
|
||||
optim,
|
||||
plugin,
|
||||
lr_scheduler,
|
||||
tokenizer,
|
||||
loss_fn=loss_fn,
|
||||
|
|
|
@ -114,7 +114,7 @@ def train(args):
|
|||
parallel_output=False,
|
||||
max_norm=args.grad_clip,
|
||||
precision=args.mixed_precision,
|
||||
microbatch_size=args.batch_size,
|
||||
microbatch_size=args.microbatch_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
@ -269,6 +269,7 @@ def train(args):
|
|||
model=model,
|
||||
booster=booster,
|
||||
optim=optim,
|
||||
plugin=plugin,
|
||||
lr_scheduler=lr_scheduler,
|
||||
max_epochs=args.max_epochs,
|
||||
accumulation_steps=args.accumulation_steps,
|
||||
|
@ -344,6 +345,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||
parser.add_argument("--microbatch_size", type=int, default=1)
|
||||
args = parser.parse_args()
|
||||
if args.config_file is not None:
|
||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||
|
|
|
@ -61,7 +61,7 @@ def test_overfit():
|
|||
_, predicted = torch.max(outputs.data, 1)
|
||||
total = labels.size(0)
|
||||
correct = (predicted == Y).sum().item()
|
||||
assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset")
|
||||
assert correct / total > 0.95
|
||||
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
|
||||
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
|||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||
|
||||
set -xu
|
||||
|
||||
|
@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
|
|||
MODELS_DIR=$TEMP_DIR/models_config
|
||||
# Skip those tests due to CI tests timeout
|
||||
MODELS=('llama')
|
||||
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy
|
||||
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp' 'tp_pp')
|
||||
PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
|
||||
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
||||
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
|
||||
|
@ -91,7 +91,7 @@ SKIPPED_TESTS=(
|
|||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||
llama-gemini-20 # gemini doesn't support lora
|
||||
)
|
||||
|
||||
skip_eval=false
|
||||
GRAD_CKPTS=('--grad_checkpoint')
|
||||
for lora_rank in ${LORA_RANK[@]}; do
|
||||
for model in ${MODELS[@]}; do
|
||||
|
@ -129,15 +129,18 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
plugin='3d'
|
||||
fi
|
||||
if [[ $plugin == "tp_pp" ]]; then
|
||||
echo "Here"
|
||||
tp='2'
|
||||
bs='8'
|
||||
pp='2'
|
||||
plugin='3d'
|
||||
skip_eval=true
|
||||
fi
|
||||
if [[ $plugin == "pp" ]]; then
|
||||
bs='8'
|
||||
pp='2'
|
||||
plugin='3d'
|
||||
skip_eval=true
|
||||
fi
|
||||
if [[ $plugin == "sp_split_gather" ]]; then
|
||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||
|
@ -175,28 +178,53 @@ for lora_rank in ${LORA_RANK[@]}; do
|
|||
for split in $(seq -f "%05g" 0 0); do
|
||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||
done
|
||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--pp $pp \
|
||||
--zero_stage $zero_stage \
|
||||
--sp $sp \
|
||||
--sp_mode $sp_mode \
|
||||
$enable_sequence_parallelism \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
|
||||
if [[ $skip_eval ]]; then
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--pp $pp \
|
||||
--zero_stage $zero_stage \
|
||||
--sp $sp \
|
||||
--sp_mode $sp_mode \
|
||||
$enable_sequence_parallelism \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
else
|
||||
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||
--pretrain $pretrain \
|
||||
--tokenizer_dir $tokenizer_dir \
|
||||
--dataset ${dataset[@]} \
|
||||
--eval_dataset ${dataset[@]} \
|
||||
--save_path $MODEL_SAVE_PATH \
|
||||
--config_file $MODELS_DIR/config.jsonl \
|
||||
$lora_config \
|
||||
--plugin $plugin \
|
||||
--batch_size $bs \
|
||||
--max_epochs 1 \
|
||||
--accumulation_steps $grad_accu \
|
||||
--tp $tp \
|
||||
--pp $pp \
|
||||
--zero_stage $zero_stage \
|
||||
--sp $sp \
|
||||
--sp_mode $sp_mode \
|
||||
$enable_sequence_parallelism \
|
||||
--lr 2e-5 \
|
||||
$grad_ckpt \
|
||||
--max_len 400 \
|
||||
--use_flash_attn
|
||||
fi
|
||||
passed=$?
|
||||
if [ $passed -eq 0 ]; then
|
||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||
|
|
|
@ -349,7 +349,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
verbose: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
use_fp8: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
|
|
Loading…
Reference in New Issue