mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
181 lines
6.1 KiB
181 lines
6.1 KiB
import argparse
|
|
import math
|
|
from typing import Any, List, Union
|
|
|
|
import evaluate
|
|
import torch
|
|
import torch.distributed as dist
|
|
from data import GLUEDataBuilder
|
|
from torch import nn
|
|
from torch.optim import Adam, Optimizer
|
|
from torch.utils._pytree import tree_map
|
|
from torch.utils.data import DataLoader
|
|
from tqdm import tqdm
|
|
from transformers import BertConfig, BertForSequenceClassification, get_linear_schedule_with_warmup
|
|
|
|
import colossalai
|
|
from colossalai.cluster import DistCoordinator
|
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
|
|
|
|
|
def to_device(x: Any, device: torch.device) -> Any:
|
|
def _to(t: Any):
|
|
if isinstance(t, torch.Tensor):
|
|
return t.to(device)
|
|
return t
|
|
|
|
return tree_map(_to, x)
|
|
|
|
|
|
def train(args):
|
|
colossalai.launch_from_torch(config={}, seed=42)
|
|
coordinator = DistCoordinator()
|
|
|
|
# prepare for data and dataset
|
|
data_builder = GLUEDataBuilder(
|
|
model_name_or_path=args.pretrain,
|
|
task_name=args.task,
|
|
train_batch_size=args.batch_size,
|
|
eval_batch_size=args.batch_size,
|
|
)
|
|
train_dataloader = data_builder.train_dataloader()
|
|
test_dataloader = data_builder.test_dataloader()
|
|
|
|
if args.model == "bert":
|
|
cfg = BertConfig.from_pretrained(args.pretrain, num_labels=data_builder.num_labels)
|
|
model = BertForSequenceClassification.from_pretrained(args.pretrain, config=cfg)
|
|
|
|
model.to(torch.cuda.current_device())
|
|
|
|
# if multiple GPUs, shard the model
|
|
if dist.get_world_size() > 1:
|
|
tp_group = dist.new_group(backend="nccl")
|
|
shard_config = ShardConfig(
|
|
tensor_parallel_process_group=tp_group, enable_tensor_parallelism=True, enable_all_optimization=True
|
|
)
|
|
shard_former = ShardFormer(shard_config=shard_config)
|
|
model, _ = shard_former.optimize(model)
|
|
|
|
optim = Adam(model.parameters(), lr=args.lr)
|
|
num_update_steps_per_epoch = len(train_dataloader) // args.accumulation_steps
|
|
max_steps = math.ceil(args.max_epochs * num_update_steps_per_epoch)
|
|
lr_scheduler = get_linear_schedule_with_warmup(
|
|
optim,
|
|
num_warmup_steps=math.ceil(max_steps * args.warmup_fraction),
|
|
num_training_steps=max_steps,
|
|
)
|
|
fit(
|
|
model,
|
|
optim,
|
|
lr_scheduler,
|
|
train_dataloader,
|
|
args.max_epochs,
|
|
args.accumulation_steps,
|
|
args.batch_size,
|
|
coordinator,
|
|
)
|
|
results = evaluate_model(
|
|
model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, coordinator
|
|
)
|
|
if coordinator.is_master():
|
|
print(results)
|
|
if args.target_f1 is not None and "f1" in results:
|
|
assert results["f1"] >= args.target_f1, f'f1 score {results["f1"]} is lower than target {args.target_f1}'
|
|
|
|
|
|
def fit(
|
|
model: nn.Module,
|
|
optimizer: Optimizer,
|
|
scheduler,
|
|
train_dataloader,
|
|
max_epochs,
|
|
accumulation_steps,
|
|
batch_size,
|
|
coordinator,
|
|
):
|
|
step_bar = tqdm(
|
|
range(len(train_dataloader) // accumulation_steps * max_epochs),
|
|
desc=f"steps",
|
|
disable=not coordinator.is_master(),
|
|
)
|
|
total_loss = 0
|
|
for epoch in range(max_epochs):
|
|
model.train()
|
|
for batch_id, batch in enumerate(train_dataloader):
|
|
batch = to_device(batch, torch.cuda.current_device())
|
|
outputs = model(**batch)
|
|
loss = outputs.loss
|
|
loss = loss / accumulation_steps
|
|
loss.backward()
|
|
total_loss += loss.item()
|
|
if (batch_id + 1) % accumulation_steps == 0:
|
|
optimizer.step()
|
|
scheduler.step()
|
|
optimizer.zero_grad()
|
|
step_bar.set_postfix(
|
|
{"epoch": epoch, "loss": total_loss / batch_size, "lr": scheduler.get_last_lr()[0]}
|
|
)
|
|
total_loss = 0
|
|
step_bar.update()
|
|
|
|
|
|
# evaluate
|
|
@torch.no_grad()
|
|
def evaluate_model(
|
|
model: nn.Module,
|
|
test_dataloader: Union[DataLoader, List[DataLoader]],
|
|
num_labels: int,
|
|
task_name: str,
|
|
eval_splits: List[str],
|
|
coordinator: DistCoordinator,
|
|
):
|
|
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
|
|
model.eval()
|
|
|
|
def evaluate_subset(dataloader: DataLoader):
|
|
accum_loss = torch.zeros(1, device=torch.cuda.current_device())
|
|
for batch in dataloader:
|
|
batch = to_device(batch, torch.cuda.current_device())
|
|
outputs = model(**batch)
|
|
val_loss, logits = outputs[:2]
|
|
accum_loss.add_(val_loss)
|
|
|
|
if num_labels > 1:
|
|
preds = torch.argmax(logits, axis=1)
|
|
elif num_labels == 1:
|
|
preds = logits.squeeze()
|
|
|
|
labels = batch["labels"]
|
|
metric.add_batch(predictions=preds, references=labels)
|
|
|
|
results = metric.compute()
|
|
if coordinator.is_master():
|
|
results["loss"] = accum_loss.item() / (len(dataloader) * dataloader.batch_size)
|
|
return results
|
|
|
|
if isinstance(test_dataloader, DataLoader):
|
|
return evaluate_subset(test_dataloader)
|
|
else:
|
|
assert len(test_dataloader) == len(eval_splits)
|
|
final_results = {}
|
|
for split, sub_loader in zip(eval_splits, test_dataloader):
|
|
results = evaluate_subset(sub_loader)
|
|
final_results.update({f"{k}_{split}": v for k, v in results.items()})
|
|
return final_results
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-t", "--task", default="mrpc", help="GLUE task to run")
|
|
parser.add_argument("--model", type=str, default="bert")
|
|
parser.add_argument("--pretrain", type=str, default="bert-base-uncased")
|
|
parser.add_argument("--max_epochs", type=int, default=1)
|
|
parser.add_argument("--batch_size", type=int, default=4)
|
|
parser.add_argument("--lr", type=float, default=2.4e-5)
|
|
parser.add_argument("--fused_layernorm", type=bool, default=False)
|
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
|
parser.add_argument("--warmup_fraction", type=float, default=0.03)
|
|
parser.add_argument("--target_f1", type=float, default=None)
|
|
args = parser.parse_args()
|
|
train(args)
|