mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] Align bert value (#3907)
* add bert align test, fix dist loss bug * forward and backward align * add ignore index * add shardformer CI * add gather_output optional for user in shardconfig * update readme with optional gather_ouput * add dist crossentropy loss test, remove unused files * remove unused file * remove unused file * rename the file * polish codepull/4157/head
parent
79f8d5d54b
commit
f1cb5ac6bf
|
@ -20,7 +20,7 @@
|
|||
The sample API usage is given below:
|
||||
|
||||
``` python
|
||||
from colossalai.shardformer import shard_model
|
||||
from colossalai.shardformer import ShardConfig, shard_model
|
||||
from transformers import BertForMaskedLM
|
||||
|
||||
# create huggingface model as normal
|
||||
|
@ -28,7 +28,12 @@ model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
|||
|
||||
# make the huggingface model paralleled to ShardModel
|
||||
# auto policy:
|
||||
sharded_model = shard_model(model)
|
||||
shardconfig = ShardConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
gather_output=True,
|
||||
)
|
||||
sharded_model = shard_model(model, config=shardconfig)
|
||||
|
||||
# custom policy:
|
||||
from xxx import <POLICYCLASS>
|
||||
|
@ -72,7 +77,7 @@ More details can be found in shardformer/policies/basepolicy.py
|
|||
``` python
|
||||
from colossalai.shardformer.policies.basepolicy import Policy, Layer, Col_Layer, Row_Layer, Argument
|
||||
|
||||
CustomPolicy(Policy):
|
||||
class CustomPolicy(Policy):
|
||||
@staticmethod
|
||||
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
|
||||
r"""
|
||||
|
@ -235,7 +240,7 @@ CustomPolicy(Policy):
|
|||
This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class.
|
||||
|
||||
CLASS `Col_Layer(Layer)`:
|
||||
- gather_output (bool): Whether to gather the output of the layer
|
||||
- gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered.
|
||||
|
||||
This class inherited from `Layer`, representing the layer will be sliced along column.
|
||||
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from .shard import ShardConfig, shard_model
|
|
@ -14,7 +14,7 @@ class DistCrossEntropy(Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor):
|
||||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int):
|
||||
r"""
|
||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||
loss = -log(exp(x[class])/sum(exp(x[i]))
|
||||
|
@ -75,8 +75,8 @@ class DistCrossEntropy(Function):
|
|||
|
||||
# calculate the loss
|
||||
# loss = log(sum(exp(x[i]))) - x[class]
|
||||
loss = torch.log(sum_exp_logits) - pred_logits
|
||||
loss = torch.sum(loss).div_(loss.numel())
|
||||
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
|
||||
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
|
||||
|
||||
# caculate the softmax
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
|
@ -101,5 +101,5 @@ class DistCrossEntropy(Function):
|
|||
return grad_logits, None, None
|
||||
|
||||
|
||||
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels)
|
||||
def applyDistCrossEntropy(vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index)
|
||||
|
|
|
@ -141,7 +141,7 @@ class BertPolicy(Policy):
|
|||
weight="decoder.weight",
|
||||
bias="decoder.bias",
|
||||
replace_layer=col_nn.Linear1D_Col,
|
||||
# gather_output=True,
|
||||
gather_output=True,
|
||||
)
|
||||
]
|
||||
|
||||
|
@ -155,7 +155,8 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
|
||||
@staticmethod
|
||||
def inject_policy() -> Tuple[nn.Module, nn.Module]:
|
||||
return (BertForMaskedLM, BertForMaskedLM_)
|
||||
# return (BertForMaskedLM, BertForMaskedLM_)
|
||||
return None
|
||||
|
||||
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
|
|
@ -5,16 +5,14 @@ __all__ = ['ShardConfig']
|
|||
|
||||
@dataclass
|
||||
class ShardConfig:
|
||||
"""
|
||||
The config for sharding the huggingface model for test
|
||||
r"""
|
||||
The config for sharding the huggingface model
|
||||
|
||||
Args:
|
||||
rank (int): The rank of local process
|
||||
world_size (int): The world size of the distributed process
|
||||
gather_output (bool): Whether to gather the output of the model of the last layer
|
||||
"""
|
||||
rank: int
|
||||
fp16: bool = True
|
||||
num_gpus: int = 2
|
||||
world_size: int = 2
|
||||
backend = "nccl"
|
||||
verbose: str = 'simple'
|
||||
seed: int = None
|
||||
require_grad: bool = False
|
||||
master_addr: str = "127.0.0.1"
|
||||
master_port: int = 29500
|
||||
gather_output: bool = True
|
||||
|
|
|
@ -65,6 +65,8 @@ class ModelSharder(object):
|
|||
BertForMaskedLM.forward -> BertForMaskedLM_.forward
|
||||
"""
|
||||
inject_policy = self.policy.inject_policy()
|
||||
if inject_policy is None:
|
||||
return
|
||||
|
||||
if inject_policy is None:
|
||||
return
|
||||
|
@ -148,7 +150,7 @@ class ModelSharder(object):
|
|||
n_cast = policy_layer.n_cast
|
||||
reversed = policy_layer.reversed
|
||||
if policy_layer.__class__.__name__ == "Col_Layer":
|
||||
gather_output = policy_layer.gather_output
|
||||
gather_output = policy_layer.gather_output and self.shard_config.gather_output
|
||||
|
||||
if weight_attr is not None:
|
||||
if hasattr_(org_layer, weight_attr):
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
parallel = dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d'))
|
|
@ -1,50 +0,0 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument("--module", type=str, default='distloss')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def test_dist_crossentropy():
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True)
|
||||
labels = torch.randint(8, (1, 4)).repeat(2, 1)
|
||||
|
||||
pred_ = pred.view(-1, 8)
|
||||
labels_ = labels.view(-1)
|
||||
loss = F.cross_entropy(pred_, labels_)
|
||||
loss.backward()
|
||||
print(f"normal loss:{loss}")
|
||||
|
||||
pred = pred.chunk(int(os.environ['WORLD_SIZE']), -1)[int(os.environ['RANK'])]
|
||||
loss = applyDistCrossEntropy(pred.to('cuda'), labels.to('cuda'))
|
||||
loss.backward()
|
||||
print(f"dist loss:{loss}")
|
||||
|
||||
|
||||
def test_dropout():
|
||||
input = torch.randn(5, 4).to("cuda")
|
||||
m = Dropout1D(p=0.2).to("cuda")
|
||||
for i in range(2):
|
||||
print(f"Output: {m(input)}")
|
||||
print(torch.randn(1))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
colossalai.launch_from_torch(config={})
|
||||
if args.module == 'distloss':
|
||||
test_dist_crossentropy()
|
||||
elif args.module == 'dropout':
|
||||
test_dropout()
|
||||
else:
|
||||
print("not implemented yet")
|
|
@ -1,124 +0,0 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler
|
||||
|
||||
import colossalai
|
||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||
from colossalai.utils import get_current_device, print_rank_0
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = colossalai.get_default_parser()
|
||||
parser.add_argument("--mode", type=str, default='inference')
|
||||
parser.add_argument("--save_model", action='store_true')
|
||||
parser.add_argument("--model", type=str, default='bert-base-uncased')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_data(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
# tokenizer.pad_token_id = 0
|
||||
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
|
||||
# datasets=load_dataset("yelp_review_full")
|
||||
tokenized_datasets = datasets.map(
|
||||
lambda examples: tokenizer(examples["text"], truncation=True, padding="max_length"), batched=True)
|
||||
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
|
||||
# tokenized_datasets=tokenized_datasets.rename_column("label","labels")
|
||||
tokenized_datasets.set_format("torch")
|
||||
|
||||
train_dataset = tokenized_datasets["train"]
|
||||
test_dataset = tokenized_datasets["test"]
|
||||
|
||||
datacollector = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=0.15, return_tensors="pt")
|
||||
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
|
||||
eval_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True, collate_fn=datacollector)
|
||||
return train_dataloader, eval_dataloader
|
||||
|
||||
|
||||
def inference(model: nn.Module, args):
|
||||
print(model)
|
||||
# print(model.wte.weight.shape)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
tokenizer.pad_token_id = 0
|
||||
token = "Hello, my dog is cute"
|
||||
inputs = tokenizer(token, return_tensors="pt")
|
||||
inputs.to("cuda")
|
||||
model.eval()
|
||||
model.to("cuda")
|
||||
outputs = model(**inputs)
|
||||
print(outputs[0])
|
||||
|
||||
|
||||
def train(model: nn.Module, args, num_epoch: int = 3):
|
||||
train_dataloader, eval_dataloader = load_data(args)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
|
||||
num_training = num_epoch * len(train_dataloader)
|
||||
progress_bar = tqdm(range(num_training))
|
||||
lr_scheduler = get_scheduler(name="linear",
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=0,
|
||||
num_training_steps=num_training)
|
||||
best_test_loss = float("inf")
|
||||
model.to("cuda")
|
||||
model.train()
|
||||
for epoch in range(num_epoch):
|
||||
progress_bar.set_description(f"Rank {get_current_device()} epoch {epoch}")
|
||||
for batch in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
batch = {k: v.to('cuda') for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
progress_bar.update(1)
|
||||
train_loss = loss
|
||||
|
||||
loss = 0.0
|
||||
for batch in eval_dataloader:
|
||||
batch = {k: v.to('cuda') for k, v in batch.items()}
|
||||
outputs = model(**batch)
|
||||
# loss = outputs.loss
|
||||
assert not torch.isnan(outputs.loss), f"{batch}"
|
||||
loss += outputs.loss.item()
|
||||
# loss = criterion(outputs.logits, batch["input_ids"])
|
||||
test_loss = loss / len(eval_dataloader)
|
||||
print_rank_0(f"Train Loss: {train_loss:.4f} Test Loss:{test_loss:.4f}")
|
||||
if args.save_model and test_loss < best_test_loss:
|
||||
best_test_loss = test_loss
|
||||
torch.save(model.state_dict(), "./checkpoints/best_model.pth")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
colossalai.launch_from_torch(config=args.config)
|
||||
if args.model == 'bert-base-uncased':
|
||||
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
|
||||
elif args.model == 'gpt2':
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
else:
|
||||
raise AttributeError("model not supported")
|
||||
shard_config = ShardConfig(
|
||||
rank=int(str(get_current_device()).split(':')[-1]),
|
||||
world_size=int(os.environ['WORLD_SIZE']),
|
||||
)
|
||||
sharded_model = shard_model(model, shard_config)
|
||||
|
||||
if args.mode == "train":
|
||||
train(sharded_model, args)
|
||||
elif args.mode == "inference":
|
||||
inference(sharded_model, args)
|
||||
else:
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,103 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
|
||||
|
||||
def build_model(rank, world_size):
|
||||
config = BertConfig.from_pretrained('bert-base-uncased')
|
||||
config.hidden_dropout_prob = 0
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda')
|
||||
|
||||
shardconfig = ShardConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
gather_output=True,
|
||||
)
|
||||
sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config),
|
||||
shardconfig).to('cuda')
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
def check_forward(org_model, sharded_model):
|
||||
input = 'Hello, my dog is cute'
|
||||
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
|
||||
#orgin model
|
||||
org_model.eval()
|
||||
org_out = org_model(**tokenized_input)
|
||||
|
||||
#shard model
|
||||
sharded_model.eval()
|
||||
shard_out = sharded_model(**tokenized_input)
|
||||
|
||||
assert torch.allclose(
|
||||
org_out[0], shard_out[0],
|
||||
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
|
||||
|
||||
|
||||
def check_backward(org_model, sharded_model):
|
||||
# prepare input
|
||||
input = 'Hello, my dog is cute'
|
||||
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
labels = tokenized_input['input_ids'].clone()
|
||||
labels[labels == tokenizer.pad_token_id] = -100
|
||||
tokenized_input['labels'] = labels
|
||||
|
||||
#orgin model
|
||||
org_model.train()
|
||||
org_out = org_model(**tokenized_input)
|
||||
org_loss = org_out.loss
|
||||
org_loss.backward()
|
||||
org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
|
||||
#shard model
|
||||
sharded_model.train()
|
||||
shard_out = sharded_model(**tokenized_input)
|
||||
shard_loss = shard_out.loss
|
||||
shard_loss.backward()
|
||||
shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
|
||||
def check_bert(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
org_model, sharded_model = build_model(rank, world_size)
|
||||
check_forward(org_model, sharded_model)
|
||||
check_backward(org_model, sharded_model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bert():
|
||||
spawn(check_bert, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bert()
|
|
@ -0,0 +1,42 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.layer.dist_crossentropy import applyDistCrossEntropy
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
|
||||
|
||||
def check_dist_crossentropy(rank, world_size, port, ignore_index):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||
|
||||
# prepare data
|
||||
pred = torch.randn(2, 4, 8, requires_grad=True)
|
||||
labels = torch.randint(8, (2, 4))
|
||||
# set some label to -100 to test the ignore index
|
||||
labels[0, -1] = ignore_index
|
||||
|
||||
org_pred = pred.view(-1, 8)
|
||||
org_labels = labels.view(-1)
|
||||
org_loss = F.cross_entropy(org_pred, org_labels)
|
||||
|
||||
dist_pred = pred.chunk(world_size, -1)[rank]
|
||||
dist_loss = applyDistCrossEntropy(dist_pred.to('cuda'), labels.to('cuda'), ignore_index=ignore_index)
|
||||
|
||||
assert torch.allclose(org_loss, dist_loss,
|
||||
atol=1e-5), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_dist_crossentropy():
|
||||
ignore_index = -100
|
||||
spawn(check_dist_crossentropy, 2, ignore_index=ignore_index)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_dist_crossentropy()
|
Loading…
Reference in New Issue