[pipeline] Bert pipeline for shardformer and its tests (#4197)

* add pipeline forward

* complete pipeline forward check

* fix bert forward without pipeline

* fix comments

* discard useless line

* add todo

* clean prints

* fix distribute layers
pull/4445/head
Jianghai 2023-07-10 13:58:58 +08:00 committed by Hongxin Liu
parent 890774b2fb
commit 1094e0f0d3
5 changed files with 259 additions and 11 deletions

View File

@ -191,7 +191,7 @@ class Policy(ABC):
# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
start_position = num_stages // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

View File

@ -13,6 +13,8 @@ from transformers.modeling_outputs import (
CausalLMOutputWithCrossAttentions,
)
from transformers.models.bert.modeling_bert import (
BertForMaskedLM,
BertForNextSentencePrediction,
BertForPreTraining,
BertForPreTrainingOutput,
BertLMHeadModel,
@ -135,7 +137,6 @@ class BertPolicy(Policy):
],
policy=policy,
target_key=BertLayer)
# handle embedding layer
self.append_or_create_submodule_replacement(
description=[SubModuleReplacementDescription(
@ -144,6 +145,7 @@ class BertPolicy(Policy):
)],
policy=policy,
target_key=BertEmbeddings)
return policy
def add_lm_head_policy(self, base_policy):
@ -177,6 +179,15 @@ class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
from transformers.models.bert.modeling_bert import BertModel
if self.pipeline_stage_manager:
# set None as default
module_policy[BertModel] = ModulePolicyDescription(
method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)})
return module_policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
@ -444,6 +455,13 @@ def bert_model_forward(
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
else:
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape
@ -466,14 +484,6 @@ def bert_model_forward(
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
@ -778,3 +788,131 @@ def bert_lmhead_forward(self: BertLMHeadModel,
hidden_states = outputs.get('hidden_states')
# intermediate stage always return dict
return {'hidden_states': hidden_states}
def bert_for_masked_lm_forward(
self: BertForMaskedLM,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
):
#-> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
"""
pass
def bert_for_next_sentence_prediction_forward(
self: BertForNextSentencePrediction,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
**kwargs,
):
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
(see `input_ids` docstring). Indices should be in `[0, 1]`:
- 0 indicates sequence B is a continuation of sequence A,
- 1 indicates sequence B is a random sequence.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, BertForNextSentencePrediction
>>> import torch
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
>>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
>>> logits = outputs.logits
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
```
"""
if "next_sentence_label" in kwargs:
warnings.warn(
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
" `labels` instead.",
FutureWarning,
)
labels = kwargs.pop("next_sentence_label")
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
seq_relationship_scores = self.cls(pooled_output)
next_sentence_loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
if not return_dict:
output = (seq_relationship_scores,) + outputs[2:]
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
return NextSentencePredictorOutput(
loss=next_sentence_loss,
logits=seq_relationship_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
# intermediate stage always return dict
return {'hidden_states': hidden_states}

View File

@ -1,3 +1,4 @@
from types import MethodType
from typing import Any, Callable, Dict, List, Union
import torch.nn as nn
@ -134,7 +135,8 @@ class ModelSharder(object):
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
for method_name, new_method in method_replacement.items():
# bind the new method to the module
setattr(module, method_name, new_method.__get__(module, module.__class__))
bound_method = MethodType(new_method, module)
setattr(module, method_name, bound_method)
def _replace_sub_module(
self,

View File

@ -2,6 +2,7 @@ import copy
from contextlib import nullcontext
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@ -21,6 +22,28 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
return org_model.cuda(), sharded_model.cuda()
def build_pipeline_model(model_fn,
stage_manager=None,
enable_fused_normalization=False,
enable_tensor_parallelism=False,
use_lazy_init: bool = False):
ctx = LazyInitContext() if use_lazy_init else nullcontext()
with ctx:
# create new model
org_model = model_fn()
model_copy = copy.deepcopy(org_model)
if use_lazy_init:
ctx.materialize(org_model)
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
pipeline_stage_manager=stage_manager)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
return org_model.cuda(), sharded_model.cuda()
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()

View File

@ -0,0 +1,85 @@
import pytest
import torch
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.logging import disable_existing_loggers
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
pass
@parameterize('enable_fused_normalization', [False])
@parameterize('enable_tensor_parallelism', [False])
@parameterize('use_lazy_init', [False])
#TODO: merge this into test_shard_bert
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
x = torch.randint(0, 1000, (2, 3)).cuda()
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name == 'transformers_bert':
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
enable_tensor_parallelism, use_lazy_init)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x).cuda()
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
# print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 128)
else:
attention_mask = torch.ones((2, 3)).cuda()
output = sharded_model(hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
# print(output[0].shape)
assert output[0].shape == (2, 3, 128)
torch.cuda.empty_cache()
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bert_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert():
spawn(check_bert, 4)
if __name__ == "__main__":
test_bert()