mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
150 lines
5.3 KiB
150 lines
5.3 KiB
2 years ago
|
import time
|
||
|
|
||
|
import torch
|
||
|
import datasets
|
||
|
import transformers
|
||
|
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer
|
||
|
from transformers import get_linear_schedule_with_warmup
|
||
|
from transformers.utils.versions import require_version
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
import colossalai
|
||
|
from colossalai.nn.optimizer import HybridAdam
|
||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||
|
from colossalai.tensor import ProcessGroup, ShardSpec
|
||
|
from colossalai.utils import get_current_device
|
||
|
from colossalai.zero import ColoInitContext
|
||
|
from colossalai.booster import Booster
|
||
|
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||
|
from colossalai.cluster import DistCoordinator
|
||
|
|
||
|
from args import parse_demo_args
|
||
|
from data import NetflixDataset, netflix_collator
|
||
|
|
||
|
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
|
||
|
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
||
|
|
||
|
|
||
|
def move_to_cuda(batch, device):
|
||
|
return {k: v.to(device) for k, v in batch.items()}
|
||
|
|
||
|
|
||
|
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
|
||
|
|
||
|
torch.cuda.synchronize()
|
||
|
model.train()
|
||
|
|
||
|
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
|
||
|
|
||
|
for batch in pbar:
|
||
|
|
||
|
# Foward
|
||
|
optimizer.zero_grad()
|
||
|
batch = move_to_cuda(batch, torch.cuda.current_device())
|
||
|
|
||
|
outputs = model(use_cache=False, **batch)
|
||
|
loss = outputs['loss']
|
||
|
|
||
|
# Backward
|
||
|
booster.backward(loss, optimizer)
|
||
|
optimizer.step()
|
||
|
lr_scheduler.step()
|
||
|
|
||
|
# Print batch loss
|
||
|
pbar.set_postfix({'loss': loss.item()})
|
||
|
|
||
|
|
||
|
def main():
|
||
|
|
||
|
args = parse_demo_args()
|
||
|
|
||
|
# Launch ColossalAI
|
||
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
||
|
coordinator = DistCoordinator()
|
||
|
world_size = coordinator.world_size
|
||
|
|
||
|
# Manage loggers
|
||
|
disable_existing_loggers()
|
||
|
logger = get_dist_logger()
|
||
|
if coordinator.is_master():
|
||
|
datasets.utils.logging.set_verbosity_warning()
|
||
|
transformers.utils.logging.set_verbosity_info()
|
||
|
else:
|
||
|
datasets.utils.logging.set_verbosity_error()
|
||
|
transformers.utils.logging.set_verbosity_error()
|
||
|
|
||
|
# Build OPT model
|
||
|
# Initialize the model under ColoInitContext if using GeminiPlugin
|
||
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
||
|
if args.plugin == 'gemini':
|
||
|
shard_pg = ProcessGroup(tp_degree=world_size)
|
||
|
default_dist_spec = ShardSpec([-1], [world_size])
|
||
|
with ColoInitContext(device='cpu',
|
||
|
default_dist_spec=default_dist_spec,
|
||
|
default_pg=shard_pg):
|
||
|
model = OPTForCausalLM(config)
|
||
|
else:
|
||
|
model = OPTForCausalLM(config)
|
||
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
||
|
|
||
|
# Enable gradient checkpointing
|
||
|
model.gradient_checkpointing_enable()
|
||
|
|
||
|
# Set plugin
|
||
|
booster_kwargs = {}
|
||
|
if args.plugin == 'torch_ddp_fp16':
|
||
|
booster_kwargs['mixed_precision'] = 'fp16'
|
||
|
if args.plugin.startswith('torch_ddp'):
|
||
|
plugin = TorchDDPPlugin()
|
||
|
elif args.plugin == 'gemini':
|
||
|
plugin = GeminiPlugin(device=get_current_device(),
|
||
|
placement_policy='cpu',
|
||
|
pin_memory=True,
|
||
|
strict_ddp_mode=True,
|
||
|
initial_scale=2**5)
|
||
|
elif args.plugin == 'low_level_zero':
|
||
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
||
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
||
|
|
||
|
# Prepare tokenizer and dataloader
|
||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
||
|
dataset = NetflixDataset(tokenizer)
|
||
|
dataloader = plugin.prepare_dataloader(dataset,
|
||
|
batch_size=args.batch_size,
|
||
|
shuffle=True,
|
||
|
drop_last=True,
|
||
|
collate_fn=netflix_collator)
|
||
|
|
||
|
# Set optimizer
|
||
|
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size))
|
||
|
|
||
|
# Set lr scheduler
|
||
|
total_steps = len(dataloader) * args.num_epoch
|
||
|
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
||
|
lr_scheduler = get_linear_schedule_with_warmup(
|
||
|
optimizer,
|
||
|
num_warmup_steps=num_warmup_steps,
|
||
|
num_training_steps=len(dataloader) * args.num_epoch
|
||
|
)
|
||
|
|
||
|
# Set booster
|
||
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
||
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
|
||
|
optimizer=optimizer,
|
||
|
dataloader=dataloader,
|
||
|
lr_scheduler=lr_scheduler)
|
||
|
|
||
|
# Start finetuning
|
||
|
logger.info(f"Start finetuning", ranks=[0])
|
||
|
for epoch in range(args.num_epoch):
|
||
|
train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator)
|
||
|
|
||
|
# Finish training and evaluate
|
||
|
logger.info(f"Finish finetuning", ranks=[0])
|
||
|
booster.save_model(model, args.output_path)
|
||
|
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|