[example] add profile util for llama

pull/5751/head
hxwang 2024-05-24 03:59:36 +00:00
parent 15d21a077a
commit 63c057cd8e
2 changed files with 57 additions and 20 deletions

View File

@ -1,11 +1,12 @@
import argparse import argparse
import resource import resource
import time
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
from data_utils import RandomDataset from data_utils import RandomDataset
from model_utils import format_numel_str, get_model_numel from model_utils import format_numel_str, get_model_numel
from performance_evaluator import PerformanceEvaluator from performance_evaluator import PerformanceEvaluator, get_profile_context
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig, AutoModelForCausalLM from transformers import AutoConfig, AutoModelForCausalLM
@ -76,6 +77,7 @@ def main():
parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel") parser.add_argument("--mbs", type=int, default=1, help="Micro batch size of pipeline parallel")
parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled") parser.add_argument("--zero", type=int, default=0, help="Zero Stage when hybrid plugin is enabled")
parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False)
parser.add_argument("--profile", action="store_true", help="Enable profiling", default=False)
args = parser.parse_args() args = parser.parse_args()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -110,6 +112,7 @@ def main():
extra_dp_size=args.extra_dp, extra_dp_size=args.extra_dp,
enable_fused_normalization=torch.cuda.is_available(), enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers, enable_flash_attention=args.xformers,
max_prefetch=10,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -246,16 +249,27 @@ def main():
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB" f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss/1024:.2f} MB"
) )
with get_profile_context(
args.profile,
1,
len(dataloader) - 1,
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
data_iter = iter(dataloader) data_iter = iter(dataloader)
for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()): for step in tqdm(range(len(dataloader)), desc="Step", disable=not coordinator.is_master()):
performance_evaluator.on_step_start(step) performance_evaluator.on_step_start(step)
booster.execute_pipeline( booster.execute_pipeline(
data_iter, model, criterion=lambda outputs, inputs: outputs[0], optimizer=optimizer, return_loss=False data_iter,
model,
criterion=lambda outputs, inputs: outputs[0],
optimizer=optimizer,
return_loss=False,
) )
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length)) performance_evaluator.on_step_end(input_ids=torch.empty(args.batch_size, args.max_length))
prof.step()
else: else:
for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())): for step, batch in enumerate(tqdm(dataloader, desc="Step", disable=not coordinator.is_master())):
performance_evaluator.on_step_start(step) performance_evaluator.on_step_start(step)
@ -265,6 +279,7 @@ def main():
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
performance_evaluator.on_step_end(**batch) performance_evaluator.on_step_end(**batch)
prof.step()
performance_evaluator.on_fit_end() performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB") coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")

View File

@ -4,6 +4,7 @@ from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.profiler import ProfilerActivity, profile, schedule, tensorboard_trace_handler
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
@ -27,6 +28,27 @@ def all_reduce_mean(x: float, world_size: int) -> float:
return tensor.item() return tensor.item()
def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
class DummyProfiler:
def __init__(self):
self.step_number = 0
def step(self):
self.step_number += 1
if enable_flag:
return profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir),
# record_shapes=True,
# profile_memory=True,
with_stack=True,
)
else:
return DummyProfiler()
class Timer: class Timer:
def __init__(self) -> None: def __init__(self) -> None:
self.start_time: Optional[float] = None self.start_time: Optional[float] = None