|
|
|
@ -11,7 +11,8 @@ from transformers.utils.versions import require_version
|
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.booster import Booster |
|
|
|
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin |
|
|
|
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin |
|
|
|
|
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule |
|
|
|
|
from colossalai.cluster import DistCoordinator |
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger |
|
|
|
|
from colossalai.nn.optimizer import HybridAdam |
|
|
|
@ -19,35 +20,54 @@ from colossalai.nn.optimizer import HybridAdam
|
|
|
|
|
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt") |
|
|
|
|
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") |
|
|
|
|
|
|
|
|
|
output_transform_fn = lambda x: x |
|
|
|
|
criterion = lambda x: x.loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def move_to_cuda(batch, device): |
|
|
|
|
return {k: v.to(device) for k, v in batch.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): |
|
|
|
|
def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator): |
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: |
|
|
|
|
|
|
|
|
|
for batch in pbar: |
|
|
|
|
|
|
|
|
|
# Forward |
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
batch = move_to_cuda(batch, torch.cuda.current_device()) |
|
|
|
|
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 |
|
|
|
|
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() |
|
|
|
|
total_step = len(dataloader) |
|
|
|
|
|
|
|
|
|
outputs = model(use_cache=False, **batch) |
|
|
|
|
loss = outputs['loss'] |
|
|
|
|
model.train() |
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
dataloader = iter(dataloader) |
|
|
|
|
with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}]', |
|
|
|
|
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar: |
|
|
|
|
|
|
|
|
|
# Forward pass |
|
|
|
|
for _ in pbar: |
|
|
|
|
if use_pipeline: |
|
|
|
|
outputs = booster.execute_pipeline(dataloader, |
|
|
|
|
model, |
|
|
|
|
_criterion, |
|
|
|
|
optimizer, |
|
|
|
|
return_loss=True, |
|
|
|
|
return_outputs=True) |
|
|
|
|
# Backward and optimize |
|
|
|
|
if is_pp_last_stage: |
|
|
|
|
loss = outputs['loss'] |
|
|
|
|
pbar.set_postfix({'loss': loss.item()}) |
|
|
|
|
else: |
|
|
|
|
data = next(dataloader) |
|
|
|
|
data = move_to_cuda(data) |
|
|
|
|
outputs = model(**data) |
|
|
|
|
loss = _criterion(outputs, None) |
|
|
|
|
# Backward |
|
|
|
|
booster.backward(loss, optimizer) |
|
|
|
|
pbar.set_postfix({'loss': loss.item()}) |
|
|
|
|
|
|
|
|
|
# Backward |
|
|
|
|
booster.backward(loss, optimizer) |
|
|
|
|
optimizer.step() |
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
lr_scheduler.step() |
|
|
|
|
|
|
|
|
|
# Print batch loss |
|
|
|
|
pbar.set_postfix({'loss': loss.item()}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
@ -86,6 +106,16 @@ def main():
|
|
|
|
|
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5) |
|
|
|
|
elif args.plugin == 'low_level_zero': |
|
|
|
|
plugin = LowLevelZeroPlugin(initial_scale=2**5) |
|
|
|
|
elif args.plugin == 'hybrid_parallel': |
|
|
|
|
# modify the param accordingly for finetuning test cases |
|
|
|
|
plugin = HybridParallelPlugin(tp_size=2, |
|
|
|
|
pp_size=2, |
|
|
|
|
num_microbatches=2, |
|
|
|
|
enable_all_optimization=True, |
|
|
|
|
zero_stage=0, |
|
|
|
|
precision='fp16', |
|
|
|
|
initial_scale=1) |
|
|
|
|
|
|
|
|
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) |
|
|
|
|
|
|
|
|
|
# Prepare tokenizer and dataloader |
|
|
|
@ -107,21 +137,28 @@ def main():
|
|
|
|
|
num_warmup_steps=num_warmup_steps, |
|
|
|
|
num_training_steps=len(dataloader) * args.num_epoch) |
|
|
|
|
|
|
|
|
|
# Define criterion |
|
|
|
|
def _criterion(outputs, inputs): |
|
|
|
|
outputs = output_transform_fn(outputs) |
|
|
|
|
loss = criterion(outputs) |
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
# Set booster |
|
|
|
|
booster = Booster(plugin=plugin, **booster_kwargs) |
|
|
|
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, |
|
|
|
|
optimizer=optimizer, |
|
|
|
|
dataloader=dataloader, |
|
|
|
|
lr_scheduler=lr_scheduler) |
|
|
|
|
model, optimizer, _criterion, dataloader, lr_scheduler = booster.boost(model=model, |
|
|
|
|
optimizer=optimizer, |
|
|
|
|
dataloader=dataloader, |
|
|
|
|
criterion=_criterion, |
|
|
|
|
lr_scheduler=lr_scheduler) |
|
|
|
|
|
|
|
|
|
# Start finetuning |
|
|
|
|
logger.info(f"Start finetuning", ranks=[0]) |
|
|
|
|
for epoch in range(args.num_epoch): |
|
|
|
|
train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator) |
|
|
|
|
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator) |
|
|
|
|
|
|
|
|
|
# Finish training and evaluate |
|
|
|
|
logger.info(f"Finish finetuning", ranks=[0]) |
|
|
|
|
booster.save_model(model, args.output_path) |
|
|
|
|
booster.save_model(model, args.output_path, shard=True) |
|
|
|
|
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|