[example] update gpt benchmark (#2219)

pull/2221/head^2
HELSON 2022-12-29 10:51:42 +08:00 committed by GitHub
parent 54de05da5d
commit 3629e611cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 16 deletions

View File

@ -2,12 +2,12 @@
export DISTPAN="colossalai" export DISTPAN="colossalai"
# The following options only valid when DISTPAN="colossalai" # The following options only valid when DISTPAN="colossalai"
export TPDEGREE=4 export TPDEGREE=1
export GPUNUM=8 export GPUNUM=1
export PLACEMENT='cpu' export PLACEMENT='const'
export USE_SHARD_INIT=False export USE_SHARD_INIT=False
export BATCH_SIZE=32 export BATCH_SIZE=32
# export MODEL_TYPE="gpt2_24b" # export MODEL_TYPE="gpt2_10b"
mkdir -p logs mkdir -p logs
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log

View File

@ -1,9 +1,11 @@
import os
from functools import partial from functools import partial
from time import time from time import time
import psutil import psutil
import torch import torch
import torch.nn as nn import torch.nn as nn
from model_zoo import model_builder
from packaging import version from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
@ -15,7 +17,6 @@ from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, Proces
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
from model_zoo import model_builder
def parse_args(): def parse_args():
@ -88,7 +89,7 @@ class GPTLMLoss(nn.Module):
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
## Randomly Generated Data # Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size): def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
@ -111,6 +112,22 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def get_model_size(model: nn.Module):
total_numel = 0
for module in model.modules():
for p in module.parameters(recurse=False):
total_numel += p.numel()
return total_numel
def set_cpu_maximum_parallelism():
conf_str = torch.__config__.parallel_info()
inter_str = conf_str.split("hardware_concurrency() : ")[1]
max_concurrency = inter_str.split('\n')[0]
os.environ["OMP_NUM_THREADS"] = max_concurrency
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
# Tensor Parallel # Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup): def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize """tensor_parallelize
@ -157,10 +174,10 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
device=get_current_device(), device=get_current_device(),
placement_policy=placememt_policy, placement_policy=placememt_policy,
pin_memory=True, pin_memory=True,
hidden_dim=4096, hidden_dim=8192,
search_range_mb=64) search_range_mb=64)
if placememt_policy == 'const': if placememt_policy == 'const':
model.gemini_manager._placement_policy.set_const_memory_boundary(10 * 1024) model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
@ -176,6 +193,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
def main(): def main():
set_cpu_maximum_parallelism()
args = parse_args() args = parse_args()
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]: if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
@ -187,6 +205,9 @@ def main():
VOCAB_SIZE = 50257 VOCAB_SIZE = 50257
NUM_STEPS = 10 NUM_STEPS = 10
WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config={}) colossalai.launch_from_torch(config={})
@ -239,7 +260,7 @@ def main():
verbose=True) verbose=True)
# model is shared after TP # model is shared after TP
numel = sum([p.numel() for p in model.parameters()]) numel = get_model_size(model)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu # Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
@ -249,29 +270,48 @@ def main():
torch.cuda.synchronize() torch.cuda.synchronize()
model.train() model.train()
tflops_list = []
for n in range(NUM_STEPS): for n in range(NUM_STEPS):
# we just use randomly generated data here # we just use randomly generated data here
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE) input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
optimizer.zero_grad() optimizer.zero_grad()
start = time() start = time()
outputs = model(input_ids, attn_mask) outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids) loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0]) torch.cuda.synchronize()
fwd_end = time()
fwd_time = fwd_end - start
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
if args.distplan in ["colossalai", "zero1", "zero2"]: if args.distplan in ["colossalai", "zero1", "zero2"]:
optimizer.backward(loss) optimizer.backward(loss)
elif args.distplan in ["torch_ddp", "torch_zero"]: elif args.distplan in ["torch_ddp", "torch_zero"]:
loss.backward() loss.backward()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0]) torch.cuda.synchronize()
bwd_end = time()
bwd_time = bwd_end - fwd_end
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
if args.distplan in ["zero1", "zero2"]: if args.distplan in ["zero1", "zero2"]:
optimizer.sync_grad() optimizer.sync_grad()
optimizer.step() optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
torch.cuda.synchronize() torch.cuda.synchronize()
optim_time = time() - bwd_end
step_time = time() - start step_time = time() - start
logger.info( logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0])
step_tflops = get_tflops_func(step_time)
logger.info(
f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
ranks=[0],
)
if n >= WARMUP_STEPS:
tflops_list.append(step_tflops)
tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
torch.cuda.synchronize() torch.cuda.synchronize()