mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] build bloom model and policy , revise the base class of policy (#4161)
* add pipeline policy and bert forward to be done * add bertmodel pipeline forward and make tests * add Bert_Policy and test for policy * update formatting * update formatting * update the code * fix bugs * fix name confilt * add bloom model and policy ,revise the base class of policy * revise * revision * add bert_for_pretrainingpull/4445/head
parent
c552cefa93
commit
90a65ea682
|
@ -1,13 +1,15 @@
|
|||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from colossalai.lazy import LazyTensor
|
||||
import numpy as np
|
||||
from torch import Tensor
|
||||
from torch.nn import Module, Parameter
|
||||
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
|
||||
class Policy:
|
||||
|
||||
def __init__(self, stage_manager: PipelineStageManager) -> None:
|
||||
self.stage_manager = stage_manager
|
||||
|
||||
|
@ -93,7 +95,8 @@ class Policy:
|
|||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def parallelize_model(self, module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
|
||||
def parallelize_model(self,
|
||||
module: Module) -> Tuple[Dict[str, Parameter], Dict[str, Tensor], List[Dict[int, Tensor]]]:
|
||||
"""Parallelize model for pipeline parallel
|
||||
|
||||
Args:
|
||||
|
@ -106,3 +109,33 @@ class Policy:
|
|||
self.replace_forward(module)
|
||||
shared_params = self.get_shared_params(module)
|
||||
return hold_params, hold_buffers, shared_params
|
||||
|
||||
@staticmethod
|
||||
def distribute_layers(num_layers: int, num_stages: int) -> List[int]:
|
||||
"""
|
||||
divide layers into stages
|
||||
"""
|
||||
quotient = num_layers // num_stages
|
||||
remainder = num_layers % num_stages
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = num_layers // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
||||
@staticmethod
|
||||
def get_stage_index(layers_per_stage: List[int], stage: int) -> List[int]:
|
||||
"""
|
||||
get the start index and end index of layers for each stage.
|
||||
"""
|
||||
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
|
||||
|
||||
start_idx = num_layers_per_stage_accumulated[stage]
|
||||
end_idx = num_layers_per_stage_accumulated[stage + 1]
|
||||
|
||||
return [start_idx, end_idx]
|
||||
|
|
|
@ -22,25 +22,26 @@ logger = logging.get_logger(__name__)
|
|||
|
||||
|
||||
def bert_model_forward(
|
||||
self: BertModel,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
#labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None, #this is from the previous stage
|
||||
self: BertModel,
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
# labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
# this is from the previous stage
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
#TODO: add explaination of the output here.
|
||||
# TODO: add explaination of the output here.
|
||||
r"""
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||||
|
@ -93,6 +94,7 @@ def bert_model_forward(
|
|||
batch_size, seq_length = input_shape
|
||||
device = hidden_states.device
|
||||
|
||||
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
|
@ -144,7 +146,7 @@ def bert_model_forward(
|
|||
else:
|
||||
encoder_extended_attention_mask = None
|
||||
|
||||
#inherit from bert_layer
|
||||
# inherit from bert_layer
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||
|
@ -156,12 +158,12 @@ def bert_model_forward(
|
|||
use_cache = False
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
#calculate the num_layers
|
||||
# calculate the num_layers
|
||||
num_layers_per_stage = len(self.encoder.layer) // stage_manager.num_stages
|
||||
start_layer = stage_manager.stage * num_layers_per_stage
|
||||
end_layer = (stage_manager.stage + 1) * num_layers_per_stage
|
||||
|
||||
#layer_outputs
|
||||
# layer_outputs
|
||||
layer_outputs = hidden_states if hidden_states is not None else None
|
||||
for idx, encoder_layer in enumerate(self.encoder.layer[start_layer:end_layer], start=start_layer):
|
||||
if stage_manager.is_first_stage() and idx == 0:
|
||||
|
@ -206,12 +208,13 @@ def bert_model_forward(
|
|||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
if self.config.add_cross_attention:
|
||||
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
||||
all_cross_attentions = all_cross_attentions + \
|
||||
(layer_outputs[2],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
#end of a stage loop
|
||||
# end of a stage loop
|
||||
sequence_output = layer_outputs[0] if layer_outputs is not None else None
|
||||
|
||||
if stage_manager.is_last_stage():
|
||||
|
@ -219,7 +222,7 @@ def bert_model_forward(
|
|||
if not return_dict:
|
||||
return (sequence_output, pooled_output) + layer_outputs[1:]
|
||||
|
||||
#output of non-first and non-last stages:
|
||||
# output of non-first and non-last stages:
|
||||
if not return_dict:
|
||||
return tuple(v for v in [
|
||||
hidden_states,
|
||||
|
@ -229,7 +232,7 @@ def bert_model_forward(
|
|||
all_cross_attentions,
|
||||
] if v is not None)
|
||||
|
||||
#return dict is not supported at this moment
|
||||
# return dict is not supported at this moment
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_decoder_cache,
|
||||
|
@ -243,6 +246,7 @@ def bert_model_forward(
|
|||
class BertModelPolicy(Policy):
|
||||
|
||||
def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
|
||||
super().__init__(stage_manager=stage_manager)
|
||||
self.stage_manager = stage_manager
|
||||
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)
|
||||
|
||||
|
@ -253,11 +257,8 @@ class BertModelPolicy(Policy):
|
|||
hold_layers = []
|
||||
if self.stage_manager.is_first_stage():
|
||||
hold_layers.append(module.embeddings)
|
||||
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
|
||||
hold_layers.extend(module.encoder.layer[num_layers_per_stage_accumulated \
|
||||
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
|
||||
num_layers_per_stage_accumulated[self.stage_manager.stage]])
|
||||
|
||||
start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
|
||||
hold_layers.extend(module.encoder.layer[start_idx:end_idx])
|
||||
if self.stage_manager.is_last_stage():
|
||||
hold_layers.append(module.pooler)
|
||||
|
||||
|
@ -270,23 +271,6 @@ class BertModelPolicy(Policy):
|
|||
def replace_forward(self, module: Module) -> None:
|
||||
module.model.forward = MethodType(partial(bert_model_forward, stage_manager=self.stage_manager), module.model)
|
||||
|
||||
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
|
||||
"""
|
||||
divide layers into stages
|
||||
"""
|
||||
quotient = num_layers // num_stages
|
||||
remainder = num_layers % num_stages
|
||||
|
||||
# calculate the num_layers per stage
|
||||
layers_per_stage = [quotient] * num_stages
|
||||
|
||||
# deal with the rest layers
|
||||
if remainder > 0:
|
||||
start_position = num_layers // 2 - remainder // 2
|
||||
for i in range(start_position, start_position + remainder):
|
||||
layers_per_stage[i] += 1
|
||||
return layers_per_stage
|
||||
|
||||
|
||||
def bert_for_pretraining_forward(
|
||||
self: BertForPreTraining,
|
||||
|
@ -306,8 +290,8 @@ def bert_for_pretraining_forward(
|
|||
) -> Union[Tuple[torch.Tensor], BertForPreTrainingOutput]:
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
outputs = self.bert(
|
||||
outputs = bert_model_forward(
|
||||
self.bert,
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
|
@ -320,7 +304,8 @@ def bert_for_pretraining_forward(
|
|||
)
|
||||
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
if stage_manager.is_last_stage():
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
|
||||
total_loss = None
|
||||
if labels is not None and next_sentence_label is not None:
|
||||
|
@ -355,11 +340,12 @@ class BertForPreTrainingPolicy(Policy):
|
|||
hold_layers = []
|
||||
if self.stage_manager.is_first_stage():
|
||||
hold_layers.append(module.bert.embeddings)
|
||||
num_layers_per_stage_accumulated = np.cumsum(self.layers_per_stage)
|
||||
hold_layers.extend(module.bert.encoder.layer[num_layers_per_stage_accumulated \
|
||||
[self.stage_manager.stage-1] if self.stage_manager.stage > 0 else 0:
|
||||
num_layers_per_stage_accumulated[self.stage_manager.stage]])
|
||||
|
||||
start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
|
||||
hold_layers.extend(module.bert.encoder.layer[start_idx:end_idx])
|
||||
|
||||
if self.stage_manager.is_last_stage():
|
||||
hold_layers.append(module.bert.pooler)
|
||||
hold_layers.append(module.cls)
|
||||
|
||||
return hold_layers
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
@ -14,6 +15,8 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
|
||||
from .base import Policy
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def bloom_model_forward(
|
||||
self: BloomModel,
|
||||
|
@ -26,6 +29,8 @@ def bloom_model_forward(
|
|||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
if deprecated_arguments.pop("position_ids", False) is not False:
|
||||
|
@ -44,28 +49,45 @@ def bloom_model_forward(
|
|||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
|
||||
# add warnings here
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
output_hidden_states = False
|
||||
if use_cache:
|
||||
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
|
||||
use_cache = False
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape batch_size x num_heads x N x N
|
||||
|
||||
# head_mask has shape n_layer x batch x num_heads x N x N
|
||||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
# case: First stage of training
|
||||
if stage_manager.is_first_stage():
|
||||
# check input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
elif input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
|
||||
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
|
||||
# initialize in the first stage and then pass to the next stage
|
||||
else:
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
batch_size, seq_length = input_shape
|
||||
|
||||
# extra recording tensor should be generated in the first stage
|
||||
|
||||
presents = () if use_cache else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
@ -77,11 +99,14 @@ def bloom_model_forward(
|
|||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
use_cache = False
|
||||
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * len(self.h))
|
||||
# Compute alibi tensor: check build_alibi_tensor documentation,build for every stage
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
if past_key_values[0] is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
past_key_values_length = past_key_values[0][0].shape[2] # source_len
|
||||
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
|
||||
|
@ -90,13 +115,19 @@ def bloom_model_forward(
|
|||
|
||||
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
|
||||
|
||||
# causal_mask is constructed every stage and its input is passed through different stages
|
||||
causal_mask = self._prepare_attn_mask(
|
||||
attention_mask,
|
||||
input_shape=(batch_size, seq_length),
|
||||
past_key_values_length=past_key_values_length,
|
||||
)
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
||||
# calculate the num_layers
|
||||
num_layers_per_stage = len(self.h) // stage_manager.num_stages
|
||||
start_layer = stage_manager.stage * num_layers_per_stage
|
||||
end_layer = (stage_manager.stage + 1) * num_layers_per_stage
|
||||
|
||||
for i, (block, layer_past) in enumerate(zip(self.h[start_layer:end_layer], past_key_values[start_layer:end_layer])):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
|
@ -130,24 +161,60 @@ def bloom_model_forward(
|
|||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
|
||||
if use_cache is True:
|
||||
presents = presents + (outputs[1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
||||
all_self_attentions = all_self_attentions + \
|
||||
(outputs[2 if use_cache else 1],)
|
||||
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
if stage_manager.is_last_stage():
|
||||
# Add last hidden state
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
# TODO: deal with all_hidden_states, all_self_attentions, presents
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
||||
|
||||
# attention_mask is not returned ; presents = past_key_values
|
||||
return BaseModelOutputWithPastAndCrossAttentions(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=presents,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attentions,
|
||||
)
|
||||
|
||||
|
||||
class BloomModelPolicy(Policy):
|
||||
|
||||
def __init__(self, stage_manager: PipelineStageManager, num_layers: int, num_stages: int):
|
||||
super().__init__(stage_manager=stage_manager)
|
||||
self.stage_manager = stage_manager
|
||||
self.layers_per_stage = self.distribute_layers(num_layers, num_stages)
|
||||
|
||||
def get_hold_layers(self, module: BloomModel) -> List[Module]:
|
||||
"""
|
||||
get pipeline layers for current stage
|
||||
"""
|
||||
hold_layers = []
|
||||
if self.stage_manager.is_first_stage():
|
||||
hold_layers.append(module.word_embeddings)
|
||||
hold_layers.append(module.word_embeddings_layernorm)
|
||||
|
||||
start_idx, end_idx = self.get_stage_index(self.layers_per_stage, self.stage_manager.stage)
|
||||
hold_layers.extend(module.h[start_idx:end_idx])
|
||||
|
||||
if self.stage_manager.is_last_stage():
|
||||
hold_layers.append(module.ln_f)
|
||||
|
||||
return hold_layers
|
||||
|
||||
def get_shared_params(self, module: BloomModel) -> List[Dict[int, Tensor]]:
|
||||
'''no shared params in bloommodel'''
|
||||
pass
|
||||
|
||||
def replace_forward(self, module: Module) -> None:
|
||||
module.forward = MethodType(partial(bloom_model_forward, stage_manager=self.stage_manager), module.model)
|
||||
|
|
|
@ -27,7 +27,8 @@ def check_bert_model_forward():
|
|||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
#print(pg_mesh)
|
||||
|
||||
# print(pg_mesh)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
|
@ -72,7 +73,7 @@ def check_bert_model_policy():
|
|||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
#print(pg_mesh)
|
||||
# print(pg_mesh)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from transformers.models.bloom import BloomConfig, BloomModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.pipeline.policy.bloom import BloomModelPolicy, bloom_model_forward
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_bloom_model_forward():
|
||||
# create a BloomModel
|
||||
configuration = BloomConfig()
|
||||
model = BloomModel(configuration)
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
# print(pg_mesh)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
# print(rank)
|
||||
|
||||
x = torch.randint(0, 1000, (2, 3))
|
||||
hidden_states = torch.randint(0, 1000, (2, 3, 64)).to(torch.float32)
|
||||
if stage_manager.is_first_stage():
|
||||
attention_mask = torch.ones_like(x)
|
||||
output = bloom_model_forward(self=model,
|
||||
input_ids=x,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 64)
|
||||
print('start the training')
|
||||
else:
|
||||
attention_mask = torch.ones((2, 3))
|
||||
output = bloom_model_forward(self=model,
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
stage_manager=stage_manager)
|
||||
print(output[0].shape)
|
||||
assert output[0].shape == (2, 3, 64)
|
||||
print('end the training')
|
||||
print(output)
|
||||
|
||||
# assert output[1].shape == (2, 768)
|
||||
|
||||
|
||||
def check_bloom_model_policy():
|
||||
# create a BloomModel
|
||||
configuration = BloomConfig()
|
||||
model = BloomModel(configuration)
|
||||
DP_DIM, PP_DIM = 0, 1
|
||||
DP_SIZE, PP_SIZE = 2, 2
|
||||
RANK_TO_COORDINATE = {
|
||||
0: (0, 0),
|
||||
1: (0, 1),
|
||||
2: (1, 0),
|
||||
3: (1, 1),
|
||||
}
|
||||
PP_RANKS_IN_GROUP = {
|
||||
0: [0, 1],
|
||||
1: [0, 1],
|
||||
2: [2, 3],
|
||||
3: [2, 3],
|
||||
}
|
||||
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
|
||||
# print(pg_mesh)
|
||||
|
||||
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
|
||||
rank = dist.get_rank()
|
||||
|
||||
model_policy = BloomModelPolicy(stage_manager=stage_manager, num_layers=len(model.h), num_stages=2)
|
||||
assert model_policy.layers_per_stage == [1, 1]
|
||||
layers = model_policy.get_hold_layers(model)
|
||||
for layer in layers:
|
||||
print(layer)
|
||||
|
||||
|
||||
def run_dist_model(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_bloom_model_forward()
|
||||
|
||||
|
||||
def run_dist_policy(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
check_bloom_model_policy()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bloom_model_forward():
|
||||
spawn(run_dist_model, 4)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_bloom_model_policy():
|
||||
spawn(run_dist_policy, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
"""test the bloom model forward and bloom model policy"""
|
||||
test_bloom_model_forward()
|
||||
test_bloom_model_policy()
|
Loading…
Reference in New Issue