support kit use for bert/gpt test (#4055)

* support kit use for bert test

* support kit test for gpt2
pull/4157/head
FoolPlayer 2023-06-22 10:33:06 +08:00 committed by Frank Lee
parent f22ddacef0
commit 7740c55c55
7 changed files with 346 additions and 273 deletions

View File

@ -25,17 +25,19 @@ class PolicyLocation:
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertPolicy"),
PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
"transformers.models.bert.modeling_bert.BertForPreTraining":
PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"),
"transformers.models.bert.modeling_bert.BertForMaskedLM":
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
"transformers.models.bert.modeling_bert.BertLMHeadModel":
PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForMaskedLM":
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
"transformers.models.bert.modeling_bert.BertForSequenceClassification":
PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"),
"transformers.models.bert.modeling_bert.BertForTokenClassification":
PolicyLocation(file_name="bert", class_name="BertForTokenClassificationPolicy"),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
@ -58,6 +60,14 @@ _POLICY_LIST = {
# GPT2
"transformers.models.gpt2.modeling_gpt2.GPT2Model":
PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel":
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
}

View File

@ -131,37 +131,6 @@ class BertForPretrainingPolicy(BertPolicy):
return self.model
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
@ -193,15 +162,53 @@ class BertLMHeadModelPolicy(BertPolicy):
return self.model
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
# BertForTokenClassification
class BertForTokenClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()

View File

@ -1,7 +1,9 @@
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -82,7 +84,6 @@ class GPT2Policy(Policy):
}
def new_model_class(self):
return self.model
def postprocess(self):
@ -94,3 +95,79 @@ class GPT2ModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
# GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# GPT22DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
def postprocess(self):
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
param = nn.Parameter(param)
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model
# GPT2ForTokenClassification
class GPT2ForTokenClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
# GPT2ForSequenceClassification
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()

View File

@ -6,83 +6,147 @@ from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence BERT
# ===============================
BATCH_SIZE = 2
SEQ_LENGTH = 16
def data_gen_fn():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
# define data gen function
def data_gen():
# Generated from following code snippet
#
# from transformers import BertTokenizer
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
# token_type_ids = tokenized_input['token_type_ids']
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data['labels'] = data['input_ids'].clone()
return data
def data_gen_for_pretraining():
# pretraining data gen
# `next_sentence_label` is the label for next sentence prediction, 0 or 1
data = data_gen_for_lm()
data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64)
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
# `labels` is the label for sequence classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([1], dtype=torch.int64)
return data
def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
def data_gen_for_mcq():
# multiple choice question data gen
# Generated from following code snippet
#
# tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
# prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
# choice0 = "It is eaten with a fork and a knife."
# choice1 = "It is eaten while held in the hand."
# data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
# data = {k: v.unsqueeze(0) for k, v in encoding.items()}
# data['labels'] = torch.tensor([0], dtype=torch.int64)
input_ids = torch.tensor([[[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102
],
[
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
2218, 1999, 1996, 2192, 1012, 102, 0
]]])
token_type_ids = torch.tensor(
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
attention_mask = torch.tensor(
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
labels = torch.tensor([0], dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
# define output transform function
output_transform_fn = lambda x: x
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
# define loss funciton
loss_fn_for_bert_model = lambda x: x.pooler_output.mean()
loss_fn = lambda x: x.loss
config = transformers.BertConfig(hidden_size=128,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=256,
hidden_dropout_prob=0,
attention_probs_dropout_prob=0)
# register the BERT variants
model_zoo.register(name='transformers_bert',
model_fn=lambda: transformers.BertModel(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bert_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_pretraining',
model_fn=lambda: transformers.BertForPreTraining(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_pretraining,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_lm_head_model',
model_fn=lambda: transformers.BertLMHeadModel(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_masked_lm',
model_fn=lambda: transformers.BertForMaskedLM(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_sequence_classification',
model_fn=lambda: transformers.BertForSequenceClassification(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_token_classification',
model_fn=lambda: transformers.BertForTokenClassification(config),
data_gen_fn=data_gen_fn,
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
# ===============================
# Register multi-sentence BERT
# ===============================
def data_gen_for_next_sentence():
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
return encoding
def data_gen_for_mcq():
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
choice0 = "It is eaten with a fork and a knife."
choice1 = "It is eaten while held in the hand."
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
return encoding
# register the following models
model_zoo.register(name='transformers_bert_for_next_sentence',
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
data_gen_fn=data_gen_for_next_sentence,
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_bert_for_mcq',
model_fn=lambda: transformers.BertForMultipleChoice(config),
data_gen_fn=data_gen_for_mcq,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -11,47 +11,86 @@ SEQ_LENGTH = 16
def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
# Generated from following code snippet
#
# from transformers import GPT2Tokenizer
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def seq_classification_data_gen():
# batch sizes should be 1 if no padding token is defined.
input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data['labels'] = data['input_ids'].clone()
return data
def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data['labels'] = torch.tensor([0], dtype=torch.int64)
return data
# define output transform function
output_transform_fn = lambda x: x
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
# define loss function
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
loss_fn = lambda x: x.loss
config = transformers.GPT2Config(n_layer=2,
n_head=4,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
summary_first_dropout=0,
hidden_dropout=0,
problem_type="single_label_classification")
# register the following models
model_zoo.register(name='transformers_gpt',
model_fn=lambda: transformers.GPT2Model(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_gpt2_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_lm',
model_fn=lambda: transformers.GPT2LMHeadModel(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_double_heads',
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_token_classification',
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_gpt_for_sequence_classification',
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
data_gen_fn=seq_classification_data_gen,
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -1,86 +1,30 @@
import copy
import os
import pytest
import torch
from transformers import (
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
BertForSequenceClassification,
BertLMHeadModel,
BertModel,
)
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
def build_model(world_size, model_fn):
config = BertConfig()
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output)
org_model = model_fn(config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=world_size,)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
return org_model, sharded_model
def check_forward(org_model, sharded_model):
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
#orgin model
org_model.eval()
org_out = org_model(**tokenized_input)
#shard model
sharded_model.eval()
shard_out = sharded_model(**tokenized_input)
assert torch.allclose(
org_out[0], shard_out[0],
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
def check_backward(org_model, sharded_model):
# prepare input
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
tokenized_input['labels'] = labels
#orgin model
org_model.train()
org_out = org_model(**tokenized_input)
org_loss = org_out.loss
# do backward
org_loss.backward()
org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad
#shard model
sharded_model.train()
shard_out = sharded_model(**tokenized_input)
shard_loss = shard_out.loss
shard_loss.backward()
shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad
# check grad equality
if org_model.__class__.__name__ == 'BertModel':
org_grad = org_model.encoder.layer[0].attention.self.query.weight.grad
shard_grad = sharded_model.encoder.layer[0].attention.self.query.weight.grad
else:
org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad
shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
@ -89,36 +33,24 @@ def check_backward(org_model, sharded_model):
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
forward_list = [
BertForMaskedLM,
BertForPreTraining,
BertLMHeadModel,
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# TODO: do not work yet
# BertModel,
# BertForSequenceClassification
# BertForNextSentencePrediction,
]
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
for model_fn in forward_list:
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
check_forward(org_model, sharded_model)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
if model_fn in backward_lsit:
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert():
spawn(check_bert, 2)

View File

@ -1,117 +1,61 @@
import copy
import os
import pytest
import torch
from transformers import AutoTokenizer, GPT2Config, GPT2Model
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
def build_model(world_size, model_fn):
config = GPT2Config()
config.attn_pdrop = 0
config.embd_pdrop = 0
config.resid_pdrop = 0
config.summary_first_dropout
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
org_model = model_fn(config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=world_size,)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
return org_model, sharded_model
def check_forward(org_model, sharded_model):
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
#orgin model
org_model.eval()
org_out = org_model(**tokenized_input)
#shard model
sharded_model.eval()
shard_out = sharded_model(**tokenized_input)
assert torch.allclose(
org_out[0], shard_out[0],
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
def check_backward(org_model, sharded_model):
# prepare input
input = 'Hello, my dog is cute'
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
labels = tokenized_input['input_ids'].clone()
labels[labels == tokenizer.pad_token_id] = -100
# tokenized_input['labels'] = labels
#orgin model
org_model.train()
org_out = org_model(**tokenized_input)
org_loss = org_out.loss
# do backward
org_loss.backward()
org_grad = org_model.h[0].attn.c_attn.weight.grad
#shard model
sharded_model.train()
shard_out = sharded_model(**tokenized_input)
shard_loss = shard_out.loss
shard_loss.backward()
shard_grad = sharded_model.h[0].attn.c_attn.weight.grad
# check grad equality
if org_model.__class__.__name__ == 'GPT2Model':
org_grad = org_model.h[0].attn.c_attn.weight.grad
shard_grad = sharded_model.h[0].attn.c_attn.weight.grad.transpose(0, 1).contiguous()
else:
org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad.transpose(0, 1).contiguous()
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
all_shard_grad = torch.cat(shard_grad_list, dim=1)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
def check_bert(rank, world_size, port):
def check_gpt2(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
forward_list = [
GPT2Model,
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# TODO: do not work yet
# BertModel,
# BertForSequenceClassification
# BertForNextSentencePrediction,
]
backward_lsit = []
for model_fn in forward_list:
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
print(name)
# if name == 'transformers_gpt':
# continue
org_model, sharded_model = build_model(world_size, model_fn)
check_forward(org_model, sharded_model)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
if model_fn in backward_lsit:
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2():
spawn(check_bert, 2)
spawn(check_gpt2, 2)
if __name__ == "__main__":