[example] upload auto parallel gpt2 demo (#2354)

pull/2365/head^2
YuliangLiu0306 2023-01-06 11:38:38 +08:00 committed by GitHub
parent 00a9c781fd
commit 8b1e0dfd80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 410 additions and 0 deletions

View File

@ -0,0 +1,44 @@
# Auto-Parallelism with GPT2
## Requirements
Before you can launch training, you need to install the following requirements.
### Install PyTorch
```bash
#conda
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
#pip
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
### Install [Colossal-AI v0.2.0](https://colossalai.org/download/) From Official Website
```bash
pip install colossalai==0.2.0+torch1.12cu11.3 -f https://release.colossalai.org
```
### Install transformers
```bash
pip install transformers
```
### Install pulp and coin-or-cbc
```bash
pip install pulp
conda install -c conda-forge coin-or-cbc
```
## Dataset
For simplicity, the input data is randonly generated here.
## Training
```bash
#Run the auto parallel resnet example with 4 GPUs with a dummy dataset.
colossalai run --nproc_per_node 4 auto_parallel_with_gpt.py
```

View File

@ -0,0 +1,109 @@
from functools import partial
from time import time
from typing import Dict, Optional, Tuple, Union
import psutil
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from gpt_modules import GPT2LMHeadModel, GPTLMLoss
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
BATCH_SIZE = 8
SEQ_LENGTH = 128
HIDDEN_DIM = 3072
NUM_HEADS = 16
NUM_LAYERS = 1
VOCAB_SIZE = 50257
NUM_STEPS = 10
FP16 = False
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def get_tflops(model_numel, batch_size, seq_len, step_time):
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
# Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def main():
disable_existing_loggers()
launch_from_torch(config={})
logger = get_dist_logger()
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
if FP16:
model = GPT2LMHeadModel(config=config).half().to('cuda')
else:
model = GPT2LMHeadModel(config=config).to('cuda')
global_numel = sum([p.numel() for p in model.parameters()])
meta_input_sample = {
'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
}
# Both device mesh initialization and model initialization will be integrated into autoparallelize
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# Enable auto-parallel
gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True)
# print solution on rank 0
if gpc.get_global_rank() == 0:
for node_strategy in solution:
print(node_strategy)
# build criterion
criterion = GPTLMLoss()
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH)
torch.cuda.synchronize()
model.train()
for n in range(10):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = gm(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
step_time = time() - start
logger.info(
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0])
torch.cuda.synchronize()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,253 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel
from transformers.pytorch_utils import Conv1D
class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states
# The reason Why we don't import GPT2Attention from transformers directly is that:
# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time,
# so we have to build the customized GPT2Attention class and remove the conditional branch manually.
# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new
# order is same as megatron-lm gpt model.
class GPT2Attention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions),
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
self.scale_attn_weights = config.scale_attn_weights
# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set()
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / (value.size(-1)**0.5)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
qkv = self.c_attn(hidden_states)
query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3)
present = (key, value)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
return attn_output
class GPT2Block(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
)
# residual connection
hidden_states = attn_outputs + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class GPT2Model(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
device = input_ids.device
past_length = 0
past_key_values = tuple([None] * len(self.h))
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask.
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
output_shape = input_shape + (hidden_states.size(-1),)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
hidden_states = outputs
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return hidden_states
class GPT2LMHeadModel(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
):
transformer_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
)
lm_logits = self.lm_head(transformer_outputs)
return lm_logits
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

View File

@ -0,0 +1,4 @@
colossalai >= 0.1.12
torch >= 1.8.1
transformers >= 4.231
PuLP >= 2.7.0