mirror of https://github.com/hpcaitech/ColossalAI
[example] add benchmark (#2276)
* add benchmark * merge common func * add total and avg tflops Co-authored-by: Ziyue Jiang <ziyue.jiang@gmail.com>pull/2284/head
parent
1405b4381e
commit
ac863a01d6
|
@ -240,6 +240,10 @@ class WorkerBase(ABC):
|
|||
output = [output[i] for i in offsets]
|
||||
return output
|
||||
|
||||
def get_numels(self) -> int:
|
||||
numel = sum(param.numel() for param in self.module_partition.parameters())
|
||||
return numel
|
||||
|
||||
def get_parameters(self) -> List[torch.Tensor]:
|
||||
return [p for p in self.module_partition.parameters()]
|
||||
|
||||
|
@ -1115,6 +1119,15 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
for fut in sync_futs:
|
||||
fut.wait()
|
||||
|
||||
def remote_numels(self) -> Dict[int, int]:
|
||||
numels = {}
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
for stage_id in range(actual_stage_num):
|
||||
worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
||||
numel = worker_rref.rpc_sync().get_numels()
|
||||
numels[stage_id] = numel
|
||||
return numels
|
||||
|
||||
def remote_parameters(self) -> Dict[int, List[torch.Tensor]]:
|
||||
parameters = {}
|
||||
actual_stage_num = self._get_actual_stage_num()
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.nn as nn
|
|||
from model_zoo import model_builder
|
||||
from packaging import version
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from utils import get_data, get_tflops
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
|
@ -95,13 +96,6 @@ class GPTLMLoss(nn.Module):
|
|||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
# Randomly Generated Data
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
|
||||
|
@ -114,10 +108,6 @@ def get_mem_info(prefix=''):
|
|||
return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB'
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module):
|
||||
total_numel = 0
|
||||
for module in model.modules():
|
||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
|||
from model_zoo import model_builder
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
from utils import get_data, get_tflops
|
||||
|
||||
from colossalai.fx import ColoTracer
|
||||
from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass
|
||||
|
@ -26,7 +27,7 @@ def parse_args():
|
|||
parser.add_argument('--num_microbatches', type=int, default=2)
|
||||
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
|
||||
parser.add_argument('--master_addr', type=str, default='localhost')
|
||||
parser.add_argument('--master_port', type=str, default='29020')
|
||||
parser.add_argument('--master_port', type=str, default='29011')
|
||||
parser.add_argument('--num_worker_threads', type=int, default=128)
|
||||
return parser.parse_args()
|
||||
|
||||
|
@ -66,12 +67,10 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs):
|
|||
return split_submodules[pp_rank + 1]
|
||||
|
||||
|
||||
def partition(logger, model_type, data_kwargs, pp_rank: int, chunk: int, stage_num: int):
|
||||
def partition(model_type, data_kwargs, pp_rank: int, chunk: int, stage_num: int):
|
||||
# build model
|
||||
model = model_builder(model_type)(checkpoint=False)
|
||||
module = create_partition_module(pp_rank, stage_num, model, data_kwargs)
|
||||
num_params = sum(param.numel() for param in module.parameters())
|
||||
logger.info(f'{pp_rank=} number of args in this partition:{num_params}')
|
||||
return module
|
||||
|
||||
|
||||
|
@ -86,6 +85,7 @@ def run_master(args):
|
|||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
NUM_STEPS = 10
|
||||
WARMUP_STEPS = 1
|
||||
|
||||
disable_existing_loggers()
|
||||
logger = get_dist_logger()
|
||||
|
@ -102,7 +102,7 @@ def run_master(args):
|
|||
warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
|
||||
# set 1f1b pipeline engine
|
||||
pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, logger, model_type, warmup_data_kwargs),
|
||||
pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model_type, warmup_data_kwargs),
|
||||
stage_num=stage_num,
|
||||
num_microbatches=num_microbatches,
|
||||
device=device,
|
||||
|
@ -111,21 +111,45 @@ def run_master(args):
|
|||
metric=None,
|
||||
checkpoint=False)
|
||||
|
||||
partition_numels = pp_engine.remote_numels()
|
||||
for rank, numel in partition_numels.items():
|
||||
logger.info(f'{rank=} numel in the partition:{numel}')
|
||||
|
||||
# build optim
|
||||
pp_engine.initialize_optimizer(HybridAdam, lr=1e-3)
|
||||
|
||||
times = []
|
||||
for n in tqdm(range(NUM_STEPS)):
|
||||
ranks_tflops = {}
|
||||
for n in range(NUM_STEPS):
|
||||
# we just use randomly generated data here
|
||||
input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE)
|
||||
batch = {'input_ids': input_ids, 'attention_mask': attn_mask}
|
||||
|
||||
start = time.time()
|
||||
outputs = pp_engine.forward_backward(batch=batch, labels=input_ids, forward_only=False)
|
||||
cost_time = time.time() - start
|
||||
times.append(cost_time)
|
||||
step_time = time.time() - start
|
||||
|
||||
logger.info("avg cost time : {}s".format(sum(times) / len(times)))
|
||||
for rank, numel in partition_numels.items():
|
||||
if rank not in ranks_tflops:
|
||||
ranks_tflops[rank] = []
|
||||
step_tflops = get_tflops(numel, batch_size, SEQ_LEN, step_time)
|
||||
|
||||
logger.info(
|
||||
f"Rank{rank} , [{n + 1}/{NUM_STEPS}] , Step time: {step_time:.3f}s, TFLOPS: {get_tflops(numel, batch_size, SEQ_LEN, step_time):.3f}",
|
||||
ranks=[0],
|
||||
)
|
||||
|
||||
if n >= WARMUP_STEPS:
|
||||
ranks_tflops[rank].append(step_tflops)
|
||||
|
||||
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
||||
gpu_tflops = []
|
||||
for rank, tflops_list in ranks_tflops.items():
|
||||
tflops_list.sort()
|
||||
gpu_tflops.append(tflops_list[median_index])
|
||||
logger.info(f"GPU{rank} Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||
|
||||
logger.info(f"Total TFLOPS is {sum(gpu_tflops):.3f}")
|
||||
logger.info(f"Avg TFLOPS per GPU is {sum(gpu_tflops) / world_size:.3f}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
import torch
|
||||
|
||||
|
||||
# Randomly Generated Data
|
||||
def get_data(batch_size, seq_len, vocab_size):
|
||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
Loading…
Reference in New Issue