InternLM/internlm/utils/gputest.py

164 lines
4.9 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import socket
import torch
import torch.distributed as dist
from flash_attn.modules.mha import FlashSelfAttention, SelfAttention
from torch.utils import benchmark
from internlm.utils.logger import get_logger
try:
import GPUtil
import psutil
except ImportError:
GPUtil, psutil = None, None
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_current_device
logger = get_logger(__file__)
def benchmark_forward(
test_fn,
*inputs,
repeats=100,
amp=True,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
test_fn(*inputs, **kwinputs)
bench_timer = benchmark.Timer(
stmt="test_fn_amp(*inputs, **kwinputs)",
globals={"test_fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
used_time = bench_timer.timeit(repeats)
return used_time.mean
def flops(batch, seqlen, headdim, nheads, time_f):
"""Compute the flops value of a GPU with give flashattention function"""
flop = 4 * batch * seqlen**2 * nheads * headdim
return (flop / time_f / 10**12) if not math.isnan(time_f) else 0.0
def get_gpu_temperature():
"""Get current GPU temperature."""
try:
gpu_id = torch.cuda.current_device()
except AssertionError:
gpu_id = -1
if GPUtil is not None and gpu_id >= 0:
gpus = GPUtil.getGPUs()
gpu_temperature = gpus[gpu_id].temperature
else:
gpu_temperature = -1
return gpu_temperature
def get_cpu_temperature():
"""Get current CPU temperature."""
if psutil is not None:
cpu_temperature = psutil.sensors_temperatures()["coretemp"][0].current
else:
cpu_temperature = -1
return cpu_temperature
def bench_net():
"""Benchmark nccl performance for slow node detection."""
if gpc.get_world_size(ParallelMode.GLOBAL) <= 1:
return
if gpc.is_rank_for_log():
logger.info("benchmarking network speed ...")
repeats = 100
input_data = torch.randn(
8 * 1024 * 1024,
device=get_current_device(),
dtype=torch.bfloat16,
)
def allreduce_fn(inputs):
dist.all_reduce(inputs, op=torch.distributed.ReduceOp.AVG, group=gpc.get_group(ParallelMode.NETTEST))
bench_timer = benchmark.Timer(
stmt="test_fn_amp(inputs)",
globals={"test_fn_amp": allreduce_fn, "inputs": input_data},
num_threads=torch.get_num_threads(),
)
allreduce_time = bench_timer.timeit(repeats).mean
allreduce_time = allreduce_time * 10**3
allreduce_time_this = allreduce_time
allreduce_time = torch.Tensor([allreduce_time]).to(device=get_current_device())
dist.all_reduce(allreduce_time, group=gpc.get_group(ParallelMode.GLOBAL))
allreduce_time_avg = allreduce_time / gpc.get_world_size(ParallelMode.GLOBAL)
allreduce_time_avg = float(allreduce_time_avg.item())
if allreduce_time_this >= allreduce_time_avg * 1.05:
logger.warning(
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} NCCL test is slower than avg, "
f"Hostname {socket.gethostname()}, "
f"allreduce_time {allreduce_time_this:.2f}, avg {allreduce_time_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
def bench_gpu(use_flash_attn=True):
"""Benchmark single GPU performance for slow node detection."""
if gpc.is_rank_for_log():
logger.info("benchmarking gpu speed ...")
headdim = 64
dim = 2048
batch_size, seqlen = 2, 1024
nheads = dim // headdim
inner_attn = FlashSelfAttention if use_flash_attn else SelfAttention
inner_attn = inner_attn(causal=True, softmax_scale=None, attention_dropout=0)
qkv = torch.randn(
batch_size,
seqlen,
3,
dim // headdim,
headdim,
device=get_current_device(),
dtype=torch.float16,
requires_grad=True,
)
time_f = benchmark_forward(inner_attn, qkv)
speed = flops(batch_size, seqlen, headdim, nheads, time_f)
speed_this = speed
speed = torch.Tensor([speed]).to(device=get_current_device())
dist.all_reduce(speed, group=gpc.get_group(ParallelMode.GLOBAL))
speed_avg = speed / gpc.get_world_size(ParallelMode.GLOBAL)
speed_avg = float(speed_avg.item())
if speed_this <= speed_avg * 0.95:
logger.warning(
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} GPU is slower than avg, "
f"Hostname {socket.gethostname()}, "
f"tflops {speed_this:.2f}, avg {speed_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)