mirror of https://github.com/hpcaitech/ColossalAI
support kit use for bert/gpt test (#4055)
* support kit use for bert test * support kit test for gpt2pull/4157/head
parent
f22ddacef0
commit
7740c55c55
|
@ -25,17 +25,19 @@ class PolicyLocation:
|
|||
_POLICY_LIST = {
|
||||
# BERT
|
||||
"transformers.models.bert.modeling_bert.BertModel":
|
||||
PolicyLocation(file_name="bert", class_name="BertPolicy"),
|
||||
PolicyLocation(file_name="bert", class_name="BertModelPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForPreTraining":
|
||||
PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForMaskedLM":
|
||||
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertLMHeadModel":
|
||||
PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
|
||||
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForMaskedLM":
|
||||
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForSequenceClassification":
|
||||
PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForTokenClassification":
|
||||
PolicyLocation(file_name="bert", class_name="BertForTokenClassificationPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
|
||||
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
|
||||
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
|
||||
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
|
||||
|
||||
|
@ -58,6 +60,14 @@ _POLICY_LIST = {
|
|||
# GPT2
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2Model":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2LMHeadModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2DoubleHeadsModel":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2DoubleHeadsModelPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForTokenClassification":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -131,37 +131,6 @@ class BertForPretrainingPolicy(BertPolicy):
|
|||
return self.model
|
||||
|
||||
|
||||
# BertForMaskedLM
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
|
||||
|
@ -193,15 +162,53 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||
return self.model
|
||||
|
||||
|
||||
# BertForNextSentencePrediction
|
||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
# BertForMaskedLM
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
BertLMPredictionHead:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
# BertForTokenClassification
|
||||
class BertForTokenClassificationPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# BertForNextSentencePrediction
|
||||
class BertForNextSentencePredictionPolicy(BertPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
|
@ -82,7 +84,6 @@ class GPT2Policy(Policy):
|
|||
}
|
||||
|
||||
def new_model_class(self):
|
||||
|
||||
return self.model
|
||||
|
||||
def postprocess(self):
|
||||
|
@ -94,3 +95,79 @@ class GPT2ModelPolicy(GPT2Policy):
|
|||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# GPT2LMHeadModel
|
||||
class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2LMHeadModel:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# GPT22DoubleHeadsModel
|
||||
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
addon_module = {
|
||||
GPT2DoubleHeadsModel:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True})
|
||||
])
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
return module_policy
|
||||
|
||||
def postprocess(self):
|
||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
param = nn.Parameter(param)
|
||||
setattr_(self.model, k, param)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# GPT2ForTokenClassification
|
||||
class GPT2ForTokenClassificationPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
# GPT2ForSequenceClassification
|
||||
class GPT2ForSequenceClassificationPolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
|
@ -6,83 +6,147 @@ from ..registry import ModelAttribute, model_zoo
|
|||
# ===============================
|
||||
# Register single-sentence BERT
|
||||
# ===============================
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
||||
|
||||
def data_gen_fn():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
# define data gen function
|
||||
def data_gen():
|
||||
# Generated from following code snippet
|
||||
#
|
||||
# from transformers import BertTokenizer
|
||||
# input = 'Hello, my dog is cute'
|
||||
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
# token_type_ids = tokenized_input['token_type_ids']
|
||||
input_ids = torch.tensor([[101, 7592, 1010, 2026, 3899, 2003, 10140, 102]], dtype=torch.int64)
|
||||
token_type_ids = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_pretraining():
|
||||
# pretraining data gen
|
||||
# `next_sentence_label` is the label for next sentence prediction, 0 or 1
|
||||
data = data_gen_for_lm()
|
||||
data['next_sentence_label'] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
# `labels` is the label for sequence classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([1], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[1, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_mcq():
|
||||
# multiple choice question data gen
|
||||
# Generated from following code snippet
|
||||
#
|
||||
# tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
# prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
# choice0 = "It is eaten with a fork and a knife."
|
||||
# choice1 = "It is eaten while held in the hand."
|
||||
# data = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
|
||||
# data = {k: v.unsqueeze(0) for k, v in encoding.items()}
|
||||
# data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
input_ids = torch.tensor([[[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037, 4825, 1010, 2003, 3591,
|
||||
4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2007, 1037, 9292, 1998, 1037, 5442, 1012, 102
|
||||
],
|
||||
[
|
||||
101, 1999, 3304, 1010, 10733, 2366, 1999, 5337, 10906, 1010, 2107, 2004, 2012, 1037,
|
||||
4825, 1010, 2003, 3591, 4895, 14540, 6610, 2094, 1012, 102, 2009, 2003, 8828, 2096,
|
||||
2218, 1999, 1996, 2192, 1012, 102, 0
|
||||
]]])
|
||||
token_type_ids = torch.tensor(
|
||||
[[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
|
||||
attention_mask = torch.tensor(
|
||||
[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0]]])
|
||||
labels = torch.tensor([0], dtype=torch.int64)
|
||||
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
|
||||
# define loss funciton
|
||||
loss_fn_for_bert_model = lambda x: x.pooler_output.mean()
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.BertConfig(hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
intermediate_size=256,
|
||||
hidden_dropout_prob=0,
|
||||
attention_probs_dropout_prob=0)
|
||||
|
||||
# register the BERT variants
|
||||
model_zoo.register(name='transformers_bert',
|
||||
model_fn=lambda: transformers.BertModel(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_bert_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_pretraining',
|
||||
model_fn=lambda: transformers.BertForPreTraining(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
data_gen_fn=data_gen_for_pretraining,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_lm_head_model',
|
||||
model_fn=lambda: transformers.BertLMHeadModel(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_masked_lm',
|
||||
model_fn=lambda: transformers.BertForMaskedLM(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_sequence_classification',
|
||||
model_fn=lambda: transformers.BertForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_token_classification',
|
||||
model_fn=lambda: transformers.BertForTokenClassification(config),
|
||||
data_gen_fn=data_gen_fn,
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
|
||||
# ===============================
|
||||
# Register multi-sentence BERT
|
||||
# ===============================
|
||||
def data_gen_for_next_sentence():
|
||||
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
||||
return encoding
|
||||
|
||||
|
||||
def data_gen_for_mcq():
|
||||
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
choice0 = "It is eaten with a fork and a knife."
|
||||
choice1 = "It is eaten while held in the hand."
|
||||
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
|
||||
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
|
||||
return encoding
|
||||
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_bert_for_next_sentence',
|
||||
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
|
||||
data_gen_fn=data_gen_for_next_sentence,
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_bert_for_mcq',
|
||||
model_fn=lambda: transformers.BertForMultipleChoice(config),
|
||||
data_gen_fn=data_gen_for_mcq,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
|
|
@ -11,47 +11,86 @@ SEQ_LENGTH = 16
|
|||
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
# Generated from following code snippet
|
||||
#
|
||||
# from transformers import GPT2Tokenizer
|
||||
# input = 'Hello, my dog is cute'
|
||||
# tokenized_input = tokenizer(input, return_tensors='pt')
|
||||
# input_ids = tokenized_input['input_ids']
|
||||
# attention_mask = tokenized_input['attention_mask']
|
||||
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]], dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
def seq_classification_data_gen():
|
||||
# batch sizes should be 1 if no padding token is defined.
|
||||
input_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
token_type_ids = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((1, SEQ_LENGTH), dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
|
||||
def data_gen_for_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['labels'] = data['input_ids'].clone()
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_token_classification():
|
||||
# token classification data gen
|
||||
# `labels` is the type not the token id for token classification, 0 or 1
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([[0, 0, 0, 0, 0, 0]], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_sequence_classification():
|
||||
# sequence classification data gen
|
||||
data = data_gen()
|
||||
data['labels'] = torch.tensor([0], dtype=torch.int64)
|
||||
return data
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
|
||||
# define loss function
|
||||
loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn = lambda x: x.loss
|
||||
|
||||
config = transformers.GPT2Config(n_layer=2,
|
||||
n_head=4,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
resid_pdrop=0,
|
||||
summary_first_dropout=0,
|
||||
hidden_dropout=0,
|
||||
problem_type="single_label_classification")
|
||||
|
||||
# register the following models
|
||||
model_zoo.register(name='transformers_gpt',
|
||||
model_fn=lambda: transformers.GPT2Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_gpt2_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_lm',
|
||||
model_fn=lambda: transformers.GPT2LMHeadModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_double_heads',
|
||||
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_token_classification',
|
||||
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
data_gen_fn=data_gen_for_token_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_gpt_for_sequence_classification',
|
||||
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
|
||||
data_gen_fn=seq_classification_data_gen,
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
|
|
@ -1,86 +1,30 @@
|
|||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
BertForSequenceClassification,
|
||||
BertLMHeadModel,
|
||||
BertModel,
|
||||
)
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
|
||||
def build_model(world_size, model_fn):
|
||||
config = BertConfig()
|
||||
config.hidden_dropout_prob = 0
|
||||
config.attention_probs_dropout_prob = 0
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output)
|
||||
|
||||
org_model = model_fn(config=config)
|
||||
org_model_forshard = copy.deepcopy(org_model)
|
||||
|
||||
org_model.to('cuda')
|
||||
# TODO: no need to transfer to cuda
|
||||
org_model_forshard.to('cuda')
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size,)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
def check_forward(org_model, sharded_model):
|
||||
input = 'Hello, my dog is cute'
|
||||
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
|
||||
#orgin model
|
||||
org_model.eval()
|
||||
org_out = org_model(**tokenized_input)
|
||||
|
||||
#shard model
|
||||
sharded_model.eval()
|
||||
shard_out = sharded_model(**tokenized_input)
|
||||
|
||||
assert torch.allclose(
|
||||
org_out[0], shard_out[0],
|
||||
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
|
||||
|
||||
|
||||
def check_backward(org_model, sharded_model):
|
||||
# prepare input
|
||||
input = 'Hello, my dog is cute'
|
||||
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
labels = tokenized_input['input_ids'].clone()
|
||||
labels[labels == tokenizer.pad_token_id] = -100
|
||||
tokenized_input['labels'] = labels
|
||||
|
||||
#orgin model
|
||||
org_model.train()
|
||||
org_out = org_model(**tokenized_input)
|
||||
org_loss = org_out.loss
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
|
||||
#shard model
|
||||
sharded_model.train()
|
||||
shard_out = sharded_model(**tokenized_input)
|
||||
shard_loss = shard_out.loss
|
||||
shard_loss.backward()
|
||||
shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
|
||||
# check grad equality
|
||||
if org_model.__class__.__name__ == 'BertModel':
|
||||
org_grad = org_model.encoder.layer[0].attention.self.query.weight.grad
|
||||
shard_grad = sharded_model.encoder.layer[0].attention.self.query.weight.grad
|
||||
else:
|
||||
org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
|
@ -89,36 +33,24 @@ def check_backward(org_model, sharded_model):
|
|||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
def check_bert(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
forward_list = [
|
||||
BertForMaskedLM,
|
||||
BertForPreTraining,
|
||||
BertLMHeadModel,
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# TODO: do not work yet
|
||||
# BertModel,
|
||||
# BertForSequenceClassification
|
||||
# BertForNextSentencePrediction,
|
||||
]
|
||||
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
|
||||
|
||||
for model_fn in forward_list:
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward(org_model, sharded_model)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
if model_fn in backward_lsit:
|
||||
check_backward(org_model, sharded_model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bert():
|
||||
spawn(check_bert, 2)
|
||||
|
||||
|
|
|
@ -1,117 +1,61 @@
|
|||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GPT2Config, GPT2Model
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
|
||||
def build_model(world_size, model_fn):
|
||||
config = GPT2Config()
|
||||
config.attn_pdrop = 0
|
||||
config.embd_pdrop = 0
|
||||
config.resid_pdrop = 0
|
||||
config.summary_first_dropout
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
|
||||
org_model = model_fn(config=config)
|
||||
org_model_forshard = copy.deepcopy(org_model)
|
||||
|
||||
org_model.to('cuda')
|
||||
# TODO: no need to transfer to cuda
|
||||
org_model_forshard.to('cuda')
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size,)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
def check_forward(org_model, sharded_model):
|
||||
input = 'Hello, my dog is cute'
|
||||
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
|
||||
#orgin model
|
||||
org_model.eval()
|
||||
org_out = org_model(**tokenized_input)
|
||||
|
||||
#shard model
|
||||
sharded_model.eval()
|
||||
shard_out = sharded_model(**tokenized_input)
|
||||
|
||||
assert torch.allclose(
|
||||
org_out[0], shard_out[0],
|
||||
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
|
||||
|
||||
|
||||
def check_backward(org_model, sharded_model):
|
||||
# prepare input
|
||||
input = 'Hello, my dog is cute'
|
||||
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
labels = tokenized_input['input_ids'].clone()
|
||||
labels[labels == tokenizer.pad_token_id] = -100
|
||||
# tokenized_input['labels'] = labels
|
||||
|
||||
#orgin model
|
||||
org_model.train()
|
||||
org_out = org_model(**tokenized_input)
|
||||
org_loss = org_out.loss
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
org_grad = org_model.h[0].attn.c_attn.weight.grad
|
||||
|
||||
#shard model
|
||||
sharded_model.train()
|
||||
shard_out = sharded_model(**tokenized_input)
|
||||
shard_loss = shard_out.loss
|
||||
shard_loss.backward()
|
||||
shard_grad = sharded_model.h[0].attn.c_attn.weight.grad
|
||||
|
||||
# check grad equality
|
||||
if org_model.__class__.__name__ == 'GPT2Model':
|
||||
org_grad = org_model.h[0].attn.c_attn.weight.grad
|
||||
shard_grad = sharded_model.h[0].attn.c_attn.weight.grad.transpose(0, 1).contiguous()
|
||||
else:
|
||||
org_grad = org_model.transformer.h[0].mlp.c_fc.weight.grad
|
||||
shard_grad = sharded_model.transformer.h[0].mlp.c_fc.weight.grad.transpose(0, 1).contiguous()
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=1)
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
def check_bert(rank, world_size, port):
|
||||
def check_gpt2(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
forward_list = [
|
||||
GPT2Model,
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
# TODO: do not work yet
|
||||
# BertModel,
|
||||
# BertForSequenceClassification
|
||||
# BertForNextSentencePrediction,
|
||||
]
|
||||
backward_lsit = []
|
||||
|
||||
for model_fn in forward_list:
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
print(name)
|
||||
# if name == 'transformers_gpt':
|
||||
# continue
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward(org_model, sharded_model)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
if model_fn in backward_lsit:
|
||||
check_backward(org_model, sharded_model)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gpt2():
|
||||
spawn(check_bert, 2)
|
||||
spawn(check_gpt2, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue