ColossalAI/colossalai/shardformer/test/test.py

111 lines
4.1 KiB
Python

import os
import random
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler
import colossalai
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.utils import get_current_device, print_rank_0
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--mode", type=str, default='inference')
parser.add_argument("--save_model", action='store_true')
return parser.parse_args()
def load_data():
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
# datasets=load_dataset("yelp_review_full")
tokenized_datasets = datasets.map(
lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
tokenized_datasets.set_format("torch")
train_dataset = tokenized_datasets["train"]
test_dataset = tokenized_datasets["test"]
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
return train_dataloader, eval_dataloader
def inference(model: nn.Module, args):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
model.eval()
model.to("cuda")
outputs = model(**inputs)
print(outputs)
def train(model: nn.Module, args, num_epoch: int = 3):
train_dataloader, eval_dataloader = load_data()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_training = num_epoch * len(train_dataloader)
progress_bar = tqdm(range(num_training))
lr_scheduler = get_scheduler(name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training)
best_test_loss = float("inf")
model.to("cuda")
model.train()
for epoch in range(num_epoch):
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
for batch in train_dataloader:
optimizer.zero_grad()
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
progress_bar.update(1)
train_loss = loss
loss = 0.0
for batch in eval_dataloader:
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
# loss = outputs.loss
assert not torch.isnan(outputs.loss), f"{batch}"
loss += outputs.loss.item()
# loss = criterion(outputs.logits, batch["input_ids"])
test_loss = loss / len(eval_dataloader)
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
if args.save_model and test_loss < best_test_loss:
best_test_loss = test_loss
torch.save(model.state_dict(), "./checkpoints/best_model.pth")
if __name__ == "__main__":
args = get_args()
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
colossalai.launch_from_torch(config=args.config)
shard_config = ShardConfig(
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),
)
sharded_model = shard_model(model, shard_config)
if args.mode == "train":
train(sharded_model, args)
elif args.mode == "inference":
inference(sharded_model, args)
else:
raise NotImplementedError