[shardformer] Align bert value (#3907)

* add bert align test, fix dist loss bug

* forward and backward align

* add ignore index

* add shardformer CI

* add gather_output optional for user in shardconfig

* update readme with optional gather_ouput

* add dist crossentropy loss test, remove unused files

* remove unused file

* remove unused file

* rename the file

* polish code
pull/4157/head
FoolPlayer 2023-06-09 14:36:54 +08:00 committed by Frank Lee
parent 79f8d5d54b
commit f1cb5ac6bf
11 changed files with 174 additions and 197 deletions

View File

@ -20,7 +20,7 @@
The sample API usage is given below:
``` python
from colossalai.shardformer import shard_model
from colossalai.shardformer import ShardConfig, shard_model
from transformers import BertForMaskedLM
# create huggingface model as normal
@ -28,7 +28,12 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased")
# make the huggingface model paralleled to ShardModel
# auto policy:
sharded_model = shard_model(model)
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(model, config=shardconfig)
# custom policy:
from xxx import <POLICYCLASS>
@ -72,7 +77,7 @@ More details can be found in shardformer/policies/basepolicy.py
``` python
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument
CustomPolicy(Policy):
class CustomPolicy(Policy):
@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
r"""
@ -235,7 +240,7 @@ CustomPolicy(Policy):
This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class.
CLASS `Col_Layer(Layer)`:
- gather_output (bool): Whether to gather the output of the layer
- gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered.
This class inherited from `Layer`, representing the layer will be sliced along column.

View File

@ -0,0 +1 @@
from .shard import ShardConfig, shard_model

View File

@ -14,7 +14,7 @@ class DistCrossEntropy(Function):
"""
@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
@ -75,8 +75,8 @@ class DistCrossEntropy(Function):
# calculate the loss
# loss = log(sum(exp(x[i]))) - x[class]
loss = torch.log(sum_exp_logits) - pred_logits
loss = torch.sum(loss).div_(loss.numel())
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
# caculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
@ -101,5 +101,5 @@ class DistCrossEntropy(Function):
return grad_logits, None, None
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels)
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index)

View File

@ -141,7 +141,7 @@ class BertPolicy(Policy):
weight="decoder.weight",
bias="decoder.bias",
replace_layer=col_nn.Linear1D_Col,
# gather_output=True,
gather_output=True,
)
]
@ -155,7 +155,8 @@ class BertForMaskedLMPolicy(BertPolicy):
@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
return (BertForMaskedLM, BertForMaskedLM_)
# return (BertForMaskedLM, BertForMaskedLM_)
return None
class BertForSequenceClassificationPolicy(BertPolicy):

View File

@ -5,16 +5,14 @@ __all__ = ['ShardConfig']
@dataclass
class ShardConfig:
"""
The config for sharding the huggingface model for test
r"""
The config for sharding the huggingface model
Args:
rank (int): The rank of local process
world_size (int): The world size of the distributed process
gather_output (bool): Whether to gather the output of the model of the last layer
"""
rank: int
fp16: bool = True
num_gpus: int = 2
world_size: int = 2
backend = "nccl"
verbose: str = 'simple'
seed: int = None
require_grad: bool = False
master_addr: str = "127.0.0.1"
master_port: int = 29500
gather_output: bool = True

View File

@ -65,6 +65,8 @@ class ModelSharder(object):
BertForMaskedLM.forward -> BertForMaskedLM_.forward
"""
inject_policy = self.policy.inject_policy()
if inject_policy is None:
return
if inject_policy is None:
return
@ -148,7 +150,7 @@ class ModelSharder(object):
n_cast = policy_layer.n_cast
reversed = policy_layer.reversed
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
gather_output = policy_layer.gather_output and self.shard_config.gather_output
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):

View File

@ -1 +0,0 @@
parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d'))

View File

@ -1,50 +0,0 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import colossalai
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
from colossalai.shardformer.layer.dropout import Dropout1D
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--module", type=str, default='distloss')
return parser.parse_args()
def test_dist_crossentropy():
pred = torch.randn(2, 4, 8, requires_grad=True)
labels = torch.randint(8, (1, 4)).repeat(2, 1)
pred_ = pred.view(-1, 8)
labels_ = labels.view(-1)
loss = F.cross_entropy(pred_, labels_)
loss.backward()
print(f"normal loss:{loss}")
pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])]
loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda'))
loss.backward()
print(f"dist loss:{loss}")
def test_dropout():
input = torch.randn(5, 4).to("cuda")
m = Dropout1D(p=0.2).to("cuda")
for i in range(2):
print(f"Output: {m(input)}")
print(torch.randn(1))
if __name__ == '__main__':
args = get_args()
colossalai.launch_from_torch(config={})
if args.module == 'distloss':
test_dist_crossentropy()
elif args.module == 'dropout':
test_dropout()
else:
print("not implemented yet")

View File

@ -1,124 +0,0 @@
import os
import random
import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler
import colossalai
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.utils import get_current_device, print_rank_0
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--mode", type=str, default='inference')
parser.add_argument("--save_model", action='store_true')
parser.add_argument("--model", type=str, default='bert-base-uncased')
return parser.parse_args()
def load_data(args):
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# tokenizer.pad_token_id = 0
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
# datasets=load_dataset("yelp_review_full")
tokenized_datasets = datasets.map(
lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
tokenized_datasets.set_format("torch")
train_dataset = tokenized_datasets["train"]
test_dataset = tokenized_datasets["test"]
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
return train_dataloader, eval_dataloader
def inference(model: nn.Module, args):
print(model)
# print(model.wte.weight.shape)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.pad_token_id = 0
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
model.eval()
model.to("cuda")
outputs = model(**inputs)
print(outputs[0])
def train(model: nn.Module, args, num_epoch: int = 3):
train_dataloader, eval_dataloader = load_data(args)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_training = num_epoch * len(train_dataloader)
progress_bar = tqdm(range(num_training))
lr_scheduler = get_scheduler(name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=num_training)
best_test_loss = float("inf")
model.to("cuda")
model.train()
for epoch in range(num_epoch):
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
for batch in train_dataloader:
optimizer.zero_grad()
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
lr_scheduler.step()
progress_bar.update(1)
train_loss = loss
loss = 0.0
for batch in eval_dataloader:
batch = {k: v.to('cuda') for k, v in batch.items()}
outputs = model(**batch)
# loss = outputs.loss
assert not torch.isnan(outputs.loss), f"{batch}"
loss += outputs.loss.item()
# loss = criterion(outputs.logits, batch["input_ids"])
test_loss = loss / len(eval_dataloader)
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
if args.save_model and test_loss < best_test_loss:
best_test_loss = test_loss
torch.save(model.state_dict(), "./checkpoints/best_model.pth")
if __name__ == "__main__":
args = get_args()
colossalai.launch_from_torch(config=args.config)
if args.model == 'bert-base-uncased':
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
elif args.model == 'gpt2':
model = GPT2LMHeadModel.from_pretrained("gpt2")
else:
raise AttributeError("model not supported")
shard_config = ShardConfig(
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),
)
sharded_model = shard_model(model, shard_config)
if args.mode == "train":
train(sharded_model, args)
elif args.mode == "inference":
inference(sharded_model, args)
else:
raise NotImplementedError

View File

@ -0,0 +1,103 @@
import os
import random
import pytest
import torch
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def build_model(rank, world_size):
config = BertConfig.from_pretrained('bert-base-uncased')
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda')
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config),
shardconfig).to('cuda')
return org_model, sharded_model
def check_forward(org_model, sharded_model):
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
#orgin model
org_model.eval()
org_out = org_model(**tokenized_input)
#shard model
sharded_model.eval()
shard_out = sharded_model(**tokenized_input)
assert torch.allclose(
org_out[0], shard_out[0],
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
def check_backward(org_model, sharded_model):
# prepare input
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
tokenized_input['labels'] = labels
#orgin model
org_model.train()
org_out = org_model(**tokenized_input)
org_loss = org_out.loss
org_loss.backward()
org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad
#shard model
sharded_model.train()
shard_out = sharded_model(**tokenized_input)
shard_loss = shard_out.loss
shard_loss.backward()
shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
org_model, sharded_model = build_model(rank, world_size)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert():
spawn(check_bert, 2)
if __name__ == "__main__":
test_bert()

View File

@ -0,0 +1,42 @@
import pytest
import torch
import torch.nn.functional as F
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
from colossalai.testing import rerun_if_address_is_in_use, spawn
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
def check_dist_crossentropy(rank, world_size, port, ignore_index):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
# prepare data
pred = torch.randn(2, 4, 8, requires_grad=True)
labels = torch.randint(8, (2, 4))
# set some label to -100 to test the ignore index
labels[0, -1] = ignore_index
org_pred = pred.view(-1, 8)
org_labels = labels.view(-1)
org_loss = F.cross_entropy(org_pred, org_labels)
dist_pred = pred.chunk(world_size, -1)[rank]
dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index)
assert torch.allclose(org_loss, dist_loss,
atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_crossentropy():
ignore_index = -100
spawn(check_dist_crossentropy, 2, ignore_index=ignore_index)
if __name__ == '__main__':
test_dist_crossentropy()