[pipeline] build bloom model and policy , revise the base class of policy (#4161)

* add pipeline policy and bert forward to be done

* add bertmodel pipeline forward and make tests

* add Bert_Policy and test for policy

* update formatting

* update formatting

* update the code

* fix bugs

* fix name confilt

* add bloom model and policy ,revise the base class of policy

* revise

* revision

* add bert_for_pretraining
pull/4445/head
Jianghai 2023-07-05 10:52:53 +08:00 committed by Hongxin Liu
parent c552cefa93
commit 90a65ea682
5 changed files with 286 additions and 80 deletions

View File

@ -1,13 +1,15 @@
from typing import Any, Dict, List, Optional, Tuple
from colossalai.lazy import LazyTensor
import numpy as np
from torch import Tensor
from torch.nn import Module, Parameter
from colossalai.lazy import LazyTensor
from colossalai.pipeline.stage_manager import PipelineStageManager
class Policy:
def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager
@ -93,7 +95,8 @@ class Policy:
"""
raise NotImplementedError
def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
def parallelize_model(self,
module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
"""Parallelize model for pipeline parallel
Args:
@ -106,3 +109,33 @@ class Policy:
self.replace_forward(module)
shared_params = self.get_shared_params(module)
return hold_params, hold_buffers, shared_params
@staticmethod
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages
# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
@staticmethod
def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]:
"""
get the start index and end index of layers for each stage.
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
start_idx = num_layers_per_stage_accumulated[stage]
end_idx = num_layers_per_stage_accumulated[stage + 1]
return [start_idx, end_idx]

View File

@ -22,25 +22,26 @@ logger = logging.get_logger(__name__)
def bert_model_forward(
self: BertModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
#labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage
self: BertModel,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
# labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
# this is from the previous stage
hidden_states: Optional[torch.FloatTensor] = None,
):
#TODO: add explaination of the output here.
# TODO: add explaination of the output here.
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
@ -93,6 +94,7 @@ def bert_model_forward(
batch_size, seq_length = input_shape
device = hidden_states.device
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
@ -144,7 +146,7 @@ def bert_model_forward(
else:
encoder_extended_attention_mask = None
#inherit from bert_layer
# inherit from bert_layer
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
@ -156,12 +158,12 @@ def bert_model_forward(
use_cache = False
next_decoder_cache = () if use_cache else None
#calculate the num_layers
# calculate the num_layers
num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages
start_layer = stage_manager.stage * num_layers_per_stage
end_layer = (stage_manager.stage + 1) * num_layers_per_stage
#layer_outputs
# layer_outputs
layer_outputs = hidden_states if hidden_states is not None else None
for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer):
if stage_manager.is_first_stage() and idx == 0:
@ -206,12 +208,13 @@ def bert_model_forward(
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
all_cross_attentions = all_cross_attentions + \
(layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
#end of a stage loop
# end of a stage loop
sequence_output = layer_outputs[0] if layer_outputs is not None else None
if stage_manager.is_last_stage():
@ -219,7 +222,7 @@ def bert_model_forward(
if not return_dict:
return (sequence_output, pooled_output) + layer_outputs[1:]
#output of non-first and non-last stages:
# output of non-first and non-last stages:
if not return_dict:
return tuple(v for v in [
hidden_states,
@ -229,7 +232,7 @@ def bert_model_forward(
all_cross_attentions,
] if v is not None)
#return dict is not supported at this moment
# return dict is not supported at this moment
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
@ -243,6 +246,7 @@ def bert_model_forward(
class BertModelPolicy(Policy):
def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
super().__init__(stage_manager=stage_manager)
self.stage_manager = stage_manager
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)
@ -253,11 +257,8 @@ class BertModelPolicy(Policy):
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.embeddings)
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
num_layers_per_stage_accumulated[self.stage_manager.stage]])
start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
hold_layers.extend(module.encoder.layer[start_idx:end_idx])
if self.stage_manager.is_last_stage():
hold_layers.append(module.pooler)
@ -270,23 +271,6 @@ class BertModelPolicy(Policy):
def replace_forward(self, module: Module) -> None:
module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model)
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""
divide layers into stages
"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages
# deal with the rest layers
if remainder > 0:
start_position = num_layers // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def bert_for_pretraining_forward(
self: BertForPreTraining,
@ -306,8 +290,8 @@ def bert_for_pretraining_forward(
) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.bert(
outputs = bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
@ -320,7 +304,8 @@ def bert_for_pretraining_forward(
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if stage_manager.is_last_stage():
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
total_loss = None
if labels is not None and next_sentence_label is not None:
@ -355,11 +340,12 @@ class BertForPreTrainingPolicy(Policy):
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.bert.embeddings)
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
num_layers_per_stage_accumulated[self.stage_manager.stage]])
start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
if self.stage_manager.is_last_stage():
hold_layers.append(module.bert.pooler)
hold_layers.append(module.cls)
return hold_layers

View File

@ -1,3 +1,4 @@
import warnings
from functools import partial
from types import MethodType
from typing import Dict, List, Optional, Tuple, Union
@ -14,6 +15,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from .base import Policy
logger = logging.get_logger(__name__)
def bloom_model_forward(
self: BloomModel,
@ -26,6 +29,8 @@ def bloom_model_forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
@ -44,28 +49,45 @@ def bloom_model_forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# add warnings here
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# case: First stage of training
if stage_manager.is_first_stage():
# check input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
# initialize in the first stage and then pass to the next stage
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
# extra recording tensor should be generated in the first stage
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
@ -77,11 +99,14 @@ def bloom_model_forward(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# Compute alibi tensor: check build_alibi_tensor documentation
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
past_key_values_length = past_key_values[0][0].shape[2] # source_len
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
@ -90,13 +115,19 @@ def bloom_model_forward(
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
# causal_mask is constructed every stage and its input is passed through different stages
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# calculate the num_layers
num_layers_per_stage = len(self.h) // stage_manager.num_stages
start_layer = stage_manager.stage * num_layers_per_stage
end_layer = (stage_manager.stage + 1) * num_layers_per_stage
for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -130,24 +161,60 @@ def bloom_model_forward(
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
all_self_attentions = all_self_attentions + \
(outputs[2 if use_cache else 1],)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if stage_manager.is_last_stage():
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
# TODO: deal with all_hidden_states, all_self_attentions, presents
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
# attention_mask is not returned ; presents = past_key_values
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class BloomModelPolicy(Policy):
def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
super().__init__(stage_manager=stage_manager)
self.stage_manager = stage_manager
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)
def get_hold_layers(self, module: BloomModel) -> List[Module]:
"""
get pipeline layers for current stage
"""
hold_layers = []
if self.stage_manager.is_first_stage():
hold_layers.append(module.word_embeddings)
hold_layers.append(module.word_embeddings_layernorm)
start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
hold_layers.extend(module.h[start_idx:end_idx])
if self.stage_manager.is_last_stage():
hold_layers.append(module.ln_f)
return hold_layers
def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]:
'''no shared params in bloommodel'''
pass
def replace_forward(self, module: Module) -> None:
module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model)

View File

@ -27,7 +27,8 @@ def check_bert_model_forward():
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
#print(pg_mesh)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
@ -72,7 +73,7 @@ def check_bert_model_policy():
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
#print(pg_mesh)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()

View File

@ -0,0 +1,119 @@
import pytest
import torch
import torch.distributed as dist
from transformers.models.bloom import BloomConfig, BloomModel
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.policy.bloom import BloomModelPolicy, bloom_model_forward
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_bloom_model_forward():
# create a BloomModel
configuration = BloomConfig()
model = BloomModel(configuration)
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# print(rank)
x = torch.randint(0, 1000, (2, 3))
hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32)
if stage_manager.is_first_stage():
attention_mask = torch.ones_like(x)
output = bloom_model_forward(self=model,
input_ids=x,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output[0].shape)
assert output[0].shape == (2, 3, 64)
print('start the training')
else:
attention_mask = torch.ones((2, 3))
output = bloom_model_forward(self=model,
hidden_states=hidden_states,
attention_mask=attention_mask,
stage_manager=stage_manager)
print(output[0].shape)
assert output[0].shape == (2, 3, 64)
print('end the training')
print(output)
# assert output[1].shape == (2, 768)
def check_bloom_model_policy():
# create a BloomModel
configuration = BloomConfig()
model = BloomModel(configuration)
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
# print(pg_mesh)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
model_policy = BloomModelPolicy(stage_manager=stage_manager, num_layers=len(model.h), num_stages=2)
assert model_policy.layers_per_stage == [1, 1]
layers = model_policy.get_hold_layers(model)
for layer in layers:
print(layer)
def run_dist_model(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bloom_model_forward()
def run_dist_policy(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_bloom_model_policy()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bloom_model_forward():
spawn(run_dist_model, 4)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_bloom_model_policy():
spawn(run_dist_policy, 4)
if __name__ == "__main__":
"""test the bloom model forward and bloom model policy"""
test_bloom_model_forward()
test_bloom_model_policy()