mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] shardformer support opt models (#4091)
* [shardformer] shardformer support opt models * [shardformer] shardformer support opt models, fix * [shardformer] shardformer support opt models, fix * [shardformer] shardformer support opt models, fixpull/4157/head
parent
d33a44e8c3
commit
ac80937138
|
@ -68,6 +68,16 @@ _POLICY_LIST = {
|
|||
PolicyLocation(file_name="gpt2", class_name="GPT2ForTokenClassificationPolicy"),
|
||||
"transformers.models.gpt2.modeling_gpt2.GPT2ForSequenceClassification":
|
||||
PolicyLocation(file_name="gpt2", class_name="GPT2ForSequenceClassificationPolicy"),
|
||||
|
||||
# OPT
|
||||
"transformers.models.opt.modeling_opt.OPTModel":
|
||||
PolicyLocation(file_name="opt", class_name="OPTModelPolicy"),
|
||||
"transformers.models.opt.modeling_opt.OPTForCausalLM":
|
||||
PolicyLocation(file_name="opt", class_name="OPTForCausalLMPolicy"),
|
||||
"transformers.models.opt.modeling_opt.OPTForSequenceClassification":
|
||||
PolicyLocation(file_name="opt", class_name="OPTForSequenceClassificationPolicy"),
|
||||
"transformers.models.opt.modeling_opt.OPTForQuestionAnswering":
|
||||
PolicyLocation(file_name="opt", class_name="OPTForQuestionAnsweringPolicy"),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
from transformers.models.opt.modeling_opt import (
|
||||
OPTAttention,
|
||||
OPTDecoder,
|
||||
OPTDecoderLayer,
|
||||
OPTForCausalLM,
|
||||
OPTForSequenceClassification,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
||||
class OPTPolicy(Policy):
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
base_policy = {
|
||||
OPTDecoder:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
)
|
||||
]),
|
||||
OPTDecoderLayer:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc1",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="fc2",
|
||||
target_module=Linear1D_Row,
|
||||
)
|
||||
]),
|
||||
OPTAttention:
|
||||
ModulePolicyDescription(attribute_replacement={
|
||||
"embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="out_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
]),
|
||||
}
|
||||
if self.shard_config.fused_layernorm:
|
||||
base_policy[OPTDecoder].sub_module_replacement.append(
|
||||
SubModuleReplacementDescription(suffix="final_layer_norm",
|
||||
target_module=FusedLayerNorm,
|
||||
ignore_if_not_exist=True))
|
||||
base_policy[OPTDecoderLayer].sub_module_replacement.extend([
|
||||
SubModuleReplacementDescription(suffix="self_attn_layer_norm",
|
||||
target_module=FusedLayerNorm,
|
||||
ignore_if_not_exist=True),
|
||||
SubModuleReplacementDescription(suffix="final_layer_norm",
|
||||
target_module=FusedLayerNorm,
|
||||
ignore_if_not_exist=True)
|
||||
])
|
||||
return base_policy
|
||||
|
||||
def new_model_class(self):
|
||||
return None
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class OPTModelPolicy(OPTPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
class OPTForCausalLMPolicy(OPTPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
new_item = {
|
||||
OPTForCausalLM:
|
||||
ModulePolicyDescription(attribute_replacement={},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True))
|
||||
])
|
||||
}
|
||||
|
||||
policy.update(new_item)
|
||||
return policy
|
||||
|
||||
|
||||
class OPTForSequenceClassificationPolicy(OPTPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
||||
class OPTForQuestionAnsweringPolicy(OPTPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
|
@ -11,14 +11,47 @@ SEQ_LENGTH = 16
|
|||
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
def data_gen_for_causal_lm():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = labels
|
||||
return data
|
||||
|
||||
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
|
||||
|
||||
def data_gen_for_sequence_classification():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = torch.tensor([1])
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_question_answering():
|
||||
# LM data gen
|
||||
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
|
||||
data = data_gen()
|
||||
data['start_positions'] = torch.tensor([0])
|
||||
data['end_positions'] = torch.tensor([1])
|
||||
return data
|
||||
|
||||
|
||||
output_transform_fn = lambda x: x
|
||||
loss_fn_for_opt_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn_for_lm = lambda x: x.loss
|
||||
config = transformers.OPTConfig(
|
||||
hidden_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
dropout=0,
|
||||
)
|
||||
|
||||
# register the following models
|
||||
# transformers.OPTModel,
|
||||
|
@ -27,9 +60,23 @@ model_zoo.register(name='transformers_opt',
|
|||
model_fn=lambda: transformers.OPTModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_opt_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_opt_for_causal_lm',
|
||||
model_fn=lambda: transformers.OPTForCausalLM(config),
|
||||
data_gen_fn=data_gen,
|
||||
data_gen_fn=data_gen_for_causal_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_opt_for_question_answering',
|
||||
model_fn=lambda: transformers.OPTForQuestionAnswering(config),
|
||||
data_gen_fn=data_gen_for_question_answering,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_opt_for_sequence_classification',
|
||||
model_fn=lambda: transformers.OPTForSequenceClassification(config),
|
||||
data_gen_fn=data_gen_for_sequence_classification,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
|
|
@ -11,10 +11,9 @@ from tests.kit.model_zoo import model_zoo
|
|||
@clear_cache_before_run()
|
||||
def test_opt():
|
||||
sub_registry = model_zoo.get_sub_registry('transformers_opt')
|
||||
|
||||
for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
|
||||
model = model_fn()
|
||||
trace_model_and_compare_output(model, data_gen_fn)
|
||||
trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'start_positions', 'end_positions'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -25,7 +25,6 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
|||
# switch to train mode
|
||||
original_model.train()
|
||||
sharded_model.train()
|
||||
|
||||
# run forward
|
||||
org_output = original_model(**data)
|
||||
org_output = output_transform_fn(org_output)
|
||||
|
@ -34,5 +33,4 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
|||
shard_output = sharded_model(**data)
|
||||
shard_output = output_transform_fn(shard_output)
|
||||
shard_loss = loss_fn(shard_output)
|
||||
|
||||
return org_output, org_loss, shard_output, shard_loss
|
||||
return org_output, org_loss, shard_output, shard_loss
|
|
@ -0,0 +1,67 @@
|
|||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)
|
||||
|
||||
# run backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
# check grad
|
||||
if hasattr(org_model, 'model'):
|
||||
opt_model = org_model.model
|
||||
shard_opt_model = sharded_model.model
|
||||
else:
|
||||
opt_model = org_model
|
||||
shard_opt_model = sharded_model
|
||||
|
||||
org_grad = opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad = shard_opt_model.decoder.layers[0].self_attn.q_proj.weight.grad
|
||||
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
|
||||
def check_OPTModel(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_OPTModel():
|
||||
spawn(check_OPTModel, 4)
|
Loading…
Reference in New Issue