ColossalAI/examples/tutorial/sequence_parallel/model/bert.py

283 lines
12 KiB
Python

from colossalai.context.parallel_mode import ParallelMode
import torch
import torch.nn as nn
import inspect
from .layers import Embedding, BertLayer, BertDualHead, PreProcessor, VocabEmbedding
from .layers.init_method import init_normal, output_init_normal
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.kernel import LayerNorm
from colossalai.nn.layer.wrapper import PipelineSharedModuleWrapper
from colossalai.logging import get_dist_logger
from colossalai.pipeline.utils import partition_uniform
class BertForPretrain(nn.Module):
def __init__(self,
vocab_size,
hidden_size,
max_sequence_length,
num_attention_heads,
num_layers,
add_binary_head,
is_naive_fp16,
num_tokentypes=2,
dropout_prob=0.1,
mlp_ratio=4,
init_std=0.02,
convert_fp16_to_fp32_in_softmax=False,
):
super().__init__()
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
self.init_std = init_std
self.num_layers = num_layers
if not add_binary_head:
num_tokentypes = 0
self.preprocessor = PreProcessor(self.sub_seq_length)
self.embedding = Embedding(hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=dropout_prob,
num_tokentypes=num_tokentypes)
self.bert_layers = nn.ModuleList()
for i in range(num_layers):
bert_layer = BertLayer(layer_number=i+1,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=dropout_prob,
mlp_ratio=mlp_ratio,
hidden_dropout=dropout_prob,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
is_naive_fp16=is_naive_fp16
)
self.bert_layers.append(bert_layer)
self.layer_norm = LayerNorm(hidden_size)
self.head = BertDualHead(hidden_size, self.embedding.word_embedding_weight.size(0),
add_binary_head=add_binary_head)
self.reset_parameters()
def _init_normal(self, tensor):
init_normal(tensor, sigma=self.init_std)
def _output_init_normal(self, tensor):
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
def reset_parameters(self):
# initialize embedding
self._init_normal(self.embedding.word_embedding_weight)
self._init_normal(self.embedding.position_embeddings.weight)
if self.embedding.tokentype_embeddings:
self._init_normal(self.embedding.tokentype_embeddings.weight)
# initialize bert layer
for layer in self.bert_layers:
# initialize self attention
self._init_normal(layer.self_attention.query_key_value.weight)
self._output_init_normal(layer.self_attention.dense.weight)
self._init_normal(layer.mlp.dense_h_to_4h.weight)
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
# initializer head
self._init_normal(self.head.lm_head.dense.weight)
if self.head.binary_head is not None:
self._init_normal(self.head.binary_head.pooler.dense.weight)
self._init_normal(self.head.binary_head.dense.weight)
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
# inputs of the forward function
# input_ids: [batch_size, sub_seq_len]
# attention_mask: [batch_size, seq_len]
# tokentype_ids: [batch_size, sub_seq_len]
# outputs of preprocessor
# pos_ids: [batch_size, sub_seq_len]
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
# hidden_states shape change:
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
hidden_states = hidden_states.transpose(0, 1).contiguous()
for idx, layer in enumerate(self.bert_layers):
hidden_states = layer(hidden_states, attention_masks)
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.layer_norm(hidden_states)
# hidden_states: [sub_seq_len, batch_size, hidden_size]
# word_embedding: [vocab_size, hidden_size]
return self.head(output, self.embedding.word_embedding_weight, lm_labels)
class PipelineBertForPretrain(nn.Module):
def __init__(self,
vocab_size,
hidden_size,
max_sequence_length,
num_attention_heads,
num_layers,
add_binary_head,
is_naive_fp16,
num_tokentypes=2,
dropout_prob=0.1,
mlp_ratio=4,
init_std=0.02,
convert_fp16_to_fp32_in_softmax=False,
first_stage=True,
last_stage=True,
start_idx=None,
end_idx=None):
super().__init__()
self.seq_parallel_size = gpc.get_world_size(ParallelMode.SEQUENCE)
assert max_sequence_length % self.seq_parallel_size == 0, 'sequence length is not divisible by the sequence parallel size'
self.sub_seq_length = max_sequence_length // self.seq_parallel_size
self.init_std = init_std
self.num_layers = num_layers
if not add_binary_head:
num_tokentypes = 0
self.first_stage = first_stage
self.last_stage = last_stage
self.preprocessor = PreProcessor(self.sub_seq_length)
if self.first_stage:
self.embedding = Embedding(hidden_size=hidden_size,
vocab_size=vocab_size,
max_sequence_length=max_sequence_length,
embedding_dropout_prob=dropout_prob,
num_tokentypes=num_tokentypes)
# transformer layers
self.bert_layers = nn.ModuleList()
if start_idx is None and end_idx is None:
start_idx = 0
end_idx = num_layers
for i in range(start_idx, end_idx):
bert_layer = BertLayer(layer_number=i+1,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
attention_dropout=dropout_prob,
mlp_ratio=mlp_ratio,
hidden_dropout=dropout_prob,
convert_fp16_to_fp32_in_softmax=convert_fp16_to_fp32_in_softmax,
is_naive_fp16=is_naive_fp16
)
self.bert_layers.append(bert_layer)
if self.last_stage:
self.word_embeddings = VocabEmbedding(vocab_size, hidden_size)
self.layer_norm = LayerNorm(hidden_size)
self.head = BertDualHead(hidden_size, vocab_size,
add_binary_head=add_binary_head)
self.reset_parameters()
def _init_normal(self, tensor):
init_normal(tensor, sigma=self.init_std)
def _output_init_normal(self, tensor):
output_init_normal(tensor, sigma=self.init_std, num_layers=self.num_layers)
def reset_parameters(self):
# initialize embedding
if self.first_stage:
self._init_normal(self.embedding.word_embedding_weight)
self._init_normal(self.embedding.position_embeddings.weight)
if self.embedding.tokentype_embeddings:
self._init_normal(self.embedding.tokentype_embeddings.weight)
# initialize bert layer
for layer in self.bert_layers:
# initialize self attention
self._init_normal(layer.self_attention.query_key_value.weight)
self._output_init_normal(layer.self_attention.dense.weight)
self._init_normal(layer.mlp.dense_h_to_4h.weight)
self._output_init_normal(layer.mlp.dense_4h_to_h.weight)
# initializer head
if self.last_stage:
self._init_normal(self.head.lm_head.dense.weight)
if self.head.binary_head is not None:
self._init_normal(self.head.binary_head.pooler.dense.weight)
self._init_normal(self.head.binary_head.dense.weight)
def forward(self, input_ids, attention_masks, tokentype_ids, lm_labels):
# inputs of the forward function
# input_ids: [batch_size, sub_seq_len]
# attention_mask: [batch_size, seq_len]
# tokentype_ids: [batch_size, sub_seq_len]
# outputs of preprocessor
# pos_ids: [batch_size, sub_seq_len]
# attention_masks: [batch_size, 1, sub_seq_len, seq_len]
if self.first_stage:
pos_ids, attention_masks = self.preprocessor(input_ids, attention_masks)
else:
_, attention_masks = self.preprocessor(None, attention_masks)
if self.first_stage:
hidden_states = self.embedding(input_ids, pos_ids, tokentype_ids)
hidden_states = hidden_states.transpose(0, 1).contiguous()
else:
hidden_states = input_ids
# hidden_states shape change:
# [batch_size, sub_seq_len, hidden_size] -> [sub_seq_len, batch_size, hidden_size]
for idx, layer in enumerate(self.bert_layers):
hidden_states = layer(hidden_states, attention_masks)
if self.last_stage:
hidden_states = hidden_states.transpose(0, 1).contiguous()
output = self.layer_norm(hidden_states)
output = self.head(output, self.word_embeddings.weight, lm_labels)
else:
output = hidden_states
# hidden_states: [sub_seq_len, batch_size, hidden_size]
# word_embedding: [vocab_size, hidden_size]
return output
def _filter_kwargs(func, kwargs):
sig = inspect.signature(func)
return {k: v for k, v in kwargs.items() if k in sig.parameters}
def build_pipeline_bert(num_layers, num_chunks, device=torch.device('cuda'), **kwargs):
logger = get_dist_logger()
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
rank = gpc.get_global_rank()
wrapper = PipelineSharedModuleWrapper([0, pipeline_size - 1])
parts = partition_uniform(num_layers, pipeline_size, num_chunks)[pipeline_rank]
models = []
for start, end in parts:
kwargs['num_layers'] = num_layers
kwargs['start_idx'] = start
kwargs['end_idx'] = end
kwargs['first_stage'] = start == 0
kwargs['last_stage'] = end == num_layers
logger.info(f'Rank{rank} build layer {start}-{end}, {end-start}/{num_layers} layers')
chunk = PipelineBertForPretrain(**_filter_kwargs(PipelineBertForPretrain.__init__, kwargs)).to(device)
if start == 0:
wrapper.register_module(chunk.embedding.word_embeddings)
elif end == num_layers:
wrapper.register_module(chunk.word_embeddings)
models.append(chunk)
if len(models) == 1:
model = models[0]
else:
model = nn.ModuleList(models)
return model