mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [shardformer] update shardformer readme [shardformer] update shardformer readme [shardformer] update shardformer readme * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] change dataset * [shardformer] change dataset * [shardformer] fix CI * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix [example] update opt example [example] resolve comments fix fix * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * fix * update llama2 example * update llama2 example * fix * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * Update requirements.txt * update llama2 example * update llama2 example * update llama2 examplepull/4593/merge
flybird11111
1 year ago
committed by
GitHub
8 changed files with 402 additions and 35 deletions
@ -0,0 +1,295 @@
|
||||
import argparse |
||||
import math |
||||
import os |
||||
import resource |
||||
from contextlib import nullcontext |
||||
from functools import partial |
||||
from typing import Optional, Tuple |
||||
|
||||
import torch |
||||
import torch.distributed as dist |
||||
import torch.nn as nn |
||||
from attn import SUPPORT_XFORMERS, replace_xformers |
||||
from data_utils import load_json, prepare_dataloader, save_json |
||||
from datasets import load_dataset |
||||
from torch.optim import Optimizer |
||||
from torch.optim.lr_scheduler import _LRScheduler |
||||
from torch.utils.tensorboard import SummaryWriter |
||||
from tqdm import tqdm |
||||
from transformers.models.llama.configuration_llama import LlamaConfig |
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM |
||||
from transformers.models.llama.tokenization_llama import LlamaTokenizer |
||||
|
||||
import colossalai |
||||
from colossalai.booster import Booster |
||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin |
||||
from colossalai.cluster import DistCoordinator |
||||
from colossalai.lazy import LazyInitContext |
||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR |
||||
from colossalai.nn.optimizer import HybridAdam |
||||
from colossalai.utils import get_current_device |
||||
|
||||
|
||||
def get_model_numel(model: nn.Module) -> int: |
||||
return sum(p.numel() for p in model.parameters()) |
||||
|
||||
|
||||
def format_numel_str(numel: int) -> str: |
||||
B = 1024**3 |
||||
M = 1024**2 |
||||
K = 1024 |
||||
if numel >= B: |
||||
return f'{numel / B:.2f} B' |
||||
elif numel >= M: |
||||
return f'{numel / M:.2f} M' |
||||
elif numel >= K: |
||||
return f'{numel / K:.2f} K' |
||||
else: |
||||
return f'{numel}' |
||||
|
||||
|
||||
def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): |
||||
texts = [sample['prompt'] + sample['completion'] for sample in batch] |
||||
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length) |
||||
data = {k: v.cuda() for k, v in data.items()} |
||||
data['labels'] = data['input_ids'].clone() |
||||
return data |
||||
|
||||
|
||||
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: |
||||
dist.all_reduce(tensor, op=dist.ReduceOp.SUM) |
||||
tensor.div_(dist.get_world_size()) |
||||
return tensor |
||||
|
||||
|
||||
def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int, |
||||
batch_size: int, coordinator: DistCoordinator, save_dir: str): |
||||
save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}') |
||||
os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True) |
||||
|
||||
booster.save_model(model, os.path.join(save_dir, 'model'), shard=True) |
||||
booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True) |
||||
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler')) |
||||
running_states = { |
||||
'epoch': epoch, |
||||
'step': step, |
||||
'sample_start_index': step * batch_size, |
||||
} |
||||
if coordinator.is_master(): |
||||
save_json(running_states, os.path.join(save_dir, 'running_states.json')) |
||||
|
||||
|
||||
def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, |
||||
load_dir: str) -> Tuple[int, int, int]: |
||||
booster.load_model(model, os.path.join(load_dir, 'model')) |
||||
booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer')) |
||||
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler')) |
||||
running_states = load_json(os.path.join(load_dir, 'running_states.json')) |
||||
return running_states['epoch'], running_states['step'], running_states['sample_start_index'] |
||||
|
||||
|
||||
def _criterion(outputs, inputs): |
||||
return outputs.loss |
||||
|
||||
|
||||
def main(): |
||||
# ============================== |
||||
# Parse Arguments |
||||
# ============================== |
||||
parser = argparse.ArgumentParser() |
||||
parser.add_argument('--model_path', type=str, help="pretrained checkpoint path, used with mode==finetune") |
||||
parser.add_argument('-p', |
||||
'--plugin', |
||||
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'], |
||||
default='gemini', |
||||
help='Choose which plugin to use') |
||||
parser.add_argument('-d', '--dataset', type=str, default='yizhongw/self_instruct', help='Data set path') |
||||
parser.add_argument('--task_name', type=str, default="super_natural_instructions", help='task to run') |
||||
parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs') |
||||
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size') |
||||
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate') |
||||
parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay') |
||||
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing') |
||||
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length') |
||||
parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision') |
||||
parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval') |
||||
parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory') |
||||
parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint') |
||||
parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping') |
||||
parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory') |
||||
parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention') |
||||
args = parser.parse_args() |
||||
|
||||
# ============================== |
||||
# Initialize Distributed Training |
||||
# ============================== |
||||
colossalai.launch_from_torch({}) |
||||
coordinator = DistCoordinator() |
||||
|
||||
# ============================== |
||||
# Initialize Booster |
||||
# ============================== |
||||
if args.plugin == 'gemini': |
||||
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip) |
||||
elif args.plugin == 'gemini_auto': |
||||
plugin = GeminiPlugin(precision=args.mixed_precision, |
||||
placement_policy='auto', |
||||
initial_scale=2**16, |
||||
max_norm=args.grad_clip) |
||||
elif args.plugin == 'zero2': |
||||
plugin = LowLevelZeroPlugin(stage=2, |
||||
precision=args.mixed_precision, |
||||
initial_scale=2**16, |
||||
max_norm=args.grad_clip) |
||||
elif args.plugin == 'zero2_cpu': |
||||
plugin = LowLevelZeroPlugin(stage=2, |
||||
precision=args.mixed_precision, |
||||
initial_scale=2**16, |
||||
cpu_offload=True, |
||||
max_norm=args.grad_clip) |
||||
elif args.plugin == 'hybrid_parallel': |
||||
# modify the param accordingly, default configuration is for llama2-7b |
||||
plugin = HybridParallelPlugin(tp_size=4, |
||||
pp_size=2, |
||||
num_microbatches=None, |
||||
microbatch_size=1, |
||||
enable_jit_fused=False, |
||||
zero_stage=0, |
||||
precision='fp32', |
||||
initial_scale=1) |
||||
else: |
||||
raise ValueError(f'Unknown plugin {args.plugin}') |
||||
|
||||
booster = Booster(plugin=plugin) |
||||
|
||||
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 |
||||
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() |
||||
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) |
||||
|
||||
# ============================== |
||||
# Initialize Tensorboard |
||||
# ============================== |
||||
if print_flag: |
||||
os.makedirs(args.tensorboard_dir, exist_ok=True) |
||||
writer = SummaryWriter(args.tensorboard_dir) |
||||
|
||||
# ============================== |
||||
# Initialize Model, Optimizer and LR Scheduler |
||||
# ============================== |
||||
|
||||
config = LlamaConfig.from_pretrained(args.model_path) |
||||
# use lazy init when using GeminiPlugin |
||||
init_ctx = LazyInitContext( |
||||
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() |
||||
|
||||
with init_ctx: |
||||
model = LlamaForCausalLM(config) |
||||
|
||||
# ============================== |
||||
# Initialize Tokenizer, Dataset and Dataloader |
||||
# ============================== |
||||
tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer') |
||||
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257 |
||||
tokenizer.pad_token = tokenizer.unk_token |
||||
|
||||
dataset = load_dataset(args.dataset, args.task_name) |
||||
train_ds = dataset['train'] |
||||
dataloader = prepare_dataloader(train_ds, |
||||
batch_size=args.batch_size, |
||||
shuffle=True, |
||||
drop_last=True, |
||||
collate_fn=partial(tokenize_batch_for_finetune, |
||||
tokenizer=tokenizer, |
||||
max_length=args.max_length)) |
||||
|
||||
if args.grad_checkpoint: |
||||
model.gradient_checkpointing_enable() |
||||
if args.flash_attention: |
||||
assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed' |
||||
replace_xformers(model) |
||||
|
||||
model_numel = get_model_numel(model) |
||||
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}') |
||||
|
||||
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay) |
||||
total_step = args.num_epochs * len(dataloader) |
||||
lr_scheduler = CosineAnnealingWarmupLR(optimizer, |
||||
total_steps=total_step, |
||||
warmup_steps=math.ceil(total_step * 0.03), |
||||
eta_min=0.1 * args.lr) |
||||
default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16 |
||||
torch.set_default_dtype(default_dtype) |
||||
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model, |
||||
optimizer, |
||||
dataloader=dataloader, |
||||
lr_scheduler=lr_scheduler) |
||||
torch.set_default_dtype(torch.float) |
||||
|
||||
booster.load_model(model, args.model_path) |
||||
|
||||
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') |
||||
coordinator.print_on_master( |
||||
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB') |
||||
|
||||
# load checkpoint if specified |
||||
start_epoch = 0 |
||||
start_step = 0 |
||||
sampler_start_idx = 0 |
||||
if args.load is not None: |
||||
coordinator.print_on_master('Loading checkpoint') |
||||
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load) |
||||
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}') |
||||
|
||||
num_steps_per_epoch = len(dataloader) |
||||
|
||||
# if resume training, set the sampler start index to the correct value |
||||
dataloader.sampler.set_start_index(sampler_start_idx) |
||||
for epoch in range(start_epoch, args.num_epochs): |
||||
dataloader.sampler.set_epoch(epoch) |
||||
step_nums = num_steps_per_epoch - start_step |
||||
dataloader_iter = iter(dataloader) |
||||
|
||||
with tqdm(range(step_nums), |
||||
desc=f'Epoch {epoch}', |
||||
disable=not print_flag, |
||||
total=num_steps_per_epoch, |
||||
initial=start_step) as pbar: |
||||
for step in pbar: |
||||
if use_pipeline: |
||||
outputs = booster.execute_pipeline(dataloader_iter, |
||||
model, |
||||
_criterion, |
||||
optimizer, |
||||
return_loss=True, |
||||
return_outputs=True) |
||||
loss = outputs["loss"] |
||||
else: |
||||
batch = next(dataloader_iter) |
||||
outputs = model(**batch) |
||||
loss = outputs[0] |
||||
booster.backward(loss, optimizer) |
||||
|
||||
optimizer.step() |
||||
lr_scheduler.step() |
||||
optimizer.zero_grad() |
||||
|
||||
if not use_pipeline: |
||||
all_reduce_mean(loss) |
||||
if print_flag: |
||||
pbar.set_postfix({'loss': loss.item()}) |
||||
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step) |
||||
|
||||
if args.save_interval > 0 and (step + 1) % args.save_interval == 0: |
||||
coordinator.print_on_master(f'Saving checkpoint') |
||||
save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator, |
||||
args.save_dir) |
||||
coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}') |
||||
# the continue epochs are not resumed, so we need to reset the sampler start index and start step |
||||
dataloader.sampler.set_start_index(0) |
||||
start_step = 0 |
||||
|
||||
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB') |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
main() |
Loading…
Reference in new issue