Browse Source

[example] llama2 add fine-tune example (#4673)

* [shardformer] update shardformer readme

[shardformer] update shardformer readme

[shardformer] update shardformer readme

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] update llama2/opt finetune example and shardformer update to llama2

* [shardformer] change dataset

* [shardformer] change dataset

* [shardformer] fix CI

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

* [shardformer] fix

[example] update opt example

[example] resolve comments

fix

fix

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* [example] llama2 add finetune example

* fix

* update llama2 example

* update llama2 example

* fix

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* update llama2 example

* Update requirements.txt

* update llama2 example

* update llama2 example

* update llama2 example
pull/4593/merge
flybird11111 1 year ago committed by GitHub
parent
commit
4c4482f3ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
  2. 7
      examples/language/bert/finetune.py
  3. 39
      examples/language/llama2/README.md
  4. 295
      examples/language/llama2/finetune.py
  5. 79
      examples/language/llama2/pretrain.py
  6. 2
      examples/language/llama2/requirements.txt
  7. 7
      examples/language/opt/README.md
  8. 4
      examples/language/opt/requirements.txt

4
colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

@ -13,6 +13,7 @@ from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO from .general_checkpoint_io import GeneralCheckpointIO
@ -71,6 +72,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
self.verbose = verbose self.verbose = verbose
self.working_to_master_map = None self.working_to_master_map = None
self.master_to_working_map = None self.master_to_working_map = None
self.coordinator = DistCoordinator()
@staticmethod @staticmethod
def _model_sharder(model: nn.Module, def _model_sharder(model: nn.Module,
@ -655,7 +657,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dist.all_gather(gather_tensor, v, group=tp_group) dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim) v = torch.cat(gather_tensor, dim=partition_dim)
state_[k] = v.detach().clone().cpu() state_[k] = v.detach().clone().cpu()
return state_ return state_

7
examples/language/bert/finetune.py

@ -129,14 +129,13 @@ def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion:
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
total_step = len(train_dataloader) total_step = len(train_dataloader)
model.train() model.train()
optimizer.zero_grad() optimizer.zero_grad()
train_dataloader_iter = iter(train_dataloader) train_dataloader_iter = iter(train_dataloader)
with tqdm(range(total_step), with tqdm(range(total_step), desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not print_flag) as pbar:
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
# Forward pass # Forward pass
for _ in pbar: for _ in pbar:
if use_pipeline: if use_pipeline:
@ -192,13 +191,13 @@ def main():
model_name = "albert-xxlarge-v2" model_name = "albert-xxlarge-v2"
else: else:
raise RuntimeError raise RuntimeError
# ============================== # ==============================
# Launch Distributed Environment # Launch Distributed Environment
# ============================== # ==============================
colossalai.launch_from_torch(config={}, seed=42) colossalai.launch_from_torch(config={}, seed=42)
coordinator = DistCoordinator() coordinator = DistCoordinator()
# local_batch_size = BATCH_SIZE // coordinator.world_size
lr = LEARNING_RATE * coordinator.world_size lr = LEARNING_RATE * coordinator.world_size
# ============================== # ==============================

39
examples/language/llama2/README.md

@ -92,7 +92,7 @@ Make sure master node can access all nodes (including itself) by ssh without pas
Here is details about CLI arguments: Here is details about CLI arguments:
- Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2. - Model configuration: `-c`, `--config`. `7b`, `13b`, `30b` and `65b` are supported for LLaMA-1, `7b`, `13b`, and `70b` are supported for LLaMA-2.
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins). - Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama. - Dataset path: `-d`, `--dataset`. The default dataset is `togethercomputer/RedPajama-Data-1T-Sample`. It support any dataset from `datasets` with the same data format as RedPajama.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1. - Number of epochs: `-e`, `--num_epochs`. The default value is 1.
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2. - Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
@ -195,3 +195,40 @@ If you run the above command successfully, you will get the following results:
year={2023} year={2023}
} }
``` ```
# Fine-tune Llama2
We also provide a example to fine-tune llama2 in `finetune.py`,
Make sure master node can access all nodes (including itself) by ssh without password.
Here is details about CLI arguments:
- Pretrained checkpoint path: `--model_path`, the path of your model checkpoint, it can be your local directory or a Hugging Face tag.
- Booster plugin: `-p`, `--plugin`. `gemini`, `gemini_auto`, `zero2`, `hybrid_parallel` and `zero2_cpu` are supported. For more details, please refer to [Booster plugins](https://colossalai.org/docs/basics/booster_plugins).
- Dataset path: `-d`, `--dataset`. The default dataset is `yizhongw/self_instruct`. It support any dataset from `datasets` with the same data format as `yizhongw/self_instruct`.
- task name: `--task_name`, the task to fine-tune, it's also related to the target of loading dataset, The default value is `super_natural_instructions`.
- Number of epochs: `-e`, `--num_epochs`. The default value is 1.
- Local batch size: `-b`, `--batch_size`. Batch size per GPU. The default value is 2.
- Learning rate: `--lr`. The default value is 3e-4.
- Weight decay: `-w`, `--weight_decay`. The default value is 0.1.
- Gradient checkpointing: `-g`, `--gradient_checkpoint`. The default value is `False`. This saves memory at the cost of speed. You'd better enable this option when training with a large batch size.
- Max length: `-l`, `--max_length`. The default value is 4096.
- Mixed precision: `-x`, `--mixed_precision`. The default value is "fp16". "fp16" and "bf16" are supported.
- Save interval: `-i`, `--save_interval`. The interval (steps) of saving checkpoints. The default value is 1000.
- Checkpoint directory: `-o`, `--save_dir`. The directoty path to save checkpoints. The default value is `checkpoint`.
- Checkpoint to load: `-f`, `--load`. The checkpoint path to load. The default value is `None`.
- Gradient clipping: `--gradient_clipping`. The default value is 1.0.
- Tensorboard log directory: `-t`, `--tensorboard_dir`. The directory path to save tensorboard logs. The default value is `tb_logs`.
- Flash attention: `-a`, `--flash_attention`. If you want to use flash attention, you must install `flash-attn`. The default value is `False`. This is helpful to accelerate training while saving memory. We recommend you always use flash attention.
```shell
torchrun --standalone --nproc_per_node 8 finetune.py \
--plugin "hybrid_parallel" \
--dataset "yizhongw/self_instruct" \
--model_path "/path/llama" \
--task_name "super_natural_instructions" \
--save_dir "/path/output"
```

295
examples/language/llama2/finetune.py

@ -0,0 +1,295 @@
import argparse
import math
import os
import resource
from contextlib import nullcontext
from functools import partial
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn as nn
from attn import SUPPORT_XFORMERS, replace_xformers
from data_utils import load_json, prepare_dataloader, save_json
from datasets import load_dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.llama.tokenization_llama import LlamaTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def get_model_numel(model: nn.Module) -> int:
return sum(p.numel() for p in model.parameters())
def format_numel_str(numel: int) -> str:
B = 1024**3
M = 1024**2
K = 1024
if numel >= B:
return f'{numel / B:.2f} B'
elif numel >= M:
return f'{numel / M:.2f} M'
elif numel >= K:
return f'{numel / K:.2f} K'
else:
return f'{numel}'
def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
texts = [sample['prompt'] + sample['completion'] for sample in batch]
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
data = {k: v.cuda() for k, v in data.items()}
data['labels'] = data['input_ids'].clone()
return data
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
tensor.div_(dist.get_world_size())
return tensor
def save(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler, epoch: int, step: int,
batch_size: int, coordinator: DistCoordinator, save_dir: str):
save_dir = os.path.join(save_dir, f'epoch{epoch}-step{step}')
os.makedirs(os.path.join(save_dir, 'model'), exist_ok=True)
booster.save_model(model, os.path.join(save_dir, 'model'), shard=True)
booster.save_optimizer(optimizer, os.path.join(save_dir, 'optimizer'), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, 'lr_scheduler'))
running_states = {
'epoch': epoch,
'step': step,
'sample_start_index': step * batch_size,
}
if coordinator.is_master():
save_json(running_states, os.path.join(save_dir, 'running_states.json'))
def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler: _LRScheduler,
load_dir: str) -> Tuple[int, int, int]:
booster.load_model(model, os.path.join(load_dir, 'model'))
booster.load_optimizer(optimizer, os.path.join(load_dir, 'optimizer'))
booster.load_lr_scheduler(lr_scheduler, os.path.join(load_dir, 'lr_scheduler'))
running_states = load_json(os.path.join(load_dir, 'running_states.json'))
return running_states['epoch'], running_states['step'], running_states['sample_start_index']
def _criterion(outputs, inputs):
return outputs.loss
def main():
# ==============================
# Parse Arguments
# ==============================
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help="pretrained checkpoint path, used with mode==finetune")
parser.add_argument('-p',
'--plugin',
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
default='gemini',
help='Choose which plugin to use')
parser.add_argument('-d', '--dataset', type=str, default='yizhongw/self_instruct', help='Data set path')
parser.add_argument('--task_name', type=str, default="super_natural_instructions", help='task to run')
parser.add_argument('-e', '--num_epochs', type=int, default=1, help='Number of epochs')
parser.add_argument('-b', '--batch_size', type=int, default=2, help='Local batch size')
parser.add_argument('--lr', type=float, default=3e-4, help='Learning rate')
parser.add_argument('-w', '--weigth_decay', type=float, default=0.1, help='Weight decay')
parser.add_argument('-g', '--grad_checkpoint', action='store_true', help='Use gradient checkpointing')
parser.add_argument('-l', '--max_length', type=int, default=4096, help='Max sequence length')
parser.add_argument('-x', '--mixed_precision', default='fp16', choices=['fp16', 'bf16'], help='Mixed precision')
parser.add_argument('-i', '--save_interval', type=int, default=1000, help='Save interval')
parser.add_argument('-o', '--save_dir', type=str, default='checkpoint', help='Checkpoint directory')
parser.add_argument('-f', '--load', type=str, default=None, help='Load checkpoint')
parser.add_argument('--grad_clip', type=float, default=1.0, help='Gradient clipping')
parser.add_argument('-t', '--tensorboard_dir', type=str, default='tb_logs', help='Tensorboard directory')
parser.add_argument('-a', '--flash_attention', action='store_true', help='Use Flash Attention')
args = parser.parse_args()
# ==============================
# Initialize Distributed Training
# ==============================
colossalai.launch_from_torch({})
coordinator = DistCoordinator()
# ==============================
# Initialize Booster
# ==============================
if args.plugin == 'gemini':
plugin = GeminiPlugin(precision=args.mixed_precision, initial_scale=2**16, max_norm=args.grad_clip)
elif args.plugin == 'gemini_auto':
plugin = GeminiPlugin(precision=args.mixed_precision,
placement_policy='auto',
initial_scale=2**16,
max_norm=args.grad_clip)
elif args.plugin == 'zero2':
plugin = LowLevelZeroPlugin(stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
max_norm=args.grad_clip)
elif args.plugin == 'zero2_cpu':
plugin = LowLevelZeroPlugin(stage=2,
precision=args.mixed_precision,
initial_scale=2**16,
cpu_offload=True,
max_norm=args.grad_clip)
elif args.plugin == 'hybrid_parallel':
# modify the param accordingly, default configuration is for llama2-7b
plugin = HybridParallelPlugin(tp_size=4,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_jit_fused=False,
zero_stage=0,
precision='fp32',
initial_scale=1)
else:
raise ValueError(f'Unknown plugin {args.plugin}')
booster = Booster(plugin=plugin)
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
# ==============================
# Initialize Tensorboard
# ==============================
if print_flag:
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ==============================
# Initialize Model, Optimizer and LR Scheduler
# ==============================
config = LlamaConfig.from_pretrained(args.model_path)
# use lazy init when using GeminiPlugin
init_ctx = LazyInitContext(
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
with init_ctx:
model = LlamaForCausalLM(config)
# ==============================
# Initialize Tokenizer, Dataset and Dataloader
# ==============================
tokenizer = LlamaTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
# follows fast chat: https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py#L257
tokenizer.pad_token = tokenizer.unk_token
dataset = load_dataset(args.dataset, args.task_name)
train_ds = dataset['train']
dataloader = prepare_dataloader(train_ds,
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
collate_fn=partial(tokenize_batch_for_finetune,
tokenizer=tokenizer,
max_length=args.max_length))
if args.grad_checkpoint:
model.gradient_checkpointing_enable()
if args.flash_attention:
assert SUPPORT_XFORMERS, 'Use flash attention while xfomers is not installed'
replace_xformers(model)
model_numel = get_model_numel(model)
coordinator.print_on_master(f'Model params: {format_numel_str(model_numel)}')
optimizer = HybridAdam(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weigth_decay)
total_step = args.num_epochs * len(dataloader)
lr_scheduler = CosineAnnealingWarmupLR(optimizer,
total_steps=total_step,
warmup_steps=math.ceil(total_step * 0.03),
eta_min=0.1 * args.lr)
default_dtype = torch.float16 if args.mixed_precision == 'fp16' else torch.bfloat16
torch.set_default_dtype(default_dtype)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model,
optimizer,
dataloader=dataloader,
lr_scheduler=lr_scheduler)
torch.set_default_dtype(torch.float)
booster.load_model(model, args.model_path)
coordinator.print_on_master(f'Booster init max CUDA memory: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
coordinator.print_on_master(
f'Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB')
# load checkpoint if specified
start_epoch = 0
start_step = 0
sampler_start_idx = 0
if args.load is not None:
coordinator.print_on_master('Loading checkpoint')
start_epoch, start_step, sampler_start_idx = load(booster, model, optimizer, lr_scheduler, args.load)
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
num_steps_per_epoch = len(dataloader)
# if resume training, set the sampler start index to the correct value
dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch)
step_nums = num_steps_per_epoch - start_step
dataloader_iter = iter(dataloader)
with tqdm(range(step_nums),
desc=f'Epoch {epoch}',
disable=not print_flag,
total=num_steps_per_epoch,
initial=start_step) as pbar:
for step in pbar:
if use_pipeline:
outputs = booster.execute_pipeline(dataloader_iter,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
loss = outputs["loss"]
else:
batch = next(dataloader_iter)
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if not use_pipeline:
all_reduce_mean(loss)
if print_flag:
pbar.set_postfix({'loss': loss.item()})
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
if args.save_interval > 0 and (step + 1) % args.save_interval == 0:
coordinator.print_on_master(f'Saving checkpoint')
save(booster, model, optimizer, lr_scheduler, epoch, step + 1, args.batch_size, coordinator,
args.save_dir)
coordinator.print_on_master(f'Saved checkpoint at epoch {epoch} step {step + 1}')
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
start_step = 0
coordinator.print_on_master(f'Max CUDA memory usage: {torch.cuda.max_memory_allocated()/1024**2:.2f} MB')
if __name__ == '__main__':
main()

79
examples/language/llama2/pretrain.py

@ -21,7 +21,7 @@ from transformers.models.llama.tokenization_llama import LlamaTokenizer
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -65,9 +65,10 @@ def format_numel_str(numel: int) -> str:
return f'{numel}' return f'{numel}'
def tokenize_batch(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048): def tokenize_batch_for_pretrain(batch, tokenizer: Optional[LlamaTokenizer] = None, max_length: int = 2048):
texts = [sample['text'] for sample in batch] texts = [sample['text'] for sample in batch]
data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length) data = tokenizer(texts, return_tensors="pt", padding='max_length', truncation=True, max_length=max_length)
data = {k: v.cuda() for k, v in data.items()}
data['labels'] = data['input_ids'].clone() data['labels'] = data['input_ids'].clone()
return data return data
@ -104,6 +105,10 @@ def load(booster: Booster, model: nn.Module, optimizer: Optimizer, lr_scheduler:
return running_states['epoch'], running_states['step'], running_states['sample_start_index'] return running_states['epoch'], running_states['step'], running_states['sample_start_index']
def _criterion(outputs, inputs):
return outputs.loss
def main(): def main():
# ============================== # ==============================
# Parse Arguments # Parse Arguments
@ -112,7 +117,7 @@ def main():
parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration') parser.add_argument('-c', '--config', type=str, default='7b', help='Model configuration')
parser.add_argument('-p', parser.add_argument('-p',
'--plugin', '--plugin',
choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu'], choices=['gemini', 'gemini_auto', 'zero2', 'zero2_cpu', 'hybrid_parallel'],
default='gemini', default='gemini',
help='Choose which plugin to use') help='Choose which plugin to use')
parser.add_argument('-d', parser.add_argument('-d',
@ -142,13 +147,6 @@ def main():
colossalai.launch_from_torch({}) colossalai.launch_from_torch({})
coordinator = DistCoordinator() coordinator = DistCoordinator()
# ==============================
# Initialize Tensorboard
# ==============================
if coordinator.is_master():
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ============================== # ==============================
# Initialize Booster # Initialize Booster
# ============================== # ==============================
@ -170,11 +168,32 @@ def main():
initial_scale=2**16, initial_scale=2**16,
cpu_offload=True, cpu_offload=True,
max_norm=args.grad_clip) max_norm=args.grad_clip)
elif args.plugin == 'hybrid_parallel':
# modify the param accordingly, default configuration is for llama2-7b
plugin = HybridParallelPlugin(tp_size=4,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_jit_fused=False,
zero_stage=0,
precision='fp32',
initial_scale=1)
else: else:
raise ValueError(f'Unknown plugin {args.plugin}') raise ValueError(f'Unknown plugin {args.plugin}')
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage)
# ==============================
# Initialize Tensorboard
# ==============================
if print_flag:
os.makedirs(args.tensorboard_dir, exist_ok=True)
writer = SummaryWriter(args.tensorboard_dir)
# ============================== # ==============================
# Initialize Tokenizer, Dataset and Dataloader # Initialize Tokenizer, Dataset and Dataloader
# ============================== # ==============================
@ -188,12 +207,15 @@ def main():
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=partial(tokenize_batch, tokenizer=tokenizer, max_length=args.max_length)) collate_fn=partial(tokenize_batch_for_pretrain,
tokenizer=tokenizer,
max_length=args.max_length))
# ============================== # ==============================
# Initialize Model, Optimizer and LR Scheduler # Initialize Model, Optimizer and LR Scheduler
# ============================== # ==============================
config = MODEL_CONFIGS[args.config] config = MODEL_CONFIGS[args.config]
# use lazy init when using GeminiPlugin
init_ctx = LazyInitContext( init_ctx = LazyInitContext(
default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext() default_device=get_current_device()) if isinstance(plugin, GeminiPlugin) else nullcontext()
@ -236,27 +258,42 @@ def main():
coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}') coordinator.print_on_master(f'Loaded checkpoint {args.load} at epoch {start_epoch} step {start_step}')
num_steps_per_epoch = len(dataloader) num_steps_per_epoch = len(dataloader)
# if resume training, set the sampler start index to the correct value # if resume training, set the sampler start index to the correct value
dataloader.sampler.set_start_index(sampler_start_idx) dataloader.sampler.set_start_index(sampler_start_idx)
for epoch in range(start_epoch, args.num_epochs): for epoch in range(start_epoch, args.num_epochs):
dataloader.sampler.set_epoch(epoch) dataloader.sampler.set_epoch(epoch)
with tqdm(enumerate(dataloader), step_nums = num_steps_per_epoch - start_step
dataloader_iter = iter(dataloader)
with tqdm(range(step_nums),
desc=f'Epoch {epoch}', desc=f'Epoch {epoch}',
disable=not coordinator.is_master(), disable=not print_flag,
total=num_steps_per_epoch, total=num_steps_per_epoch,
initial=start_step) as pbar: initial=start_step) as pbar:
for step, batch in pbar: for step in pbar:
batch = {k: v.cuda() for k, v in batch.items()} if use_pipeline:
outputs = model(**batch) outputs = booster.execute_pipeline(dataloader_iter,
loss = outputs[0] model,
booster.backward(loss, optimizer) _criterion,
optimizer,
return_loss=True,
return_outputs=True)
loss = outputs["loss"]
else:
batch = next(dataloader_iter)
outputs = model(**batch)
loss = outputs[0]
booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
all_reduce_mean(loss) if not use_pipeline:
pbar.set_postfix({'loss': loss.item()}) all_reduce_mean(loss)
if coordinator.is_master(): if print_flag:
pbar.set_postfix({'loss': loss.item()})
writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step) writer.add_scalar('loss', loss.item(), epoch * num_steps_per_epoch + step)
if args.save_interval > 0 and (step + 1) % args.save_interval == 0: if args.save_interval > 0 and (step + 1) % args.save_interval == 0:

2
examples/language/llama2/requirements.txt

@ -1,4 +1,4 @@
colossalai>=0.3.0 colossalai>=0.3.2
datasets datasets
numpy numpy
torch>=1.12.0,<=2.0.0 torch>=1.12.0,<=2.0.0

7
examples/language/opt/README.md

@ -23,9 +23,9 @@ The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI)
## Our Modifications ## Our Modifications
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). the tokenization).
We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, and GeminiPlugin. We adapt the OPT training code to ColossalAI by leveraging [Boosting API](https://colossalai.org/docs/basics/booster_api) loaded with a chosen plugin, where each plugin corresponds to a specific kind of training strategy. This example supports plugins including TorchDDPPlugin, LowLevelZeroPlugin, HybridParallelPlugin and GeminiPlugin.
## Run Demo ## Run Demo
@ -48,6 +48,3 @@ You can run benchmark for OPT model by running the following script:
bash run_benchmark.sh bash run_benchmark.sh
``` ```
The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing. The script will test performance (throughput & peak memory usage) for each combination of hyperparameters. You can also play with this script to configure your set of hyperparameters for testing.

4
examples/language/opt/requirements.txt

@ -1,4 +1,4 @@
colossalai >= 0.1.12 colossalai >= 0.3.2
torch >= 1.8.1 torch >= 1.8.1
datasets >= 1.8.0 datasets >= 1.8.0
transformers >= 4.20.0 transformers >= 4.30.2

Loading…
Cancel
Save