mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] Bert pipeline for shardformer and its tests (#4197)
* add pipeline forward * complete pipeline forward check * fix bert forward without pipeline * fix comments * discard useless line * add todo * clean prints * fix distribute layerspull/4445/head
parent
890774b2fb
commit
1094e0f0d3
|
@ -191,7 +191,7 @@ class Policy(ABC):
|
|||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = num_layers // 2 - remainder // 2
|
||||
start_position = num_stages // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
|
|
@ -13,6 +13,8 @@ from transformers.modeling_outputs import (
|
|||
CausalLMOutputWithCrossAttentions,
|
||||
)
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertForMaskedLM,
|
||||
BertForNextSentencePrediction,
|
||||
BertForPreTraining,
|
||||
BertForPreTrainingOutput,
|
||||
BertLMHeadModel,
|
||||
|
@ -135,7 +137,6 @@ class BertPolicy(Policy):
|
|||
],
|
||||
policy=policy,
|
||||
target_key=BertLayer)
|
||||
|
||||
# handle embedding layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[SubModuleReplacementDescription(
|
||||
|
@ -144,6 +145,7 @@ class BertPolicy(Policy):
|
|||
)],
|
||||
policy=policy,
|
||||
target_key=BertEmbeddings)
|
||||
|
||||
return policy
|
||||
|
||||
def add_lm_head_policy(self, base_policy):
|
||||
|
@ -177,6 +179,15 @@ class BertModelPolicy(BertPolicy):
|
|||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def module_policy(self):
|
||||
module_policy = super().module_policy()
|
||||
from transformers.models.bert.modeling_bert import BertModel
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
module_policy[BertModel] = ModulePolicyDescription(
|
||||
method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)})
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[Module]:
|
||||
"""Get pipeline layers for current stage."""
|
||||
module = self.model
|
||||
|
@ -444,6 +455,13 @@ def bert_model_forward(
|
|||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
batch_size, seq_length = input_shape
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
else:
|
||||
input_shape = hidden_states.size()[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
|
@ -466,14 +484,6 @@ def bert_model_forward(
|
|||
if attention_mask is None:
|
||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||
|
||||
if token_type_ids is None:
|
||||
if hasattr(self.embeddings, "token_type_ids"):
|
||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||
token_type_ids = buffered_token_type_ids_expanded
|
||||
else:
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||
|
@ -778,3 +788,131 @@ def bert_lmhead_forward(self: BertLMHeadModel,
|
|||
hidden_states = outputs.get('hidden_states')
|
||||
# intermediate stage always return dict
|
||||
return {'hidden_states': hidden_states}
|
||||
|
||||
|
||||
def bert_for_masked_lm_forward(
|
||||
self: BertForMaskedLM,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
):
|
||||
#-> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def bert_for_next_sentence_prediction_forward(
|
||||
self: BertForNextSentencePrediction,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
labels: Optional[torch.Tensor] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
**kwargs,
|
||||
):
|
||||
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
||||
(see `input_ids` docstring). Indices should be in `[0, 1]`:
|
||||
|
||||
- 0 indicates sequence B is a continuation of sequence A,
|
||||
- 1 indicates sequence B is a random sequence.
|
||||
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, BertForNextSentencePrediction
|
||||
>>> import torch
|
||||
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||
>>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
|
||||
|
||||
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
||||
|
||||
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
||||
>>> logits = outputs.logits
|
||||
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
||||
```
|
||||
"""
|
||||
|
||||
if "next_sentence_label" in kwargs:
|
||||
warnings.warn(
|
||||
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
|
||||
" `labels` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
labels = kwargs.pop("next_sentence_label")
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if return_dict:
|
||||
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||
return_dict = False
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
if stage_manager.is_last_stage():
|
||||
pooled_output = outputs[1]
|
||||
seq_relationship_scores = self.cls(pooled_output)
|
||||
|
||||
next_sentence_loss = None
|
||||
if labels is not None:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
|
||||
|
||||
if not return_dict:
|
||||
output = (seq_relationship_scores,) + outputs[2:]
|
||||
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
||||
|
||||
return NextSentencePredictorOutput(
|
||||
loss=next_sentence_loss,
|
||||
logits=seq_relationship_scores,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
else:
|
||||
hidden_states = outputs.get('hidden_states')
|
||||
# intermediate stage always return dict
|
||||
return {'hidden_states': hidden_states}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from types import MethodType
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
@ -134,7 +135,8 @@ class ModelSharder(object):
|
|||
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
|
||||
for method_name, new_method in method_replacement.items():
|
||||
# bind the new method to the module
|
||||
setattr(module, method_name, new_method.__get__(module, module.__class__))
|
||||
bound_method = MethodType(new_method, module)
|
||||
setattr(module, method_name, bound_method)
|
||||
|
||||
def _replace_sub_module(
|
||||
self,
|
||||
|
|
|
@ -2,6 +2,7 @@ import copy
|
|||
from contextlib import nullcontext
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
|
@ -21,6 +22,28 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
|
|||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
||||
|
||||
def build_pipeline_model(model_fn,
|
||||
stage_manager=None,
|
||||
enable_fused_normalization=False,
|
||||
enable_tensor_parallelism=False,
|
||||
use_lazy_init: bool = False):
|
||||
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||
with ctx:
|
||||
# create new model
|
||||
org_model = model_fn()
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
if use_lazy_init:
|
||||
ctx.materialize(org_model)
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||
pipeline_stage_manager=stage_manager)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||
return org_model.cuda(), sharded_model.cuda()
|
||||
|
||||
|
||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# prepare input
|
||||
data = data_gen_fn()
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
pass
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [False])
|
||||
@parameterize('enable_tensor_parallelism', [False])
|
||||
@parameterize('use_lazy_init', [False])
|
||||
#TODO: merge this into test_shard_bert
|
||||
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
if name == 'transformers_bert':
|
||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||
enable_tensor_parallelism, use_lazy_init)
|
||||
|
||||
if stage_manager.stage == 0:
|
||||
attention_mask = torch.ones_like(x).cuda()
|
||||
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||
# print(output['hidden_states'].shape)
|
||||
assert output['hidden_states'].shape == (2, 3, 128)
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3)).cuda()
|
||||
output = sharded_model(hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
# print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 128)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_bert(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_bert_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_bert():
|
||||
spawn(check_bert, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_bert()
|
Loading…
Reference in New Issue