2023-06-08 03:27:05 +00:00
|
|
|
import datasets
|
2023-08-24 01:29:25 +00:00
|
|
|
import torch
|
2023-06-08 03:27:05 +00:00
|
|
|
import transformers
|
2023-08-24 01:29:25 +00:00
|
|
|
from args import parse_demo_args
|
|
|
|
from data import NetflixDataset, netflix_collator
|
2023-06-08 03:27:05 +00:00
|
|
|
from tqdm import tqdm
|
2023-08-24 01:29:25 +00:00
|
|
|
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
|
|
|
|
from transformers.utils.versions import require_version
|
2023-06-08 03:27:05 +00:00
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.booster import Booster
|
2023-09-09 14:45:36 +00:00
|
|
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
2023-06-08 03:27:05 +00:00
|
|
|
from colossalai.cluster import DistCoordinator
|
2023-08-24 01:29:25 +00:00
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
2023-06-08 03:27:05 +00:00
|
|
|
|
|
|
|
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
|
|
|
|
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
|
|
|
|
|
2023-09-09 14:45:36 +00:00
|
|
|
output_transform_fn = lambda x: x
|
|
|
|
criterion = lambda x: x.loss
|
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
|
|
|
|
def move_to_cuda(batch, device):
|
|
|
|
return {k: v.to(device) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
2023-09-09 14:45:36 +00:00
|
|
|
def train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator):
|
2023-06-08 03:27:05 +00:00
|
|
|
torch.cuda.synchronize()
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-09-09 14:45:36 +00:00
|
|
|
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
|
|
|
|
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage()
|
|
|
|
total_step = len(dataloader)
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-09-09 14:45:36 +00:00
|
|
|
model.train()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
dataloader = iter(dataloader)
|
2023-09-19 06:20:26 +00:00
|
|
|
with tqdm(
|
|
|
|
range(total_step), desc=f"Epoch [{epoch + 1}]", disable=not (coordinator.is_master() or is_pp_last_stage)
|
|
|
|
) as pbar:
|
2023-09-09 14:45:36 +00:00
|
|
|
# Forward pass
|
|
|
|
for _ in pbar:
|
|
|
|
if use_pipeline:
|
2023-09-19 06:20:26 +00:00
|
|
|
outputs = booster.execute_pipeline(
|
|
|
|
dataloader, model, _criterion, optimizer, return_loss=True, return_outputs=True
|
|
|
|
)
|
2023-09-09 14:45:36 +00:00
|
|
|
# Backward and optimize
|
|
|
|
if is_pp_last_stage:
|
2023-09-19 06:20:26 +00:00
|
|
|
loss = outputs["loss"]
|
|
|
|
pbar.set_postfix({"loss": loss.item()})
|
2023-09-09 14:45:36 +00:00
|
|
|
else:
|
|
|
|
data = next(dataloader)
|
|
|
|
data = move_to_cuda(data)
|
|
|
|
outputs = model(**data)
|
|
|
|
loss = _criterion(outputs, None)
|
|
|
|
# Backward
|
|
|
|
booster.backward(loss, optimizer)
|
2023-09-19 06:20:26 +00:00
|
|
|
pbar.set_postfix({"loss": loss.item()})
|
2023-06-08 03:27:05 +00:00
|
|
|
|
|
|
|
optimizer.step()
|
2023-09-09 14:45:36 +00:00
|
|
|
optimizer.zero_grad()
|
2023-06-08 03:27:05 +00:00
|
|
|
lr_scheduler.step()
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
args = parse_demo_args()
|
|
|
|
|
|
|
|
# Launch ColossalAI
|
|
|
|
colossalai.launch_from_torch(config={}, seed=args.seed)
|
|
|
|
coordinator = DistCoordinator()
|
|
|
|
world_size = coordinator.world_size
|
|
|
|
|
|
|
|
# Manage loggers
|
|
|
|
disable_existing_loggers()
|
|
|
|
logger = get_dist_logger()
|
|
|
|
if coordinator.is_master():
|
|
|
|
datasets.utils.logging.set_verbosity_warning()
|
|
|
|
transformers.utils.logging.set_verbosity_info()
|
|
|
|
else:
|
|
|
|
datasets.utils.logging.set_verbosity_error()
|
|
|
|
transformers.utils.logging.set_verbosity_error()
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
# Build OPT model
|
|
|
|
config = AutoConfig.from_pretrained(args.model_name_or_path)
|
2023-06-12 07:02:27 +00:00
|
|
|
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
|
2023-06-08 03:27:05 +00:00
|
|
|
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
|
|
|
|
|
|
|
|
# Enable gradient checkpointing
|
|
|
|
model.gradient_checkpointing_enable()
|
|
|
|
|
|
|
|
# Set plugin
|
|
|
|
booster_kwargs = {}
|
2023-09-19 06:20:26 +00:00
|
|
|
if args.plugin == "torch_ddp_fp16":
|
|
|
|
booster_kwargs["mixed_precision"] = "fp16"
|
|
|
|
if args.plugin.startswith("torch_ddp"):
|
2023-06-08 03:27:05 +00:00
|
|
|
plugin = TorchDDPPlugin()
|
2023-09-19 06:20:26 +00:00
|
|
|
elif args.plugin == "gemini":
|
2023-08-24 01:29:25 +00:00
|
|
|
plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
|
2023-09-19 06:20:26 +00:00
|
|
|
elif args.plugin == "low_level_zero":
|
2023-06-08 03:27:05 +00:00
|
|
|
plugin = LowLevelZeroPlugin(initial_scale=2**5)
|
2023-09-19 06:20:26 +00:00
|
|
|
elif args.plugin == "hybrid_parallel":
|
2023-09-09 14:45:36 +00:00
|
|
|
# modify the param accordingly for finetuning test cases
|
2023-09-19 06:20:26 +00:00
|
|
|
plugin = HybridParallelPlugin(
|
|
|
|
tp_size=2,
|
|
|
|
pp_size=2,
|
|
|
|
num_microbatches=2,
|
|
|
|
enable_all_optimization=True,
|
|
|
|
zero_stage=0,
|
|
|
|
precision="fp16",
|
|
|
|
initial_scale=1,
|
|
|
|
)
|
2023-09-09 14:45:36 +00:00
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
logger.info(f"Set plugin as {args.plugin}", ranks=[0])
|
|
|
|
|
|
|
|
# Prepare tokenizer and dataloader
|
2023-08-24 01:29:25 +00:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
|
2023-06-08 03:27:05 +00:00
|
|
|
dataset = NetflixDataset(tokenizer)
|
2023-09-19 06:20:26 +00:00
|
|
|
dataloader = plugin.prepare_dataloader(
|
|
|
|
dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=netflix_collator
|
|
|
|
)
|
2023-08-24 01:29:25 +00:00
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
# Set optimizer
|
2023-08-24 01:29:25 +00:00
|
|
|
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
|
2023-06-08 03:27:05 +00:00
|
|
|
|
|
|
|
# Set lr scheduler
|
|
|
|
total_steps = len(dataloader) * args.num_epoch
|
|
|
|
num_warmup_steps = int(args.warmup_ratio * total_steps)
|
2023-09-19 06:20:26 +00:00
|
|
|
lr_scheduler = get_linear_schedule_with_warmup(
|
|
|
|
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=len(dataloader) * args.num_epoch
|
|
|
|
)
|
2023-06-08 03:27:05 +00:00
|
|
|
|
2023-09-09 14:45:36 +00:00
|
|
|
# Define criterion
|
|
|
|
def _criterion(outputs, inputs):
|
|
|
|
outputs = output_transform_fn(outputs)
|
|
|
|
loss = criterion(outputs)
|
|
|
|
return loss
|
|
|
|
|
2023-06-08 03:27:05 +00:00
|
|
|
# Set booster
|
|
|
|
booster = Booster(plugin=plugin, **booster_kwargs)
|
2023-09-19 06:20:26 +00:00
|
|
|
model, optimizer, _criterion, dataloader, lr_scheduler = booster.boost(
|
|
|
|
model=model, optimizer=optimizer, dataloader=dataloader, criterion=_criterion, lr_scheduler=lr_scheduler
|
|
|
|
)
|
2023-06-08 03:27:05 +00:00
|
|
|
|
|
|
|
# Start finetuning
|
|
|
|
logger.info(f"Start finetuning", ranks=[0])
|
|
|
|
for epoch in range(args.num_epoch):
|
2023-09-09 14:45:36 +00:00
|
|
|
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, dataloader, booster, coordinator)
|
2023-06-08 03:27:05 +00:00
|
|
|
|
|
|
|
# Finish training and evaluate
|
|
|
|
logger.info(f"Finish finetuning", ranks=[0])
|
2023-09-09 14:45:36 +00:00
|
|
|
booster.save_model(model, args.output_path, shard=True)
|
2023-06-08 03:27:05 +00:00
|
|
|
logger.info(f"Saving model checkpoint to {args.output_path}", ranks=[0])
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|