[example] update gpt benchmark (#2219)

pull/2221/head^2
HELSON 2022-12-29 10:51:42 +08:00 committed by GitHub
parent 54de05da5d
commit 3629e611cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 16 deletions

View File

@ -2,12 +2,12 @@
export DISTPAN="colossalai"
# The following options only valid when DISTPAN="colossalai"
export TPDEGREE=4
export GPUNUM=8
export PLACEMENT='cpu'
export TPDEGREE=1
export GPUNUM=1
export PLACEMENT='const'
export USE_SHARD_INIT=False
export BATCH_SIZE=32
# export MODEL_TYPE="gpt2_24b"
# export MODEL_TYPE="gpt2_10b"
mkdir -p logs
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log
torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log

View File

@ -1,9 +1,11 @@
import os
from functools import partial
from time import time
import psutil
import torch
import torch.nn as nn
from model_zoo import model_builder
from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP
@ -15,7 +17,6 @@ from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, Proces
from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
from model_zoo import model_builder
def parse_args():
@ -88,7 +89,7 @@ class GPTLMLoss(nn.Module):
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
## Randomly Generated Data
# Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids)
@ -111,6 +112,22 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def get_model_size(model: nn.Module):
total_numel = 0
for module in model.modules():
for p in module.parameters(recurse=False):
total_numel += p.numel()
return total_numel
def set_cpu_maximum_parallelism():
conf_str = torch.__config__.parallel_info()
inter_str = conf_str.split("hardware_concurrency() : ")[1]
max_concurrency = inter_str.split('\n')[0]
os.environ["OMP_NUM_THREADS"] = max_concurrency
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
@ -157,10 +174,10 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
hidden_dim=4096,
hidden_dim=8192,
search_range_mb=64)
if placememt_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(10 * 1024)
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
@ -176,6 +193,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
def main():
set_cpu_maximum_parallelism()
args = parse_args()
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
@ -187,6 +205,9 @@ def main():
VOCAB_SIZE = 50257
NUM_STEPS = 10
WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
disable_existing_loggers()
colossalai.launch_from_torch(config={})
@ -239,7 +260,7 @@ def main():
verbose=True)
# model is shared after TP
numel = sum([p.numel() for p in model.parameters()])
numel = get_model_size(model)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
@ -249,29 +270,48 @@ def main():
torch.cuda.synchronize()
model.train()
tflops_list = []
for n in range(NUM_STEPS):
# we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad()
start = time()
outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
torch.cuda.synchronize()
fwd_end = time()
fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
if args.distplan in ["colossalai", "zero1", "zero2"]:
optimizer.backward(loss)
elif args.distplan in ["torch_ddp", "torch_zero"]:
loss.backward()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
torch.cuda.synchronize()
bwd_end = time()
bwd_time = bwd_end - fwd_end
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
if args.distplan in ["zero1", "zero2"]:
optimizer.sync_grad()
optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
torch.cuda.synchronize()
optim_time = time() - bwd_end
step_time = time() - start
logger.info(
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0])
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
step_tflops = get_tflops_func(step_time)
logger.info(
f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
ranks=[0],
)
if n >= WARMUP_STEPS:
tflops_list.append(step_tflops)
tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
torch.cuda.synchronize()