[example] make gpt example directory more clear (#2353)

pull/2357/head
Jiarui Fang 2023-01-06 11:11:26 +08:00 committed by GitHub
parent 7027540d3d
commit 509a87f3ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 26 additions and 584 deletions

View File

@ -57,12 +57,6 @@ The `train_gpt_demo.py` provides three distributed plans, you can choose the pla
- Pytorch ZeRO - Pytorch ZeRO
### Pipeline Parallel
```bash
bash run_pp.sh
```
## Performance ## Performance
Testbed: a cluster of 8xA100 (80GB) and 1xAMD EPYC 7543 32-Core Processor (512 GB). GPUs are connected via PCI-e. Testbed: a cluster of 8xA100 (80GB) and 1xAMD EPYC 7543 32-Core Processor (512 GB). GPUs are connected via PCI-e.
@ -119,3 +113,9 @@ Touch the bar of model scale and batch size.
| model | #GPU | policy | TP | batch per DP | Tflops | | model | #GPU | policy | TP | batch per DP | Tflops |
| ---------- | --------- |--------- |--------- |--------- |--------- | | ---------- | --------- |--------- |--------- |--------- |--------- |
| gpt2_20b | 8 | cpu | 2 | 8 | 46.895 | | gpt2_20b | 8 | cpu | 2 | 8 | 46.895 |
### Experimental Features
#### [Pipeline Parallel](./experiments/pipeline_parallel/)
#### [Auto Parallel](./experiments/auto_parallel_with_gpt/)

View File

@ -1,44 +0,0 @@
# Auto-Parallelism with GPT2
## Requirements
Before you can launch training, you need to install the following requirements.
### Install PyTorch
```bash
#conda
conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit=11.3 -c pytorch
#pip
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
```
### Install [Colossal-AI v0.1.12](https://colossalai.org/download/) From Official Website
```bash
pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
```
### Install transformers
```bash
pip install transformers
```
### Install pulp and coin-or-cbc
```bash
pip install pulp
conda install -c conda-forge coin-or-cbc
```
## Dataset
For simplicity, the input data is randonly generated here.
## Training
```bash
#Run the auto parallel resnet example with 4 GPUs with a dummy dataset.
colossalai run --nproc_per_node 4 auto_parallel_with_gpt.py
```

View File

@ -1,109 +0,0 @@
from functools import partial
from time import time
from typing import Dict, Optional, Tuple, Union
import psutil
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import transformers
from gpt_modules import GPT2LMHeadModel, GPTLMLoss
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize, initialize_model
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch_from_torch
from colossalai.logging import disable_existing_loggers, get_dist_logger
BATCH_SIZE = 8
SEQ_LENGTH = 128
HIDDEN_DIM = 3072
NUM_HEADS = 16
NUM_LAYERS = 1
VOCAB_SIZE = 50257
NUM_STEPS = 10
FP16 = False
def get_cpu_mem():
return psutil.Process().memory_info().rss / 1024**2
def get_gpu_mem():
return torch.cuda.memory_allocated() / 1024**2
def get_mem_info(prefix=''):
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
def get_tflops(model_numel, batch_size, seq_len, step_time):
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) / 4
# Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def main():
disable_existing_loggers()
launch_from_torch(config={})
logger = get_dist_logger()
config = transformers.GPT2Config(n_position=SEQ_LENGTH, n_layer=NUM_LAYERS, n_head=NUM_HEADS, n_embd=HIDDEN_DIM)
if FP16:
model = GPT2LMHeadModel(config=config).half().to('cuda')
else:
model = GPT2LMHeadModel(config=config).to('cuda')
global_numel = sum([p.numel() for p in model.parameters()])
meta_input_sample = {
'input_ids': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
'attention_mask': torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64).to('meta'),
}
# Both device mesh initialization and model initialization will be integrated into autoparallelize
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# Enable auto-parallel
gm, solution = initialize_model(model, meta_input_sample, device_mesh, return_solution=True)
# print solution on rank 0
if gpc.get_global_rank() == 0:
for node_strategy in solution:
print(node_strategy)
# build criterion
criterion = GPTLMLoss()
optimizer = torch.optim.Adam(gm.parameters(), lr=0.01)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
get_tflops_func = partial(get_tflops, global_numel, BATCH_SIZE, SEQ_LENGTH)
torch.cuda.synchronize()
model.train()
for n in range(10):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LENGTH, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = gm(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
loss.backward()
optimizer.step()
torch.cuda.synchronize()
step_time = time() - start
logger.info(
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0])
torch.cuda.synchronize()
if __name__ == '__main__':
main()

View File

@ -1,253 +0,0 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
from transformers.models.gpt2.modeling_gpt2 import BaseModelOutputWithPastAndCrossAttentions, GPT2PreTrainedModel
from transformers.pytorch_utils import Conv1D
class GPT2MLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size
self.c_fc = Conv1D(intermediate_size, embed_dim)
self.c_proj = Conv1D(embed_dim, intermediate_size)
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
return hidden_states
# The reason Why we don't import GPT2Attention from transformers directly is that:
# 1. The tracer will not work correctly when we feed meta_args and concrete_args at same time,
# so we have to build the customized GPT2Attention class and remove the conditional branch manually.
# 2. The order of split and view op has been changed in the customized GPT2Attention class, the new
# order is same as megatron-lm gpt model.
class GPT2Attention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions),
dtype=torch.uint8)).view(1, 1, max_positions, max_positions),
)
self.register_buffer("masked_bias", torch.tensor(-1e4))
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
self.scale_attn_weights = config.scale_attn_weights
# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
self.pruned_heads = set()
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / (value.size(-1)**0.5)
# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)
# if only "normal" attention layer implements causal mask
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].to(torch.bool)
attn_weights = torch.where(causal_mask, attn_weights, self.masked_bias.to(attn_weights.dtype))
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(value.dtype)
# Mask heads if we want to
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.permute(0, 2, 1, 3).contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
qkv = self.c_attn(hidden_states)
query, key, value = self._split_heads(qkv, self.num_heads, 3 * self.head_dim).split(self.head_dim, dim=3)
present = (key, value)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
return attn_output
class GPT2Block(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPT2Attention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPT2MLP(inner_dim, config)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
)
# residual connection
hidden_states = attn_outputs + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class GPT2Model(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
device = input_ids.device
past_length = 0
past_key_values = tuple([None] * len(self.h))
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask.
attention_mask = attention_mask.view(batch_size, -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
output_shape = input_shape + (hidden_states.size(-1),)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
outputs = block(hidden_states, attention_mask=attention_mask, head_mask=head_mask[i])
hidden_states = outputs
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
return hidden_states
class GPT2LMHeadModel(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = GPT2Model(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
):
transformer_outputs = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
)
lm_logits = self.lm_head(transformer_outputs)
return lm_logits
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

View File

@ -1,4 +0,0 @@
colossalai >= 0.1.12
torch >= 1.8.1
transformers >= 4.231
PuLP >= 2.7.0

View File

@ -9,7 +9,7 @@ for MODEL_TYPE in "gpt2_medium"; do
echo "****************** Begin ***************************" echo "****************** Begin ***************************"
echo "* benchmrking MODEL_TYPE ${MODEL_TYPE} BS ${BATCH_SIZE} BS ${BS} GPUNUM ${GPUNUM} TPDEGREE ${TPDEGREE} PLACEMENT ${PLACEMENT}" echo "* benchmrking MODEL_TYPE ${MODEL_TYPE} BS ${BATCH_SIZE} BS ${BS} GPUNUM ${GPUNUM} TPDEGREE ${TPDEGREE} PLACEMENT ${PLACEMENT}"
MODEL_TYPE=${MODEL_TYPE} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \ MODEL_TYPE=${MODEL_TYPE} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
bash ./run_gemini.sh bash ./gemini/run_gemini.sh
echo "****************** Finished ***************************" echo "****************** Finished ***************************"
echo "" echo ""
echo "" echo ""

View File

@ -0,0 +1,12 @@
import torch
# Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)

View File

@ -1,3 +1,4 @@
set -x
# distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"] # distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
export DISTPAN=${DISTPAN:-"colossalai"} export DISTPAN=${DISTPAN:-"colossalai"}
@ -9,8 +10,11 @@ export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
export BATCH_SIZE=${BATCH_SIZE:-16} export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
# export PYTHONPATH=$PWD:$PYTHONPATH
mkdir -p gemini_logs mkdir -p gemini_logs
torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py \
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--tp_degree=${TPDEGREE} \ --tp_degree=${TPDEGREE} \
--model_type=${MODEL_TYPE} \ --model_type=${MODEL_TYPE} \
--batch_size=${BATCH_SIZE} \ --batch_size=${BATCH_SIZE} \

View File

@ -5,9 +5,10 @@ from time import time
import psutil import psutil
import torch import torch
import torch.nn as nn import torch.nn as nn
from commons.model_zoo import model_builder
from commons.utils import get_data, get_tflops
from packaging import version from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from utils import get_data, get_tflops
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
@ -15,7 +16,6 @@ from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from model_zoo import model_builder
CAI_VERSION = colossalai.__version__ CAI_VERSION = colossalai.__version__

View File

@ -1,7 +0,0 @@
export GPUNUM=${GPUNUM:-2}
export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export NUM_MICROBATCH=${NUM_MICROBATCH:-4}
mkdir -p pp_logs
python train_gpt_pp_demo.py --device="cuda" --model_type=${MODEL_TYPE} --num_microbatches=${NUM_MICROBATCH} --world_size=${GPUNUM} --batch_size=${BATCH_SIZE} 2>&1 | tee ./pp_logs/${MODEL_TYPE}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_nm_${NUM_MICROBATCH}.log

View File

@ -1,157 +0,0 @@
import argparse
import time
from functools import partial
import torch
from model_zoo import model_builder
from torch import nn
from tqdm import tqdm
from utils import get_data, get_tflops
from colossalai.fx import ColoTracer
from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.pipeline.middleware.adaptor import get_fx_topology
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine
from colossalai.pipeline.rpc.utils import rpc_run
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, default="gpt2_medium")
parser.add_argument('--world_size', type=int, default=2)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--dp_degree', type=int, default=1)
parser.add_argument('--tp_degree', type=int, default=1)
parser.add_argument('--num_microbatches', type=int, default=2)
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29011')
parser.add_argument('--num_worker_threads', type=int, default=128)
return parser.parse_args()
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
# Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
return input_ids, attention_mask
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
tracer = ColoTracer()
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__)
annotated_model = avgnode_split_pass(gm, stage_num)
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True)
topo = get_fx_topology(top_module)
for submodule in split_submodules:
if isinstance(submodule, torch.fx.GraphModule):
setattr(submodule, '_topo', topo)
return split_submodules[pp_rank + 1]
def partition(model_type, data_kwargs, pp_rank: int, chunk: int, stage_num: int):
# build model
model = model_builder(model_type)(checkpoint=False)
module = create_partition_module(pp_rank, stage_num, model, data_kwargs)
return module
def run_master(args):
batch_size = args.batch_size
device = args.device
world_size = args.world_size
stage_num = world_size
num_microbatches = args.num_microbatches
model_type = args.model_type
# batch size per DP degree
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
WARMUP_STEPS = 1
disable_existing_loggers()
logger = get_dist_logger()
logger.info(f"{args.model_type}, batch size {batch_size}, num stage {stage_num}, num microbatch {num_microbatches}",
ranks=[0])
torch.manual_seed(123)
# build criterion
criterion = GPTLMLoss()
# warm up pipeline fx partition
input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE)
warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
# set 1f1b pipeline engine
pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model_type, warmup_data_kwargs),
stage_num=stage_num,
num_microbatches=num_microbatches,
device=device,
chunk=1,
criterion=criterion,
metric=None,
checkpoint=False)
partition_numels = pp_engine.remote_numels()
for rank, numel in partition_numels.items():
logger.info(f'{rank=} numel in the partition:{numel}')
# build optim
pp_engine.initialize_optimizer(HybridAdam, lr=1e-3)
ranks_tflops = {}
for n in range(NUM_STEPS):
# we just use randomly generated data here
input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE)
batch = {'input_ids': input_ids, 'attention_mask': attn_mask}
start = time.time()
outputs = pp_engine.forward_backward(batch=batch, labels=input_ids, forward_only=False)
step_time = time.time() - start
for rank, numel in partition_numels.items():
if rank not in ranks_tflops:
ranks_tflops[rank] = []
step_tflops = get_tflops(numel, batch_size, SEQ_LEN, step_time)
logger.info(
f"Rank{rank} , [{n + 1}/{NUM_STEPS}] , Step time: {step_time:.3f}s, TFLOPS: {get_tflops(numel, batch_size, SEQ_LEN, step_time):.3f}",
ranks=[0],
)
if n >= WARMUP_STEPS:
ranks_tflops[rank].append(step_tflops)
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
gpu_tflops = []
for rank, tflops_list in ranks_tflops.items():
tflops_list.sort()
gpu_tflops.append(tflops_list[median_index])
logger.info(f"GPU{rank} Median TFLOPS is {tflops_list[median_index]:.3f}")
logger.info(f"Total TFLOPS is {sum(gpu_tflops):.3f}")
logger.info(f"Avg TFLOPS per GPU is {sum(gpu_tflops) / world_size:.3f}")
if __name__ == '__main__':
args = parse_args()
rpc_run(args, run_master)