import argparse from contextlib import nullcontext from typing import Callable, List, Union import evaluate import torch import torch.distributed as dist import torch.nn as nn from data import GLUEDataBuilder from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoConfig, GPT2ForSequenceClassification, get_linear_schedule_with_warmup import colossalai from colossalai.accelerator import get_accelerator from colossalai.booster import Booster from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam # ============================== # Prepare Hyperparameters # ============================== NUM_EPOCHS = 3 BATCH_SIZE = 32 LEARNING_RATE = 2.4e-5 WEIGHT_DECAY = 0.01 WARMUP_FRACTION = 0.1 output_transform_fn = lambda x: x criterion = lambda x: x.loss def move_to_cuda(batch): return {k: v.cuda() for k, v in batch.items()} @torch.no_grad() def evaluate_model( model: nn.Module, criterion, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str, eval_splits: List[str], booster: Booster, coordinator: DistCoordinator, ): metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) model.eval() def evaluate_subset(dataloader: DataLoader): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() accum_loss = torch.zeros(1, device=get_accelerator().get_current_device()) for batch in dataloader: batch = move_to_cuda(batch) labels = batch["labels"] if use_pipeline: pg_mesh = booster.plugin.pg_mesh pp_group = booster.plugin.pp_group current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) current_rank = dist.get_rank() batch = iter([batch]) outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) if is_pp_last_stage: logits = outputs["outputs"]["logits"] val_loss = outputs["loss"] accum_loss.add_(val_loss) if num_labels > 1: preds = torch.argmax(logits, axis=1) elif num_labels == 1: preds = logits.squeeze() dist.broadcast_object_list([preds, val_loss], src=current_pp_group_ranks[-1], group=pp_group) metric.add_batch(predictions=preds, references=labels) elif current_rank in current_pp_group_ranks: object_list = [None, None] dist.broadcast_object_list(object_list, src=current_pp_group_ranks[-1], group=pp_group) metric.add_batch( predictions=object_list[0].to(get_accelerator().get_current_device()), references=labels ) accum_loss.add_(object_list[1].to(get_accelerator().get_current_device())) else: batch = move_to_cuda(batch) outputs = model(**batch) val_loss, logits = outputs[:2] accum_loss.add_(val_loss) if num_labels > 1: preds = torch.argmax(logits, axis=1) elif num_labels == 1: preds = logits.squeeze() metric.add_batch(predictions=preds, references=labels) results = metric.compute() dist.all_reduce(accum_loss.div_(len(dataloader))) if coordinator.is_master() and results is not None: results["loss"] = accum_loss.item() / coordinator.world_size return results if isinstance(test_dataloader, DataLoader): return evaluate_subset(test_dataloader) else: assert len(test_dataloader) == len(eval_splits) final_results = {} for split, sub_loader in zip(eval_splits, test_dataloader): results = evaluate_subset(sub_loader) final_results.update({f"{k}_{split}": v for k, v in results.items()}) return final_results def train_epoch( epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler, train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator, ): use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() total_step = len(train_dataloader) model.train() optimizer.zero_grad() train_dataloader_iter = iter(train_dataloader) with tqdm( range(total_step), desc=f"Epoch [{epoch + 1}/{NUM_EPOCHS}]", disable=not (coordinator.is_master() or is_pp_last_stage), ) as pbar: # Forward pass for _ in pbar: if use_pipeline: outputs = booster.execute_pipeline( train_dataloader_iter, model, _criterion, optimizer, return_loss=True ) # Backward and optimize if is_pp_last_stage: loss = outputs["loss"] pbar.set_postfix({"loss": loss.item()}) else: data = next(train_dataloader_iter) data = move_to_cuda(data) outputs = model(**data) loss = _criterion(outputs, None) # Backward booster.backward(loss, optimizer) pbar.set_postfix({"loss": loss.item()}) optimizer.step() optimizer.zero_grad() lr_scheduler.step() def main(): # ============================== # Parse Arguments # ============================== parser = argparse.ArgumentParser() parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run") parser.add_argument( "-p", "--plugin", type=str, default="torch_ddp", choices=["torch_ddp", "torch_ddp_fp16", "gemini", "low_level_zero", "hybrid_parallel"], help="plugin to use", ) parser.add_argument( "--model_type", type=str, default="gpt2", help="only gpt2 now", ) parser.add_argument("--target_f1", type=float, default=None, help="target f1 score. Raise exception if not reached") parser.add_argument("--use_lazy_init", type=bool, default=False, help="for initiating lazy init context") parser.add_argument("--use_fp8_comm", type=bool, default=False, help="for using fp8 during communication") args = parser.parse_args() if args.model_type == "gpt2": model_name = "gpt2" else: raise RuntimeError # ============================== # Launch Distributed Environment # ============================== colossalai.launch_from_torch(seed=42) coordinator = DistCoordinator() # local_batch_size = BATCH_SIZE // coordinator.world_size lr = LEARNING_RATE * coordinator.world_size # ============================== # Instantiate Plugin and Booster # ============================== booster_kwargs = {} if args.plugin == "torch_ddp_fp16": booster_kwargs["mixed_precision"] = "fp16" if args.plugin.startswith("torch_ddp"): plugin = TorchDDPPlugin(fp8_communication=args.use_fp8_comm) elif args.plugin == "gemini": plugin = GeminiPlugin(initial_scale=2**5) elif args.plugin == "low_level_zero": plugin = LowLevelZeroPlugin(initial_scale=2**5) elif args.plugin == "hybrid_parallel": # modify the param accordingly for finetuning test cases plugin = HybridParallelPlugin( tp_size=1, pp_size=2, num_microbatches=None, microbatch_size=1, enable_all_optimization=True, zero_stage=1, precision="fp16", initial_scale=1, fp8_communication=args.use_fp8_comm, ) booster = Booster(plugin=plugin, **booster_kwargs) # ============================== # Prepare Dataloader # ============================== data_builder = GLUEDataBuilder( model_name, plugin, args.task, train_batch_size=BATCH_SIZE, eval_batch_size=BATCH_SIZE ) train_dataloader = data_builder.train_dataloader() test_dataloader = data_builder.test_dataloader() # ==================================== # Prepare model, optimizer # ==================================== # gpt2 pretrained model cfg = AutoConfig.from_pretrained( model_name, num_labels=data_builder.num_labels, pad_token=data_builder.tokenizer.pad_token, pad_token_id=data_builder.tokenizer.pad_token_id, ) init_ctx = ( LazyInitContext(default_device=get_accelerator().get_current_device()) if isinstance(plugin, (GeminiPlugin)) else nullcontext() ) with init_ctx: if model_name == "gpt2": model = GPT2ForSequenceClassification.from_pretrained(model_name, config=cfg).cuda() else: raise RuntimeError # optimizer no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": WEIGHT_DECAY, }, { "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0, }, ] optimizer = HybridAdam(optimizer_grouped_parameters, lr=lr, eps=1e-8) # lr scheduler total_steps = len(train_dataloader) * NUM_EPOCHS num_warmup_steps = int(WARMUP_FRACTION * total_steps) lr_scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=total_steps, ) def _criterion(outputs, inputs): outputs = output_transform_fn(outputs) loss = criterion(outputs) return loss # ============================== # Boost with ColossalAI # ============================== model, optimizer, _criterion, _, lr_scheduler = booster.boost( model, optimizer, criterion=_criterion, lr_scheduler=lr_scheduler ) # ============================== # Train model # ============================== for epoch in range(NUM_EPOCHS): train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator) results = evaluate_model( model, _criterion, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, booster, coordinator, ) if coordinator.is_master(): print(results) if args.target_f1 is not None and "f1" in results: assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}' if __name__ == "__main__": main()