mirror of https://github.com/hpcaitech/ColossalAI
[example] update gpt benchmark (#2219)
parent
54de05da5d
commit
3629e611cd
|
@ -2,12 +2,12 @@
|
||||||
export DISTPAN="colossalai"
|
export DISTPAN="colossalai"
|
||||||
|
|
||||||
# The following options only valid when DISTPAN="colossalai"
|
# The following options only valid when DISTPAN="colossalai"
|
||||||
export TPDEGREE=4
|
export TPDEGREE=1
|
||||||
export GPUNUM=8
|
export GPUNUM=1
|
||||||
export PLACEMENT='cpu'
|
export PLACEMENT='const'
|
||||||
export USE_SHARD_INIT=False
|
export USE_SHARD_INIT=False
|
||||||
export BATCH_SIZE=32
|
export BATCH_SIZE=32
|
||||||
# export MODEL_TYPE="gpt2_24b"
|
# export MODEL_TYPE="gpt2_10b"
|
||||||
|
|
||||||
mkdir -p logs
|
mkdir -p logs
|
||||||
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log
|
torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from time import time
|
from time import time
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from model_zoo import model_builder
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
@ -15,7 +17,6 @@ from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, Proces
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
|
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
|
||||||
from model_zoo import model_builder
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -88,7 +89,7 @@ class GPTLMLoss(nn.Module):
|
||||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
|
|
||||||
## Randomly Generated Data
|
# Randomly Generated Data
|
||||||
def get_data(batch_size, seq_len, vocab_size):
|
def get_data(batch_size, seq_len, vocab_size):
|
||||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
@ -111,6 +112,22 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_size(model: nn.Module):
|
||||||
|
total_numel = 0
|
||||||
|
for module in model.modules():
|
||||||
|
for p in module.parameters(recurse=False):
|
||||||
|
total_numel += p.numel()
|
||||||
|
return total_numel
|
||||||
|
|
||||||
|
|
||||||
|
def set_cpu_maximum_parallelism():
|
||||||
|
conf_str = torch.__config__.parallel_info()
|
||||||
|
inter_str = conf_str.split("hardware_concurrency() : ")[1]
|
||||||
|
max_concurrency = inter_str.split('\n')[0]
|
||||||
|
os.environ["OMP_NUM_THREADS"] = max_concurrency
|
||||||
|
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
|
||||||
|
|
||||||
|
|
||||||
# Tensor Parallel
|
# Tensor Parallel
|
||||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||||
"""tensor_parallelize
|
"""tensor_parallelize
|
||||||
|
@ -157,10 +174,10 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
placement_policy=placememt_policy,
|
placement_policy=placememt_policy,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
hidden_dim=4096,
|
hidden_dim=8192,
|
||||||
search_range_mb=64)
|
search_range_mb=64)
|
||||||
if placememt_policy == 'const':
|
if placememt_policy == 'const':
|
||||||
model.gemini_manager._placement_policy.set_const_memory_boundary(10 * 1024)
|
model.gemini_manager._placement_policy.set_const_memory_boundary(2 * 1024)
|
||||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||||
from colossalai.gemini import ChunkManager, GeminiManager
|
from colossalai.gemini import ChunkManager, GeminiManager
|
||||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||||
|
@ -176,6 +193,7 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
set_cpu_maximum_parallelism()
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||||||
|
@ -187,6 +205,9 @@ def main():
|
||||||
VOCAB_SIZE = 50257
|
VOCAB_SIZE = 50257
|
||||||
|
|
||||||
NUM_STEPS = 10
|
NUM_STEPS = 10
|
||||||
|
WARMUP_STEPS = 1
|
||||||
|
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||||
|
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median "
|
||||||
|
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config={})
|
||||||
|
@ -239,7 +260,7 @@ def main():
|
||||||
verbose=True)
|
verbose=True)
|
||||||
|
|
||||||
# model is shared after TP
|
# model is shared after TP
|
||||||
numel = sum([p.numel() for p in model.parameters()])
|
numel = get_model_size(model)
|
||||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||||
|
|
||||||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||||
|
@ -249,29 +270,48 @@ def main():
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
model.train()
|
model.train()
|
||||||
|
tflops_list = []
|
||||||
for n in range(NUM_STEPS):
|
for n in range(NUM_STEPS):
|
||||||
# we just use randomly generated data here
|
# we just use randomly generated data here
|
||||||
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
input_ids, attn_mask = get_data(BATCH_SIZE, SEQ_LEN, VOCAB_SIZE)
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
start = time()
|
start = time()
|
||||||
outputs = model(input_ids, attn_mask)
|
outputs = model(input_ids, attn_mask)
|
||||||
loss = criterion(outputs, input_ids)
|
loss = criterion(outputs, input_ids)
|
||||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
|
torch.cuda.synchronize()
|
||||||
|
fwd_end = time()
|
||||||
|
fwd_time = fwd_end - start
|
||||||
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Forward '), ranks=[0])
|
||||||
|
|
||||||
if args.distplan in ["colossalai", "zero1", "zero2"]:
|
if args.distplan in ["colossalai", "zero1", "zero2"]:
|
||||||
optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
elif args.distplan in ["torch_ddp", "torch_zero"]:
|
elif args.distplan in ["torch_ddp", "torch_zero"]:
|
||||||
loss.backward()
|
loss.backward()
|
||||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
|
torch.cuda.synchronize()
|
||||||
|
bwd_end = time()
|
||||||
|
bwd_time = bwd_end - fwd_end
|
||||||
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Backward '), ranks=[0])
|
||||||
|
|
||||||
if args.distplan in ["zero1", "zero2"]:
|
if args.distplan in ["zero1", "zero2"]:
|
||||||
optimizer.sync_grad()
|
optimizer.sync_grad()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
optim_time = time() - bwd_end
|
||||||
step_time = time() - start
|
step_time = time() - start
|
||||||
logger.info(
|
logger.info(get_mem_info(prefix=f'[{n + 1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
||||||
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
|
|
||||||
ranks=[0])
|
|
||||||
|
|
||||||
|
step_tflops = get_tflops_func(step_time)
|
||||||
|
logger.info(
|
||||||
|
f"[{n + 1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}, FWD time: {fwd_time:.3f}s, BWD time: {bwd_time:.3f}s, OPTIM time: {optim_time:.3f}s",
|
||||||
|
ranks=[0],
|
||||||
|
)
|
||||||
|
if n >= WARMUP_STEPS:
|
||||||
|
tflops_list.append(step_tflops)
|
||||||
|
|
||||||
|
tflops_list.sort()
|
||||||
|
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
||||||
|
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue