[pipeline] add bert_for_pretraining bert_lmhead forward and policy (#4172)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining

* add bert_for_pretraining forward and policy

* fix typos

* cancel warning

* change the imediate output to default dict

* change the default output of get_shared_params
pull/4445/head
Jianghai 2023-07-06 14:49:10 +08:00 committed by Hongxin Liu
parent d35bd7d0e6
commit c5ea728016
4 changed files with 498 additions and 115 deletions

View File

@ -10,9 +10,15 @@ from transformers.modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
)
from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel
from transformers.utils import logging
from transformers.models.bert.modeling_bert import (
BertForPreTraining,
BertForPreTrainingOutput,
BertLMHeadModel,
BertModel,
)
from transformers.utils import ModelOutput, logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -22,24 +28,23 @@ logger = logging.get_logger(__name__)
def bert_model_forward(
self: BertModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
self: BertModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
# labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
# this is from the previous stage
hidden_states: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
):
# TODO: add explaination of the output here.
r"""
@ -85,10 +90,6 @@ def bert_model_forward(
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
attention_mask = extended_attention_mask
else:
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape
@ -119,22 +120,10 @@ def bert_model_forward(
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = hidden_states if hidden_states is not None else None
if stage_manager.is_first_stage():
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
attention_mask = extended_attention_mask
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
@ -146,7 +135,24 @@ def bert_model_forward(
else:
encoder_extended_attention_mask = None
# inherit from bert_layer
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = hidden_states if hidden_states is not None else None
if stage_manager.is_first_stage():
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
# inherit from bert_layer,this should be changed when we add the feature to record hidden_states
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@ -221,34 +227,35 @@ def bert_model_forward(
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + layer_outputs[1:]
# return dict is not supported at this moment
else:
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# output of non-first and non-last stages:
if not return_dict:
return tuple(v for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
] if v is not None)
# return dict is not supported at this moment
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# output of non-first and non-last stages: must be a dict
else:
# intermediate stage always return dict
return {
'hidden_states': hidden_states,
}
# The layer partition policy for bertmodel
class BertModelPolicy(Policy):
def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
def __init__(
self,
stage_manager: PipelineStageManager,
num_layers: int,
):
super().__init__(stage_manager=stage_manager)
self.stage_manager = stage_manager
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)
self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages)
def get_hold_layers(self, module: BertModel) -> List[Module]:
"""
@ -266,10 +273,10 @@ class BertModelPolicy(Policy):
def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]:
'''no shared params in bertmodel'''
pass
return []
def replace_forward(self, module: Module) -> None:
module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model)
module.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module)
def bert_for_pretraining_forward(
@ -285,53 +292,74 @@ def bert_for_pretraining_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.LongTensor] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
sequence_output, pooled_output = outputs[:2]
outputs = bert_model_forward(self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
all_cross_attentions = None
if stage_manager.is_last_stage():
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
# the last stage for pretraining model
total_loss = None
if labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
total_loss = None
if labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
if not return_dict:
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
if not return_dict:
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
else:
hidden_states = outputs.get('hidden_states')
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# intermediate stage always return dict
return {
'hidden_states': hidden_states,
}
class BertForPreTrainingPolicy(Policy):
def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
def __init__(self, stage_manager: PipelineStageManager, num_layers: int):
super().__init__(stage_manager=stage_manager)
self.stage_manager = stage_manager
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)
self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages)
def get_hold_layers(self, module: BertForPreTraining) -> List[Module]:
"""
@ -352,25 +380,144 @@ class BertForPreTrainingPolicy(Policy):
def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]:
'''no shared params in bertmodel'''
pass
return []
def replace_forward(self, module: Module) -> None:
module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager),
module.model)
module.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager),
module.forward)
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
def bert_lmhead_forward(self: BertLMHeadModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
divide layers into stages
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
use_cache = False
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if return_dict:
logger.warning_once('return_dict is not supported for pipeline models at the moment')
return_dict = False
outputs = bert_model_forward(self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
all_cross_attentions = None
if stage_manager.is_last_stage():
sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)
lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=lm_loss,
logits=prediction_scores,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
else:
hidden_states = outputs.get('hidden_states')
# intermediate stage always return dict
return {'hidden_states': hidden_states}
class BertLMHeadModelPolicy(Policy):
def __init__(self, stage_manager: PipelineStageManager, num_layers: int):
super().__init__(stage_manager=stage_manager)
self.stage_manager = stage_manager
self.layers_per_stage = self.distribute_layers(num_layers, stage_manager.num_stages)
def get_hold_layers(self, module: BertLMHeadModel) -> List[Module]:
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
get pipeline layers for current stage
"""
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.bert.embeddings)
start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
if self.stage_manager.is_last_stage():
hold_layers.append(module.bert.pooler)
hold_layers.append(module.cls)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages
return hold_layers
# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def get_shared_params(self, module: BertLMHeadModel) -> List[Dict[int, Tensor]]:
'''no shared params in bertmodel'''
return []
def replace_forward(self, module: Module) -> None:
module.forward = MethodType(partial(bert_lmhead_forward, stage_manager=self.stage_manager), module)

View File

@ -0,0 +1,118 @@
import pytest
import torch
import torch.distributed as dist
from transformers.models.bert import BertConfig
from transformers.models.bert.modeling_bert import BertForPreTraining
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.policy.bert import BertForPreTrainingPolicy, bert_for_pretraining_forward
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_for_pretraining_forward():
configuration = BertConfig()
model = BertForPreTraining(configuration)
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)
x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_for_pretraining_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 768)
print('start the training')
else:
attention_mask = torch.ones((2, 3))
output = bert_for_pretraining_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output[0].shape)
assert output[0].shape == (2, 3, 30522)
print('end the training')
print(output)
# assert output[1].shape == (2, 768)
def check_bert_for_pretraining_policy():
configuration = BertConfig()
model = BertForPreTraining(configuration)
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
model_policy = BertForPreTrainingPolicy(stage_manager, len(model.bert.encoder.layer))
assert model_policy.layers_per_stage == [6, 6]
layers = model_policy.get_hold_layers(model)
for layer in layers:
print(layer)
def run_dist_model(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_for_pretraining_forward()
def run_dist_policy(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_for_pretraining_policy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_for_pretraining_forward():
spawn(run_dist_model, 4)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_for_pretraining_policy():
spawn(run_dist_policy, 4)
if __name__ == "__main__":
"""test the bert for pretraining model forward and bert for pretraining model policy"""
test_bert_for_pretraining_forward()
test_bert_for_pretraining_policy()

View File

@ -0,0 +1,118 @@
import pytest
import torch
import torch.distributed as dist
from transformers.models.bert import BertConfig
from transformers.models.bert.modeling_bert import BertLMHeadModel
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.policy.bert import BertLMHeadModelPolicy, bert_lmhead_forward
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_lmhead_forward():
configuration = BertConfig()
model = BertLMHeadModel(configuration)
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)
x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_lmhead_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 768)
print('start the training')
else:
attention_mask = torch.ones((2, 3))
output = bert_lmhead_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output[0].shape)
assert output[0].shape == (2, 3, 30522)
print('end the training')
print(output)
# assert output[1].shape == (2, 768)
def check_bert_lmhead_policy():
configuration = BertConfig()
model = BertLMHeadModel(configuration)
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
model_policy = BertLMHeadModelPolicy(stage_manager, len(model.bert.encoder.layer))
assert model_policy.layers_per_stage == [6, 6]
layers = model_policy.get_hold_layers(model)
for layer in layers:
print(layer)
def run_dist_model(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_lmhead_forward()
def run_dist_policy(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_lmhead_policy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_lmhead_forward():
spawn(run_dist_model, 4)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_lmhead_policy():
spawn(run_dist_policy, 4)
if __name__ == "__main__":
"""test the bert for pretraining model forward and bert for pretraining model policy"""
test_bert_lmhead_forward()
test_bert_lmhead_policy()

View File

@ -39,11 +39,11 @@ def check_bert_model_forward():
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
print(output[0].shape)
assert output[0].shape == (2, 3, 768)
print(output['hidden_states'].shape)
assert output['hidden_states'].shape == (2, 3, 768)
print('start the training')
else:
attention_mask = torch.ones((2, 12, 3, 3))
attention_mask = torch.ones((2, 3))
output = bert_model_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
@ -78,7 +78,7 @@ def check_bert_model_policy():
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2)
model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer))
assert model_policy.layers_per_stage == [6, 6]
layers = model_policy.get_hold_layers(model)
for layer in layers: