From ac863a01d6d6b2397c550083756592bb25cccc13 Mon Sep 17 00:00:00 2001 From: Ziyue Jiang Date: Tue, 3 Jan 2023 17:20:59 +0800 Subject: [PATCH] [example] add benchmark (#2276) * add benchmark * merge common func * add total and avg tflops Co-authored-by: Ziyue Jiang --- colossalai/pipeline/rpc/_pipeline_base.py | 13 +++++++ examples/language/gpt/train_gpt_demo.py | 12 +----- examples/language/gpt/train_gpt_pp_demo.py | 44 +++++++++++++++++----- examples/language/gpt/utils.py | 12 ++++++ 4 files changed, 60 insertions(+), 21 deletions(-) create mode 100644 examples/language/gpt/utils.py diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index cbbd317e4..2a7998c14 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -240,6 +240,10 @@ class WorkerBase(ABC): output = [output[i] for i in offsets] return output + def get_numels(self) -> int: + numel = sum(param.numel() for param in self.module_partition.parameters()) + return numel + def get_parameters(self) -> List[torch.Tensor]: return [p for p in self.module_partition.parameters()] @@ -1115,6 +1119,15 @@ class PipelineEngineBase(ABC, nn.Module): for fut in sync_futs: fut.wait() + def remote_numels(self) -> Dict[int, int]: + numels = {} + actual_stage_num = self._get_actual_stage_num() + for stage_id in range(actual_stage_num): + worker_rref = self.pp_rank_to_worker_rref[stage_id] + numel = worker_rref.rpc_sync().get_numels() + numels[stage_id] = numel + return numels + def remote_parameters(self) -> Dict[int, List[torch.Tensor]]: parameters = {} actual_stage_num = self._get_actual_stage_num() diff --git a/examples/language/gpt/train_gpt_demo.py b/examples/language/gpt/train_gpt_demo.py index 0b168b2ad..8704be9e0 100644 --- a/examples/language/gpt/train_gpt_demo.py +++ b/examples/language/gpt/train_gpt_demo.py @@ -8,6 +8,7 @@ import torch.nn as nn from model_zoo import model_builder from packaging import version from torch.nn.parallel import DistributedDataParallel as DDP +from utils import get_data, get_tflops import colossalai from colossalai.logging import disable_existing_loggers, get_dist_logger @@ -95,13 +96,6 @@ class GPTLMLoss(nn.Module): return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) -# Randomly Generated Data -def get_data(batch_size, seq_len, vocab_size): - input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) - attention_mask = torch.ones_like(input_ids) - return input_ids, attention_mask - - def get_cpu_mem(): return psutil.Process().memory_info().rss / 1024**2 @@ -114,10 +108,6 @@ def get_mem_info(prefix=''): return f'{prefix}GPU memory usage: {get_gpu_mem():.2f} MB, CPU memory usage: {get_cpu_mem():.2f} MB' -def get_tflops(model_numel, batch_size, seq_len, step_time): - return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) - - def get_model_size(model: nn.Module): total_numel = 0 for module in model.modules(): diff --git a/examples/language/gpt/train_gpt_pp_demo.py b/examples/language/gpt/train_gpt_pp_demo.py index bdb2c95cc..a77b76d62 100644 --- a/examples/language/gpt/train_gpt_pp_demo.py +++ b/examples/language/gpt/train_gpt_pp_demo.py @@ -6,6 +6,7 @@ import torch from model_zoo import model_builder from torch import nn from tqdm import tqdm +from utils import get_data, get_tflops from colossalai.fx import ColoTracer from colossalai.fx.passes.adding_split_node_pass import avgnode_split_pass, split_with_split_nodes_pass @@ -26,7 +27,7 @@ def parse_args(): parser.add_argument('--num_microbatches', type=int, default=2) parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') parser.add_argument('--master_addr', type=str, default='localhost') - parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument('--master_port', type=str, default='29011') parser.add_argument('--num_worker_threads', type=int, default=128) return parser.parse_args() @@ -66,12 +67,10 @@ def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): return split_submodules[pp_rank + 1] -def partition(logger, model_type, data_kwargs, pp_rank: int, chunk: int, stage_num: int): +def partition(model_type, data_kwargs, pp_rank: int, chunk: int, stage_num: int): # build model model = model_builder(model_type)(checkpoint=False) module = create_partition_module(pp_rank, stage_num, model, data_kwargs) - num_params = sum(param.numel() for param in module.parameters()) - logger.info(f'{pp_rank=} number of args in this partition:{num_params}') return module @@ -86,6 +85,7 @@ def run_master(args): SEQ_LEN = 1024 VOCAB_SIZE = 50257 NUM_STEPS = 10 + WARMUP_STEPS = 1 disable_existing_loggers() logger = get_dist_logger() @@ -102,7 +102,7 @@ def run_master(args): warmup_data_kwargs = {'input_ids': input_ids, 'attention_mask': attn_mask} # set 1f1b pipeline engine - pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, logger, model_type, warmup_data_kwargs), + pp_engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model_type, warmup_data_kwargs), stage_num=stage_num, num_microbatches=num_microbatches, device=device, @@ -111,21 +111,45 @@ def run_master(args): metric=None, checkpoint=False) + partition_numels = pp_engine.remote_numels() + for rank, numel in partition_numels.items(): + logger.info(f'{rank=} numel in the partition:{numel}') + # build optim pp_engine.initialize_optimizer(HybridAdam, lr=1e-3) - times = [] - for n in tqdm(range(NUM_STEPS)): + ranks_tflops = {} + for n in range(NUM_STEPS): # we just use randomly generated data here input_ids, attn_mask = get_data(batch_size, SEQ_LEN, VOCAB_SIZE) batch = {'input_ids': input_ids, 'attention_mask': attn_mask} start = time.time() outputs = pp_engine.forward_backward(batch=batch, labels=input_ids, forward_only=False) - cost_time = time.time() - start - times.append(cost_time) + step_time = time.time() - start - logger.info("avg cost time : {}s".format(sum(times) / len(times))) + for rank, numel in partition_numels.items(): + if rank not in ranks_tflops: + ranks_tflops[rank] = [] + step_tflops = get_tflops(numel, batch_size, SEQ_LEN, step_time) + + logger.info( + f"Rank{rank} , [{n + 1}/{NUM_STEPS}] , Step time: {step_time:.3f}s, TFLOPS: {get_tflops(numel, batch_size, SEQ_LEN, step_time):.3f}", + ranks=[0], + ) + + if n >= WARMUP_STEPS: + ranks_tflops[rank].append(step_tflops) + + median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS + gpu_tflops = [] + for rank, tflops_list in ranks_tflops.items(): + tflops_list.sort() + gpu_tflops.append(tflops_list[median_index]) + logger.info(f"GPU{rank} Median TFLOPS is {tflops_list[median_index]:.3f}") + + logger.info(f"Total TFLOPS is {sum(gpu_tflops):.3f}") + logger.info(f"Avg TFLOPS per GPU is {sum(gpu_tflops) / world_size:.3f}") if __name__ == '__main__': diff --git a/examples/language/gpt/utils.py b/examples/language/gpt/utils.py new file mode 100644 index 000000000..782f546dc --- /dev/null +++ b/examples/language/gpt/utils.py @@ -0,0 +1,12 @@ +import torch + + +# Randomly Generated Data +def get_data(batch_size, seq_len, vocab_size): + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + +def get_tflops(model_numel, batch_size, seq_len, step_time): + return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)