[pipeline]add pipeline policy and bert forward (#4130)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt
pull/4445/head
Jianghai 2023-07-04 13:46:16 +08:00 committed by Hongxin Liu
parent f51ce1bc8e
commit e8e7e49243
6 changed files with 786 additions and 1 deletions

View File

@ -0,0 +1,22 @@
from typing import Any, Dict, List, Optional, Tuple, Type
from torch import Tensor
from torch.nn import Module, Parameter
from colossalai.pipeline.stage_manager import PipelineStageManager
from .base import Policy
from .bert import BertModel, BertModelPolicy
POLICY_MAP: Dict[Type[Module], Type[Policy]] = {
BertModel: BertModelPolicy,
}
def pipeline_parallelize(
model: Module,
stage_manager: PipelineStageManager) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
if type(model) not in POLICY_MAP:
raise NotImplementedError(f"Policy for {type(model)} not implemented")
policy = POLICY_MAP[type(model)](stage_manager)
return policy.parallelize_model(model)

View File

@ -0,0 +1,108 @@
from typing import Any, Dict, List, Optional, Tuple
from colossalai.lazy import LazyTensor
from torch import Tensor
from torch.nn import Module, Parameter
from colossalai.pipeline.stage_manager import PipelineStageManager
class Policy:
def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager
def setup_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor]]:
"""Setup model for pipeline parallel
Args:
module (Module): Module to be setup
Returns:
Tuple[Dict[str, Parameter], Dict[str, Tensor]]: Hold parameters and buffers
"""
hold_params = set()
hold_buffers = set()
def init_layer(layer: Module):
for p in layer.parameters():
if isinstance(p, LazyTensor):
p.materialize()
p.data = p.cuda()
hold_params.add(p)
for b in layer.buffers():
if isinstance(b, LazyTensor):
b.materialize()
b.data = b.cuda()
hold_buffers.add(b)
hold_layers = self.get_hold_layers(module)
for layer in hold_layers:
init_layer(layer)
hold_params_dict = {}
hold_buffers_dict = {}
# release other tensors
for n, p in module.named_parameters():
if p in hold_params:
hold_params_dict[n] = p
else:
if isinstance(p, LazyTensor):
p.materialize()
p.data = p.cuda()
p.storage().resize_(0)
for n, b in module.named_buffers():
if b in hold_buffers:
hold_buffers_dict[n] = b
else:
if isinstance(b, LazyTensor):
b.materialize()
b.data = b.cuda()
# FIXME(ver217): use meta tensor may be better
b.storage().resize_(0)
return hold_params_dict, hold_buffers_dict
def replace_forward(self, module: Module) -> None:
"""Replace module forward in place. This method should be implemented by subclass. The output of internal layers must be a dict
Args:
module (Module): _description_
"""
raise NotImplementedError
def get_hold_layers(self, module: Module) -> List[Module]:
"""Get layers that should be hold in current stage. This method should be implemented by subclass.
Args:
module (Module): Module to be setup
Returns:
List[Module]: List of layers that should be hold in current stage
"""
raise NotImplementedError
def get_shared_params(self, module: Module) -> List[Dict[int, Tensor]]:
"""Get parameters that should be shared across stages. This method should be implemented by subclass.
Args:
module (Module): Module to be setup
Returns:
List[Module]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
raise NotImplementedError
def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
"""Parallelize model for pipeline parallel
Args:
module (Module): Module to be setup
Returns:
Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]: Hold parameters, buffers and shared parameters
"""
hold_params, hold_buffers = self.setup_model(module)
self.replace_forward(module)
shared_params = self.get_shared_params(module)
return hold_params, hold_buffers, shared_params

View File

@ -0,0 +1,390 @@
from functools import partial
from types import MethodType
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
)
from transformers.models.bert.modeling_bert import BertForPreTraining, BertForPreTrainingOutput, BertModel
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from .base import Policy
logger = logging.get_logger(__name__)
def bert_model_forward(
self: BertModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
#labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage
):
#TODO: add explaination of the output here.
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
# debugging
# preprocess:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
attention_mask = extended_attention_mask
else:
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
hidden_states = hidden_states if hidden_states is not None else None
if stage_manager.is_first_stage():
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
#inherit from bert_layer
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
if self.encoder.gradient_checkpointing and self.encoder.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
next_decoder_cache = () if use_cache else None
#calculate the num_layers
num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages
start_layer = stage_manager.stage * num_layers_per_stage
end_layer = (stage_manager.stage + 1) * num_layers_per_stage
#layer_outputs
layer_outputs = hidden_states if hidden_states is not None else None
for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer):
if stage_manager.is_first_stage() and idx == 0:
encoder_attention_mask = encoder_extended_attention_mask
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[idx] if head_mask is not None else None
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.encoder.gradient_checkpointing and self.encoder.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
#end of a stage loop
sequence_output = layer_outputs[0] if layer_outputs is not None else None
if stage_manager.is_last_stage():
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + layer_outputs[1:]
#output of non-first and non-last stages:
if not return_dict:
return tuple(v for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
] if v is not None)
#return dict is not supported at this moment
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
# The layer partition policy for bertmodel
class BertModelPolicy(Policy):
def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
self.stage_manager = stage_manager
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)
def get_hold_layers(self, module: BertModel) -> List[Module]:
"""
get pipeline layers for current stage
"""
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.embeddings)
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
num_layers_per_stage_accumulated[self.stage_manager.stage]])
if self.stage_manager.is_last_stage():
hold_layers.append(module.pooler)
return hold_layers
def get_shared_params(self, module: BertModel) -> List[Dict[int, Tensor]]:
'''no shared params in bertmodel'''
pass
def replace_forward(self, module: Module) -> None:
module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model)
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages
# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def bert_for_pretraining_forward(
self: BertForPreTraining,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
next_sentence_label: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
hidden_states: Optional[torch.LongTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
if not return_dict:
output = (prediction_scores, seq_relationship_score) + outputs[2:]
return ((total_loss,) + output) if total_loss is not None else output
return BertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class BertForPreTrainingPolicy(Policy):
def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
self.stage_manager = stage_manager
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)
def get_hold_layers(self, module: BertForPreTraining) -> List[Module]:
"""
get pipeline layers for current stage
"""
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.bert.embeddings)
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
num_layers_per_stage_accumulated[self.stage_manager.stage]])
if self.stage_manager.is_last_stage():
hold_layers.append(module.cls)
return hold_layers
def get_shared_params(self, module: BertForPreTraining) -> List[Dict[int, Tensor]]:
'''no shared params in bertmodel'''
pass
def replace_forward(self, module: Module) -> None:
module.model.forward = MethodType(partial(bert_for_pretraining_forward, stage_manager=self.stage_manager),
module.model)
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages
# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage

View File

@ -0,0 +1,153 @@
from functools import partial
from types import MethodType
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from transformers.models.bloom.modeling_bloom import BloomModel
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from .base import Policy
def bloom_model_forward(
self: BloomModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)

View File

@ -0,0 +1,112 @@
import pytest
import torch
import torch.distributed as dist
from transformers.models.bert.modeling_bert import BertModel
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.policy.bert import BertModelPolicy, bert_model_forward
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bert_model_forward():
model = BertModel.from_pretrained('bert-base-uncased')
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
#print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)
x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 768)).to(torch.float32)
if stage_manager.stage == 0:
attention_mask = torch.ones_like(x)
output = bert_model_forward(self=model, input_ids=x, attention_mask=attention_mask, stage_manager=stage_manager)
print(output[0].shape)
assert output[0].shape == (2, 3, 768)
print('start the training')
else:
attention_mask = torch.ones((2, 12, 3, 3))
output = bert_model_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output[0].shape)
assert output[0].shape == (2, 3, 768)
print('end the training')
print(output)
# assert output[1].shape == (2, 768)
def check_bert_model_policy():
model = BertModel.from_pretrained('bert-base-uncased')
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
#print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
model_policy = BertModelPolicy(stage_manager, len(model.encoder.layer), 2)
assert model_policy.layers_per_stage == [6, 6]
layers = model_policy.get_hold_layers(model)
for layer in layers:
print(layer)
def run_dist_model(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_model_forward()
def run_dist_policy(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bert_model_policy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_model_forward():
spawn(run_dist_model, 4)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bert_model_policy():
spawn(run_dist_policy, 4)
if __name__ == "__main__":
"""test the bert model forward and bert model policy"""
test_bert_model_forward()
test_bert_model_policy()