Merge remote-tracking branch 'origin/develop'

pull/274/head v0.2.1dev20230901
Sun Peng 2023-09-01 08:37:34 +00:00
commit ad0cddce66
8 changed files with 281 additions and 32 deletions

View File

@ -7,6 +7,7 @@ from .parallel_context import (
from .process_group_initializer import (
Initializer_Data,
Initializer_Model,
Initializer_Nettest,
Initializer_Pipeline,
Initializer_Tensor,
Initializer_Zero1,
@ -34,6 +35,7 @@ __all__ = [
"Initializer_Pipeline",
"Initializer_Data",
"Initializer_Zero1",
"Initializer_Nettest",
"ProcessGroupInitializer",
"Initializer_Model",
"seed",

View File

@ -143,6 +143,7 @@ class ParallelContext(metaclass=SingletonMeta):
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
self.zero1_parallel_size = -1
self.nettest_parallel_size = 1
self.num_processes_on_current_node = -1
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
@ -442,6 +443,9 @@ class ParallelContext(metaclass=SingletonMeta):
# instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
# the recommended nettest_parallel_size is 32 GPUs
self.nettest_parallel_size = 32
if self.zero1_parallel_size <= 0:
self.zero1_parallel_size = self.data_parallel_size
@ -454,6 +458,7 @@ class ParallelContext(metaclass=SingletonMeta):
self.pipeline_parallel_size,
self.tensor_parallel_size,
self.zero1_parallel_size,
self.nettest_parallel_size,
]
# run initialization of different process groups
@ -462,6 +467,7 @@ class ParallelContext(metaclass=SingletonMeta):
initializers.append(pgroup_initializer.Initializer_Model(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Tensor(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Zero1(*initializer_args))
initializers.append(pgroup_initializer.Initializer_Nettest(*initializer_args))
if self.pipeline_parallel_size > 1:
initializers.append(pgroup_initializer.Initializer_Pipeline(*initializer_args))
for initializer in initializers:

View File

@ -3,6 +3,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context
import math
from abc import ABC, abstractmethod
from enum import Enum
@ -31,6 +32,9 @@ class ParallelMode(Enum):
# zero1 parallel
ZERO1 = "zero1"
# runntime network test
NETTEST = "nettest"
class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.
@ -52,6 +56,7 @@ class ProcessGroupInitializer(ABC):
pipeline_parallel_size: int,
tensor_parallel_size: int,
zero1_parallel_size: int,
nettest_parallel_size: int,
):
self.rank = rank
self.world_size = world_size
@ -59,6 +64,7 @@ class ProcessGroupInitializer(ABC):
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
self.zero1_parallel_size = zero1_parallel_size
self.nettest_parallel_size = nettest_parallel_size
super().__init__()
@abstractmethod
@ -332,3 +338,52 @@ class Initializer_Zero1(ProcessGroupInitializer):
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode
class Initializer_Nettest(ProcessGroupInitializer):
"""A ProcessGroupInitializer for network test, especailly for NCCL.
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
nettest_parallel_size (int): Size of a network test group.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.num_nettest_group = math.ceil(self.world_size / self.nettest_parallel_size)
def init_dist_group(self, use_cpu: bool = False):
"""Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Tensor parallelism's information tuple.
"""
local_rank = None
ranks_in_group = None
process_group = None
cpu_group = None
group_world_size = None
mode = ParallelMode.NETTEST
for i in range(self.num_nettest_group):
ranks = []
for j in range(self.nettest_parallel_size):
rank = i * self.nettest_parallel_size + j
if rank < self.world_size:
ranks.append(rank)
group = dist.new_group(ranks)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
else:
group_cpu = None
if self.rank in ranks:
local_rank = ranks.index(self.rank)
group_world_size = len(ranks)
process_group = group
cpu_group = group_cpu
ranks_in_group = ranks
return local_rank, group_world_size, process_group, cpu_group, ranks_in_group, mode

View File

@ -497,6 +497,7 @@ class HybridZeroOptimizer(BaseOptimizer):
grads = [self.padding_grad]
params = [self.padding_tensor]
norm = 0
if self._clip_grad_norm > 0:
# this norm is before scaling, it will be very large
norm = compute_norm(
@ -542,15 +543,15 @@ class HybridZeroOptimizer(BaseOptimizer):
self._param_store.clear_grads_of_previous_reduced_params()
# compute norm for gradients in the last bucket
total_norms = []
total_norms = {}
for group_id in range(self.num_param_groups):
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)
group_name = self.param_groups[group_id]["name"] if "name" in self.param_groups[group_id] else "default"
group_name = f"{group_id}_{group_name}"
total_norms[group_name] = self._compute_norm_with_stage(
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)
timer("sync_grad").start()
@ -569,7 +570,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# found_inf = self._check_overflow()
# Because you may encounter inf when computing norm
if -1 in norms:
if -1 in norms.values():
found_inf = True
loss_scale = float(self.loss_scale.item()) # backup
@ -617,15 +618,17 @@ class HybridZeroOptimizer(BaseOptimizer):
# unscale and clip grads
# get the global norm
global_norm_groups = []
global_norm_groups = {}
if self._clip_grad_norm > 0:
for norm in norms:
global_norm_groups.append(norm**0.5)
for group_name, norm in norms.items():
global_norm_groups[group_name] = norm**0.5
# the following operations are performed only on the rank to which parameters are assigned.
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
)
# update the parameters
timer("step").start()
@ -652,7 +655,9 @@ class HybridZeroOptimizer(BaseOptimizer):
# update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
for group_name, global_norm in global_norm_groups.items():
global_norm_groups[group_name] = global_norm / loss_scale
return True, global_norm_groups
def broadcast_params(self):
handles = []

View File

@ -389,23 +389,31 @@ def record_current_batch_training_metrics(
line = ""
for key, value in infos.items():
line += f"{key}={value} "
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if isinstance(value, dict):
writer.add_scalars(key=key, value=value, step=train_state.step_count)
else:
writer.add_scalar(key=key, value=value, step=train_state.step_count)
if update_panel:
# metrics shown with dashboard panels
panel_metrics = {
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
}
for norm_key, norm_value in grad_norm.items():
panel_metrics[norm_key] = norm_value
logger.info(
line,
extra={
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"grad_norm": grad_norm,
"loss": loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],
"perplexity": acc_perplex["perplexity"],
"fwd_bwd_time": fwd_bwd_time,
},
"{line}",
line=line,
extra=panel_metrics,
)
else:
logger.info(line)

163
internlm/utils/gputest.py Normal file
View File

@ -0,0 +1,163 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import socket
import torch
import torch.distributed as dist
from flash_attn.modules.mha import FlashSelfAttention, SelfAttention
from torch.utils import benchmark
from internlm.utils.logger import get_logger
try:
import GPUtil
import psutil
except ImportError:
GPUtil, psutil = None, None
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_current_device
logger = get_logger(__file__)
def benchmark_forward(
test_fn,
*inputs,
repeats=100,
amp=True,
amp_dtype=torch.float16,
**kwinputs,
):
"""Use Pytorch Benchmark on the forward pass of an arbitrary function."""
def amp_wrapper(*inputs, **kwinputs):
with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=amp):
test_fn(*inputs, **kwinputs)
bench_timer = benchmark.Timer(
stmt="test_fn_amp(*inputs, **kwinputs)",
globals={"test_fn_amp": amp_wrapper, "inputs": inputs, "kwinputs": kwinputs},
num_threads=torch.get_num_threads(),
)
used_time = bench_timer.timeit(repeats)
return used_time.mean
def flops(batch, seqlen, headdim, nheads, time_f):
"""Compute the flops value of a GPU with give flashattention function"""
flop = 4 * batch * seqlen**2 * nheads * headdim
return (flop / time_f / 10**12) if not math.isnan(time_f) else 0.0
def get_gpu_temperature():
"""Get current GPU temperature."""
try:
gpu_id = torch.cuda.current_device()
except AssertionError:
gpu_id = -1
if GPUtil is not None and gpu_id >= 0:
gpus = GPUtil.getGPUs()
gpu_temperature = gpus[gpu_id].temperature
else:
gpu_temperature = -1
return gpu_temperature
def get_cpu_temperature():
"""Get current CPU temperature."""
if psutil is not None:
cpu_temperature = psutil.sensors_temperatures()["coretemp"][0].current
else:
cpu_temperature = -1
return cpu_temperature
def bench_net():
"""Benchmark nccl performance for slow node detection."""
if gpc.get_world_size(ParallelMode.GLOBAL) <= 1:
return
if gpc.is_rank_for_log():
logger.info("benchmarking network speed ...")
repeats = 100
input_data = torch.randn(
8 * 1024 * 1024,
device=get_current_device(),
dtype=torch.bfloat16,
)
def allreduce_fn(inputs):
dist.all_reduce(inputs, op=torch.distributed.ReduceOp.AVG, group=gpc.get_group(ParallelMode.NETTEST))
bench_timer = benchmark.Timer(
stmt="test_fn_amp(inputs)",
globals={"test_fn_amp": allreduce_fn, "inputs": input_data},
num_threads=torch.get_num_threads(),
)
allreduce_time = bench_timer.timeit(repeats).mean
allreduce_time = allreduce_time * 10**3
allreduce_time_this = allreduce_time
allreduce_time = torch.Tensor([allreduce_time]).to(device=get_current_device())
dist.all_reduce(allreduce_time, group=gpc.get_group(ParallelMode.GLOBAL))
allreduce_time_avg = allreduce_time / gpc.get_world_size(ParallelMode.GLOBAL)
allreduce_time_avg = float(allreduce_time_avg.item())
if allreduce_time_this >= allreduce_time_avg * 1.05:
logger.warning(
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} NCCL test is slower than avg, "
f"Hostname {socket.gethostname()}, "
f"allreduce_time {allreduce_time_this:.2f}, avg {allreduce_time_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)
def bench_gpu(use_flash_attn=True):
"""Benchmark single GPU performance for slow node detection."""
if gpc.is_rank_for_log():
logger.info("benchmarking gpu speed ...")
headdim = 64
dim = 2048
batch_size, seqlen = 2, 1024
nheads = dim // headdim
inner_attn = FlashSelfAttention if use_flash_attn else SelfAttention
inner_attn = inner_attn(causal=True, softmax_scale=None, attention_dropout=0)
qkv = torch.randn(
batch_size,
seqlen,
3,
dim // headdim,
headdim,
device=get_current_device(),
dtype=torch.float16,
requires_grad=True,
)
time_f = benchmark_forward(inner_attn, qkv)
speed = flops(batch_size, seqlen, headdim, nheads, time_f)
speed_this = speed
speed = torch.Tensor([speed]).to(device=get_current_device())
dist.all_reduce(speed, group=gpc.get_group(ParallelMode.GLOBAL))
speed_avg = speed / gpc.get_world_size(ParallelMode.GLOBAL)
speed_avg = float(speed_avg.item())
if speed_this <= speed_avg * 0.95:
logger.warning(
f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} GPU is slower than avg, "
f"Hostname {socket.gethostname()}, "
f"tflops {speed_this:.2f}, avg {speed_avg:.2f}, "
f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}"
)

View File

@ -134,6 +134,14 @@ class Writer:
except Exception:
traceback.print_exc()
def add_scalars(self, key, value, step):
try:
assert isinstance(value, dict)
if self.enable_tb and self.tb_writer is not None:
self.tb_writer.add_scalars(main_tag=key, tag_scalar_dict=value, global_step=step)
except Exception:
traceback.print_exc()
def add_text(self, key, value, step):
try:
if self.enable_tb and self.tb_writer is not None:

View File

@ -6,7 +6,6 @@ import time
import traceback
from functools import partial
import numpy as np
import torch
import torch.distributed as dist
@ -36,6 +35,7 @@ from internlm.utils.common import (
parse_args,
)
from internlm.utils.evaluation import evaluate_on_val_dls
from internlm.utils.gputest import bench_gpu, bench_net
from internlm.utils.logger import get_logger, initialize_uniscale_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.model_checkpoint import CheckpointManager
@ -197,6 +197,8 @@ def main(args):
for batch_count in range(train_state.batch_count, total_steps):
if batch_count % 50 == 0:
torch.cuda.empty_cache()
bench_gpu()
bench_net()
start_time = time.time()
timer("one-batch").start()
@ -236,7 +238,7 @@ def main(args):
train_state.step_count += 1
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address,
@ -257,7 +259,7 @@ def main(args):
trainer=trainer,
start_time=start_time,
loss=loss,
grad_norm=np.array(grad_norm_groups),
grad_norm=grad_norm_groups,
metric=metric,
update_panel=uniscale_logger is not None,
)