mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] Bert pipeline for shardformer and its tests (#4197)
* add pipeline forward * complete pipeline forward check * fix bert forward without pipeline * fix comments * discard useless line * add todo * clean prints * fix distribute layerspull/4445/head
parent
890774b2fb
commit
1094e0f0d3
|
@ -191,7 +191,7 @@ class Policy(ABC):
|
||||||
|
|
||||||
# deal with the rest layers
|
# deal with the rest layers
|
||||||
if remainder > 0:
|
if remainder > 0:
|
||||||
start_position = num_layers // 2 - remainder // 2
|
start_position = num_stages // 2 - remainder // 2
|
||||||
for i in range(start_position, start_position + remainder):
|
for i in range(start_position, start_position + remainder):
|
||||||
layers_per_stage[i] += 1
|
layers_per_stage[i] += 1
|
||||||
return layers_per_stage
|
return layers_per_stage
|
||||||
|
|
|
@ -13,6 +13,8 @@ from transformers.modeling_outputs import (
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
)
|
)
|
||||||
from transformers.models.bert.modeling_bert import (
|
from transformers.models.bert.modeling_bert import (
|
||||||
|
BertForMaskedLM,
|
||||||
|
BertForNextSentencePrediction,
|
||||||
BertForPreTraining,
|
BertForPreTraining,
|
||||||
BertForPreTrainingOutput,
|
BertForPreTrainingOutput,
|
||||||
BertLMHeadModel,
|
BertLMHeadModel,
|
||||||
|
@ -135,7 +137,6 @@ class BertPolicy(Policy):
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=BertLayer)
|
target_key=BertLayer)
|
||||||
|
|
||||||
# handle embedding layer
|
# handle embedding layer
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
description=[SubModuleReplacementDescription(
|
description=[SubModuleReplacementDescription(
|
||||||
|
@ -144,6 +145,7 @@ class BertPolicy(Policy):
|
||||||
)],
|
)],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=BertEmbeddings)
|
target_key=BertEmbeddings)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
def add_lm_head_policy(self, base_policy):
|
def add_lm_head_policy(self, base_policy):
|
||||||
|
@ -177,6 +179,15 @@ class BertModelPolicy(BertPolicy):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
def module_policy(self):
|
||||||
|
module_policy = super().module_policy()
|
||||||
|
from transformers.models.bert.modeling_bert import BertModel
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
module_policy[BertModel] = ModulePolicyDescription(
|
||||||
|
method_replacement={'forward': partial(bert_model_forward, stage_manager=self.pipeline_stage_manager)})
|
||||||
|
return module_policy
|
||||||
|
|
||||||
def get_held_layers(self) -> List[Module]:
|
def get_held_layers(self) -> List[Module]:
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
module = self.model
|
module = self.model
|
||||||
|
@ -444,6 +455,13 @@ def bert_model_forward(
|
||||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
if token_type_ids is None:
|
||||||
|
if hasattr(self.embeddings, "token_type_ids"):
|
||||||
|
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||||||
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||||||
|
token_type_ids = buffered_token_type_ids_expanded
|
||||||
|
else:
|
||||||
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
else:
|
else:
|
||||||
input_shape = hidden_states.size()[:-1]
|
input_shape = hidden_states.size()[:-1]
|
||||||
batch_size, seq_length = input_shape
|
batch_size, seq_length = input_shape
|
||||||
|
@ -466,14 +484,6 @@ def bert_model_forward(
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||||||
|
|
||||||
if token_type_ids is None:
|
|
||||||
if hasattr(self.embeddings, "token_type_ids"):
|
|
||||||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
|
||||||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
|
||||||
token_type_ids = buffered_token_type_ids_expanded
|
|
||||||
else:
|
|
||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
|
||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||||
|
@ -778,3 +788,131 @@ def bert_lmhead_forward(self: BertLMHeadModel,
|
||||||
hidden_states = outputs.get('hidden_states')
|
hidden_states = outputs.get('hidden_states')
|
||||||
# intermediate stage always return dict
|
# intermediate stage always return dict
|
||||||
return {'hidden_states': hidden_states}
|
return {'hidden_states': hidden_states}
|
||||||
|
|
||||||
|
|
||||||
|
def bert_for_masked_lm_forward(
|
||||||
|
self: BertForMaskedLM,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
):
|
||||||
|
#-> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||||
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||||||
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def bert_for_next_sentence_prediction_forward(
|
||||||
|
self: BertForNextSentencePrediction,
|
||||||
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
head_mask: Optional[torch.Tensor] = None,
|
||||||
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
|
labels: Optional[torch.Tensor] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
||||||
|
r"""
|
||||||
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||||
|
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
||||||
|
(see `input_ids` docstring). Indices should be in `[0, 1]`:
|
||||||
|
|
||||||
|
- 0 indicates sequence B is a continuation of sequence A,
|
||||||
|
- 1 indicates sequence B is a random sequence.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from transformers import AutoTokenizer, BertForNextSentencePrediction
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
>>> model = BertForNextSentencePrediction.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
||||||
|
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
||||||
|
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
||||||
|
>>> logits = outputs.logits
|
||||||
|
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
if "next_sentence_label" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
|
||||||
|
" `labels` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
labels = kwargs.pop("next_sentence_label")
|
||||||
|
if output_attentions:
|
||||||
|
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||||
|
output_attentions = False
|
||||||
|
if output_hidden_states:
|
||||||
|
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||||
|
output_hidden_states = False
|
||||||
|
if return_dict:
|
||||||
|
logger.warning_once('return_dict is not supported for pipeline models at the moment')
|
||||||
|
return_dict = False
|
||||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
|
outputs = bert_model_forward(
|
||||||
|
self.bert,
|
||||||
|
input_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_hidden_states=output_hidden_states,
|
||||||
|
return_dict=return_dict,
|
||||||
|
)
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
pooled_output = outputs[1]
|
||||||
|
seq_relationship_scores = self.cls(pooled_output)
|
||||||
|
|
||||||
|
next_sentence_loss = None
|
||||||
|
if labels is not None:
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
output = (seq_relationship_scores,) + outputs[2:]
|
||||||
|
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
||||||
|
|
||||||
|
return NextSentencePredictorOutput(
|
||||||
|
loss=next_sentence_loss,
|
||||||
|
logits=seq_relationship_scores,
|
||||||
|
hidden_states=outputs.hidden_states,
|
||||||
|
attentions=outputs.attentions,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = outputs.get('hidden_states')
|
||||||
|
# intermediate stage always return dict
|
||||||
|
return {'hidden_states': hidden_states}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from types import MethodType
|
||||||
from typing import Any, Callable, Dict, List, Union
|
from typing import Any, Callable, Dict, List, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -134,7 +135,8 @@ class ModelSharder(object):
|
||||||
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
|
def _replace_method(self, module: nn.Module, method_replacement: Dict[str, Callable]):
|
||||||
for method_name, new_method in method_replacement.items():
|
for method_name, new_method in method_replacement.items():
|
||||||
# bind the new method to the module
|
# bind the new method to the module
|
||||||
setattr(module, method_name, new_method.__get__(module, module.__class__))
|
bound_method = MethodType(new_method, module)
|
||||||
|
setattr(module, method_name, bound_method)
|
||||||
|
|
||||||
def _replace_sub_module(
|
def _replace_sub_module(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -2,6 +2,7 @@ import copy
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
from colossalai.lazy import LazyInitContext
|
from colossalai.lazy import LazyInitContext
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,6 +22,28 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
|
||||||
return org_model.cuda(), sharded_model.cuda()
|
return org_model.cuda(), sharded_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def build_pipeline_model(model_fn,
|
||||||
|
stage_manager=None,
|
||||||
|
enable_fused_normalization=False,
|
||||||
|
enable_tensor_parallelism=False,
|
||||||
|
use_lazy_init: bool = False):
|
||||||
|
ctx = LazyInitContext() if use_lazy_init else nullcontext()
|
||||||
|
with ctx:
|
||||||
|
# create new model
|
||||||
|
org_model = model_fn()
|
||||||
|
model_copy = copy.deepcopy(org_model)
|
||||||
|
if use_lazy_init:
|
||||||
|
ctx.materialize(org_model)
|
||||||
|
|
||||||
|
# shard model
|
||||||
|
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||||
|
enable_tensor_parallelism=enable_tensor_parallelism,
|
||||||
|
pipeline_stage_manager=stage_manager)
|
||||||
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
|
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||||
|
return org_model.cuda(), sharded_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
# prepare input
|
# prepare input
|
||||||
data = data_gen_fn()
|
data = data_gen_fn()
|
||||||
|
|
|
@ -0,0 +1,85 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||||
|
from colossalai.testing import (
|
||||||
|
assert_hf_output_close,
|
||||||
|
clear_cache_before_run,
|
||||||
|
parameterize,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_shardformer.test_model._utils import build_model, build_pipeline_model, run_forward
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
|
# check forward
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@parameterize('enable_fused_normalization', [False])
|
||||||
|
@parameterize('enable_tensor_parallelism', [False])
|
||||||
|
@parameterize('use_lazy_init', [False])
|
||||||
|
#TODO: merge this into test_shard_bert
|
||||||
|
def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
||||||
|
DP_DIM, PP_DIM = 0, 1
|
||||||
|
DP_SIZE, PP_SIZE = 2, 2
|
||||||
|
RANK_TO_COORDINATE = {
|
||||||
|
0: (0, 0),
|
||||||
|
1: (0, 1),
|
||||||
|
2: (1, 0),
|
||||||
|
3: (1, 1),
|
||||||
|
}
|
||||||
|
PP_RANKS_IN_GROUP = {
|
||||||
|
0: [0, 1],
|
||||||
|
1: [0, 1],
|
||||||
|
2: [2, 3],
|
||||||
|
3: [2, 3],
|
||||||
|
}
|
||||||
|
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||||
|
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||||
|
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||||
|
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||||
|
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
if name == 'transformers_bert':
|
||||||
|
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||||
|
enable_tensor_parallelism, use_lazy_init)
|
||||||
|
|
||||||
|
if stage_manager.stage == 0:
|
||||||
|
attention_mask = torch.ones_like(x).cuda()
|
||||||
|
output = sharded_model(input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
|
||||||
|
# print(output['hidden_states'].shape)
|
||||||
|
assert output['hidden_states'].shape == (2, 3, 128)
|
||||||
|
else:
|
||||||
|
attention_mask = torch.ones((2, 3)).cuda()
|
||||||
|
output = sharded_model(hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
stage_manager=stage_manager)
|
||||||
|
# print(output[0].shape)
|
||||||
|
assert output[0].shape == (2, 3, 128)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def check_bert(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
run_bert_test()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def test_bert():
|
||||||
|
spawn(check_bert, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_bert()
|
Loading…
Reference in New Issue