mirror of https://github.com/hpcaitech/ColossalAI
[ColossalChat] Add PP support (#6001)
* support pp training
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support pp training
* update rm
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* refactor
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update test case
* fix
* change to 4
* fix eval
* test
* add pp
* hotfix
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update
* skip pp eval
* update all reduce
* update sft
* update ignore
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* update no cache
* add eval
* remove fi
* remove debug
* remove parentheses to avoid warning
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* Revert "add eval"
This reverts commit 3ab2f6fa32
.
* add all reduce
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/6031/head
parent
0d3b0bd864
commit
39e2597426
|
@ -31,18 +31,18 @@ jobs:
|
||||||
|
|
||||||
- name: Install Colossal-AI
|
- name: Install Colossal-AI
|
||||||
run: |
|
run: |
|
||||||
BUILD_EXT=1 pip install -v -e .
|
BUILD_EXT=1 pip install --no-cache-dir -v -e .
|
||||||
|
|
||||||
- name: Install ChatGPT
|
- name: Install ChatGPT
|
||||||
run: |
|
run: |
|
||||||
cd applications/ColossalChat
|
cd applications/ColossalChat
|
||||||
pip install -v .
|
pip install --no-cache-dir -v .
|
||||||
export BUILD_EXT=1
|
export BUILD_EXT=1
|
||||||
pip install -r examples/requirements.txt
|
pip install --no-cache-dir -r examples/requirements.txt
|
||||||
|
|
||||||
- name: Install Transformers
|
- name: Install Transformers
|
||||||
run: |
|
run: |
|
||||||
pip install transformers==4.36.2
|
pip install --no-cache-dir transformers==4.36.2
|
||||||
|
|
||||||
- name: Execute Examples
|
- name: Execute Examples
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -161,3 +161,9 @@ applications/ColossalChat/sft_data
|
||||||
applications/ColossalChat/prompt_data
|
applications/ColossalChat/prompt_data
|
||||||
applications/ColossalChat/preference_data
|
applications/ColossalChat/preference_data
|
||||||
applications/ColossalChat/temp
|
applications/ColossalChat/temp
|
||||||
|
|
||||||
|
# Testing data
|
||||||
|
/kto_data/
|
||||||
|
/preference_data/
|
||||||
|
/prompt_data/
|
||||||
|
/sft_data/
|
||||||
|
|
|
@ -16,7 +16,7 @@ from coati.experience_buffer import NaiveExperienceBuffer
|
||||||
from coati.experience_maker import Experience
|
from coati.experience_maker import Experience
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster, Plugin
|
||||||
|
|
||||||
from .utils import is_rank_0
|
from .utils import is_rank_0
|
||||||
|
|
||||||
|
@ -38,6 +38,7 @@ class SLTrainer(ABC):
|
||||||
max_epochs: int,
|
max_epochs: int,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
start_epoch: int = 0,
|
start_epoch: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -45,6 +46,7 @@ class SLTrainer(ABC):
|
||||||
self.max_epochs = max_epochs
|
self.max_epochs = max_epochs
|
||||||
self.model = model
|
self.model = model
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
|
self.plugin = plugin
|
||||||
self.start_epoch = start_epoch
|
self.start_epoch = start_epoch
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -50,6 +50,7 @@ class DPOTrainer(SLTrainer):
|
||||||
ref_model: Any,
|
ref_model: Any,
|
||||||
booster: Booster,
|
booster: Booster,
|
||||||
actor_optim: Optimizer,
|
actor_optim: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
actor_lr_scheduler: _LRScheduler,
|
actor_lr_scheduler: _LRScheduler,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
max_epochs: int = 1,
|
max_epochs: int = 1,
|
||||||
|
@ -63,7 +64,9 @@ class DPOTrainer(SLTrainer):
|
||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: DistCoordinator = None,
|
coordinator: DistCoordinator = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
super().__init__(
|
||||||
|
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
|
||||||
|
)
|
||||||
self.ref_model = ref_model
|
self.ref_model = ref_model
|
||||||
self.actor_scheduler = actor_lr_scheduler
|
self.actor_scheduler = actor_lr_scheduler
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
|
@ -17,7 +17,7 @@ from torch.utils.data import DataLoader
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -53,6 +53,7 @@ class KTOTrainer(SLTrainer):
|
||||||
ref_model: Any,
|
ref_model: Any,
|
||||||
booster: Booster,
|
booster: Booster,
|
||||||
actor_optim: Optimizer,
|
actor_optim: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
actor_lr_scheduler: _LRScheduler,
|
actor_lr_scheduler: _LRScheduler,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
max_epochs: int = 1,
|
max_epochs: int = 1,
|
||||||
|
@ -66,7 +67,9 @@ class KTOTrainer(SLTrainer):
|
||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: DistCoordinator = None,
|
coordinator: DistCoordinator = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
super().__init__(
|
||||||
|
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
|
||||||
|
)
|
||||||
self.ref_model = ref_model
|
self.ref_model = ref_model
|
||||||
self.actor_scheduler = actor_lr_scheduler
|
self.actor_scheduler = actor_lr_scheduler
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
|
|
|
@ -16,7 +16,7 @@ from torch.utils.data import DataLoader
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ class ORPOTrainer(SLTrainer):
|
||||||
actor: Any,
|
actor: Any,
|
||||||
booster: Booster,
|
booster: Booster,
|
||||||
actor_optim: Optimizer,
|
actor_optim: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
actor_lr_scheduler: _LRScheduler,
|
actor_lr_scheduler: _LRScheduler,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
max_epochs: int = 1,
|
max_epochs: int = 1,
|
||||||
|
@ -59,7 +60,9 @@ class ORPOTrainer(SLTrainer):
|
||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: DistCoordinator = None,
|
coordinator: DistCoordinator = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, start_epoch=start_epoch)
|
super().__init__(
|
||||||
|
booster, max_epochs=max_epochs, model=actor, optimizer=actor_optim, plugin=plugin, start_epoch=start_epoch
|
||||||
|
)
|
||||||
self.actor_scheduler = actor_lr_scheduler
|
self.actor_scheduler = actor_lr_scheduler
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.odds_ratio_loss_fn = OddsRatioLoss()
|
self.odds_ratio_loss_fn = OddsRatioLoss()
|
||||||
|
|
|
@ -15,7 +15,7 @@ from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
@ -48,6 +48,7 @@ class RewardModelTrainer(SLTrainer):
|
||||||
model: Any,
|
model: Any,
|
||||||
booster: Booster,
|
booster: Booster,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
plugin: Plugin,
|
||||||
lr_scheduler: _LRScheduler,
|
lr_scheduler: _LRScheduler,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
loss_fn: Optional[Callable] = None,
|
loss_fn: Optional[Callable] = None,
|
||||||
|
@ -59,7 +60,9 @@ class RewardModelTrainer(SLTrainer):
|
||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: DistCoordinator = None,
|
coordinator: DistCoordinator = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs=max_epochs, model=model, optimizer=optimizer, start_epoch=start_epoch)
|
super().__init__(
|
||||||
|
booster, max_epochs=max_epochs, model=model, optimizer=optimizer, plugin=plugin, start_epoch=start_epoch
|
||||||
|
)
|
||||||
self.actor_scheduler = lr_scheduler
|
self.actor_scheduler = lr_scheduler
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)
|
self.loss_fn = loss_fn if loss_fn is not None else LogSigLoss(beta=beta)
|
||||||
|
|
|
@ -6,14 +6,16 @@ import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from coati.trainer.utils import all_reduce_mean
|
from coati.trainer.utils import all_reduce_mean
|
||||||
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
from coati.utils import AccumulativeMeanMeter, save_checkpoint
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm import trange
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
|
||||||
from .base import SLTrainer
|
from .base import SLTrainer
|
||||||
|
@ -40,6 +42,7 @@ class SFTTrainer(SLTrainer):
|
||||||
optim: Optimizer,
|
optim: Optimizer,
|
||||||
lr_scheduler: _LRScheduler,
|
lr_scheduler: _LRScheduler,
|
||||||
max_epochs: int = 2,
|
max_epochs: int = 2,
|
||||||
|
plugin: Plugin = None,
|
||||||
accumulation_steps: int = 8,
|
accumulation_steps: int = 8,
|
||||||
apply_loss_mask: bool = True,
|
apply_loss_mask: bool = True,
|
||||||
start_epoch=0,
|
start_epoch=0,
|
||||||
|
@ -47,7 +50,7 @@ class SFTTrainer(SLTrainer):
|
||||||
save_dir: str = None,
|
save_dir: str = None,
|
||||||
coordinator: Optional[DistCoordinator] = None,
|
coordinator: Optional[DistCoordinator] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(booster, max_epochs, model, optim, start_epoch=start_epoch)
|
super().__init__(booster, max_epochs, model, optim, plugin, start_epoch=start_epoch)
|
||||||
|
|
||||||
self.accumulation_steps = accumulation_steps
|
self.accumulation_steps = accumulation_steps
|
||||||
self.scheduler = lr_scheduler
|
self.scheduler = lr_scheduler
|
||||||
|
@ -94,6 +97,31 @@ class SFTTrainer(SLTrainer):
|
||||||
|
|
||||||
def _train(self, epoch: int):
|
def _train(self, epoch: int):
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||||
|
data_iter = iter(self.train_dataloader)
|
||||||
|
step_bar = tqdm(
|
||||||
|
range(len(self.train_dataloader)),
|
||||||
|
desc="Step",
|
||||||
|
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||||
|
)
|
||||||
|
for step in step_bar:
|
||||||
|
outputs = self.booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
self.model,
|
||||||
|
criterion=lambda outputs, inputs: outputs[0],
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
|
||||||
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
global_loss = all_reduce_mean(loss, self.plugin)
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
step_bar.set_postfix({"train/loss": global_loss.item()})
|
||||||
|
|
||||||
|
self.optimizer.step()
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
else:
|
||||||
step_bar = trange(
|
step_bar = trange(
|
||||||
len(self.train_dataloader) // self.accumulation_steps,
|
len(self.train_dataloader) // self.accumulation_steps,
|
||||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||||
|
@ -157,6 +185,41 @@ class SFTTrainer(SLTrainer):
|
||||||
self.accumulative_meter.reset()
|
self.accumulative_meter.reset()
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
if isinstance(self.plugin, HybridParallelPlugin) and self.plugin.pp_size > 1:
|
||||||
|
data_iter = iter(self.eval_dataloader)
|
||||||
|
step_bar = tqdm(
|
||||||
|
range(len(self.eval_dataloader)),
|
||||||
|
desc="Step",
|
||||||
|
disable=not (dist.get_rank() == dist.get_world_size() - 1),
|
||||||
|
)
|
||||||
|
for step in step_bar:
|
||||||
|
outputs = self.booster.execute_pipeline(
|
||||||
|
data_iter,
|
||||||
|
self.model,
|
||||||
|
criterion=lambda outputs, inputs: outputs[0],
|
||||||
|
optimizer=self.optimizer,
|
||||||
|
return_loss=True,
|
||||||
|
)
|
||||||
|
loss = outputs["loss"]
|
||||||
|
if self.booster.plugin.stage_manager.is_last_stage():
|
||||||
|
global_loss = all_reduce_mean(loss, self.plugin)
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
step_bar.set_postfix({"eval/loss": global_loss.item()})
|
||||||
|
self.accumulative_meter.add("loss", global_loss.item())
|
||||||
|
|
||||||
|
if dist.get_rank() == dist.get_world_size() - 1:
|
||||||
|
loss_mean = self.accumulative_meter.get("loss")
|
||||||
|
msg = "Evaluation Result:\n"
|
||||||
|
for tag in ["loss"]:
|
||||||
|
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||||
|
print(msg)
|
||||||
|
if self.save_dir is not None:
|
||||||
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
|
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||||
|
f.write(msg)
|
||||||
|
step_bar.close()
|
||||||
|
|
||||||
|
else:
|
||||||
step_bar = trange(
|
step_bar = trange(
|
||||||
len(self.eval_dataloader),
|
len(self.eval_dataloader),
|
||||||
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
desc=f"Epoch {epoch + 1}/{self.max_epochs}",
|
||||||
|
@ -172,11 +235,13 @@ class SFTTrainer(SLTrainer):
|
||||||
loss_mean = all_reduce_mean(tensor=outputs.loss)
|
loss_mean = all_reduce_mean(tensor=outputs.loss)
|
||||||
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
|
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
|
||||||
step_bar.update()
|
step_bar.update()
|
||||||
|
|
||||||
loss_mean = self.accumulative_meter.get("loss")
|
loss_mean = self.accumulative_meter.get("loss")
|
||||||
msg = "Evaluation Result:\n"
|
msg = "Evaluation Result:\n"
|
||||||
for tag in ["loss"]:
|
for tag in ["loss"]:
|
||||||
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
msg = msg + f"{tag}: {self.accumulative_meter.get(tag)}\n"
|
||||||
self.coordinator.print_on_master(msg)
|
self.coordinator.print_on_master(msg)
|
||||||
|
if self.save_dir is not None:
|
||||||
os.makedirs(self.save_dir, exist_ok=True)
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
with open(os.path.join(self.save_dir, f"eval_result_epoch{epoch}.txt"), "w") as f:
|
||||||
f.write(msg)
|
f.write(msg)
|
||||||
|
|
|
@ -9,6 +9,8 @@ import torch.distributed as dist
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from colossalai.booster import Plugin
|
||||||
|
|
||||||
|
|
||||||
class CycledDataLoader:
|
class CycledDataLoader:
|
||||||
"""
|
"""
|
||||||
|
@ -85,7 +87,7 @@ def to_device(x: Any, device: torch.device) -> Any:
|
||||||
return tree_map(_to, x)
|
return tree_map(_to, x)
|
||||||
|
|
||||||
|
|
||||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
def all_reduce_mean(tensor: torch.Tensor, plugin: Plugin = None) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Perform all-reduce operation on the given tensor and compute the mean across all processes.
|
Perform all-reduce operation on the given tensor and compute the mean across all processes.
|
||||||
|
|
||||||
|
@ -95,6 +97,11 @@ def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The reduced tensor with mean computed across all processes.
|
torch.Tensor: The reduced tensor with mean computed across all processes.
|
||||||
"""
|
"""
|
||||||
|
# All reduce mean across DP group
|
||||||
|
if plugin is not None:
|
||||||
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM, group=plugin.dp_group)
|
||||||
|
tensor.div_(plugin.dp_size)
|
||||||
|
else:
|
||||||
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
||||||
tensor.div_(dist.get_world_size())
|
tensor.div_(dist.get_world_size())
|
||||||
return tensor
|
return tensor
|
||||||
|
|
|
@ -267,6 +267,7 @@ def train(args):
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
actor_optim=optim,
|
actor_optim=optim,
|
||||||
|
plugin=plugin,
|
||||||
actor_lr_scheduler=lr_scheduler,
|
actor_lr_scheduler=lr_scheduler,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
|
|
|
@ -286,6 +286,7 @@ def train(args):
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
actor_optim=optim,
|
actor_optim=optim,
|
||||||
|
plugin=plugin,
|
||||||
actor_lr_scheduler=lr_scheduler,
|
actor_lr_scheduler=lr_scheduler,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
|
|
|
@ -250,6 +250,7 @@ def train(args):
|
||||||
actor=model,
|
actor=model,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
actor_optim=optim,
|
actor_optim=optim,
|
||||||
|
plugin=plugin,
|
||||||
actor_lr_scheduler=lr_scheduler,
|
actor_lr_scheduler=lr_scheduler,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
|
|
|
@ -262,6 +262,7 @@ def train(args):
|
||||||
model,
|
model,
|
||||||
booster,
|
booster,
|
||||||
optim,
|
optim,
|
||||||
|
plugin,
|
||||||
lr_scheduler,
|
lr_scheduler,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
|
|
|
@ -114,7 +114,7 @@ def train(args):
|
||||||
parallel_output=False,
|
parallel_output=False,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
microbatch_size=args.batch_size,
|
microbatch_size=args.microbatch_size,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
@ -269,6 +269,7 @@ def train(args):
|
||||||
model=model,
|
model=model,
|
||||||
booster=booster,
|
booster=booster,
|
||||||
optim=optim,
|
optim=optim,
|
||||||
|
plugin=plugin,
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
max_epochs=args.max_epochs,
|
max_epochs=args.max_epochs,
|
||||||
accumulation_steps=args.accumulation_steps,
|
accumulation_steps=args.accumulation_steps,
|
||||||
|
@ -344,6 +345,7 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--use_wandb", default=False, action="store_true")
|
parser.add_argument("--use_wandb", default=False, action="store_true")
|
||||||
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
parser.add_argument("--grad_checkpoint", default=False, action="store_true")
|
||||||
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
parser.add_argument("--use_flash_attn", default=False, action="store_true")
|
||||||
|
parser.add_argument("--microbatch_size", type=int, default=1)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if args.config_file is not None:
|
if args.config_file is not None:
|
||||||
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
os.makedirs(os.path.dirname(args.config_file), exist_ok=True)
|
||||||
|
|
|
@ -61,7 +61,7 @@ def test_overfit():
|
||||||
_, predicted = torch.max(outputs.data, 1)
|
_, predicted = torch.max(outputs.data, 1)
|
||||||
total = labels.size(0)
|
total = labels.size(0)
|
||||||
correct = (predicted == Y).sum().item()
|
correct = (predicted == Y).sum().item()
|
||||||
assert (correct / total > 0.95, "The model has not overfitted to the synthesized dataset")
|
assert correct / total > 0.95
|
||||||
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
|
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||||
}
|
}
|
||||||
|
|
||||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
set_n_least_used_CUDA_VISIBLE_DEVICES 4
|
||||||
|
|
||||||
set -xu
|
set -xu
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
|
||||||
MODELS_DIR=$TEMP_DIR/models_config
|
MODELS_DIR=$TEMP_DIR/models_config
|
||||||
# Skip those tests due to CI tests timeout
|
# Skip those tests due to CI tests timeout
|
||||||
MODELS=('llama')
|
MODELS=('llama')
|
||||||
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy
|
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu' 'pp' 'tp_pp')
|
||||||
PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
|
PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
|
||||||
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
||||||
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
|
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
|
||||||
|
@ -91,7 +91,7 @@ SKIPPED_TESTS=(
|
||||||
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
llama-gemini_auto-20 # gemini_auto plugin doesn't support lora
|
||||||
llama-gemini-20 # gemini doesn't support lora
|
llama-gemini-20 # gemini doesn't support lora
|
||||||
)
|
)
|
||||||
|
skip_eval=false
|
||||||
GRAD_CKPTS=('--grad_checkpoint')
|
GRAD_CKPTS=('--grad_checkpoint')
|
||||||
for lora_rank in ${LORA_RANK[@]}; do
|
for lora_rank in ${LORA_RANK[@]}; do
|
||||||
for model in ${MODELS[@]}; do
|
for model in ${MODELS[@]}; do
|
||||||
|
@ -129,15 +129,18 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
plugin='3d'
|
plugin='3d'
|
||||||
fi
|
fi
|
||||||
if [[ $plugin == "tp_pp" ]]; then
|
if [[ $plugin == "tp_pp" ]]; then
|
||||||
|
echo "Here"
|
||||||
tp='2'
|
tp='2'
|
||||||
bs='8'
|
bs='8'
|
||||||
pp='2'
|
pp='2'
|
||||||
plugin='3d'
|
plugin='3d'
|
||||||
|
skip_eval=true
|
||||||
fi
|
fi
|
||||||
if [[ $plugin == "pp" ]]; then
|
if [[ $plugin == "pp" ]]; then
|
||||||
bs='8'
|
bs='8'
|
||||||
pp='2'
|
pp='2'
|
||||||
plugin='3d'
|
plugin='3d'
|
||||||
|
skip_eval=true
|
||||||
fi
|
fi
|
||||||
if [[ $plugin == "sp_split_gather" ]]; then
|
if [[ $plugin == "sp_split_gather" ]]; then
|
||||||
enable_sequence_parallelism='--enable_sequence_parallelism'
|
enable_sequence_parallelism='--enable_sequence_parallelism'
|
||||||
|
@ -175,7 +178,31 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
for split in $(seq -f "%05g" 0 0); do
|
for split in $(seq -f "%05g" 0 0); do
|
||||||
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
dataset+=("$TEMP_DIR/rlhf_data/tokenized_${model}_sft/arrow/part-$split")
|
||||||
done
|
done
|
||||||
colossalai run --nproc_per_node 2 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
|
||||||
|
if [[ $skip_eval ]]; then
|
||||||
|
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||||
|
--pretrain $pretrain \
|
||||||
|
--tokenizer_dir $tokenizer_dir \
|
||||||
|
--dataset ${dataset[@]} \
|
||||||
|
--save_path $MODEL_SAVE_PATH \
|
||||||
|
--config_file $MODELS_DIR/config.jsonl \
|
||||||
|
$lora_config \
|
||||||
|
--plugin $plugin \
|
||||||
|
--batch_size $bs \
|
||||||
|
--max_epochs 1 \
|
||||||
|
--accumulation_steps $grad_accu \
|
||||||
|
--tp $tp \
|
||||||
|
--pp $pp \
|
||||||
|
--zero_stage $zero_stage \
|
||||||
|
--sp $sp \
|
||||||
|
--sp_mode $sp_mode \
|
||||||
|
$enable_sequence_parallelism \
|
||||||
|
--lr 2e-5 \
|
||||||
|
$grad_ckpt \
|
||||||
|
--max_len 400 \
|
||||||
|
--use_flash_attn
|
||||||
|
else
|
||||||
|
colossalai run --nproc_per_node 4 --master_port 31332 $EXAMPLES_DIR/training_scripts/train_sft.py \
|
||||||
--pretrain $pretrain \
|
--pretrain $pretrain \
|
||||||
--tokenizer_dir $tokenizer_dir \
|
--tokenizer_dir $tokenizer_dir \
|
||||||
--dataset ${dataset[@]} \
|
--dataset ${dataset[@]} \
|
||||||
|
@ -197,6 +224,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
$grad_ckpt \
|
$grad_ckpt \
|
||||||
--max_len 400 \
|
--max_len 400 \
|
||||||
--use_flash_attn
|
--use_flash_attn
|
||||||
|
fi
|
||||||
passed=$?
|
passed=$?
|
||||||
if [ $passed -eq 0 ]; then
|
if [ $passed -eq 0 ]; then
|
||||||
rm -rf ${MODEL_SAVE_PATH:?}/*
|
rm -rf ${MODEL_SAVE_PATH:?}/*
|
||||||
|
|
Loading…
Reference in New Issue