Feat/overlap_bcast_forward (#218)

* feat/support bcast forward overlao

* feat/optimize the bcast call

* feat/optimize the bcast call

* feat/optimize the bcast call

* fix lint

* fix lint

* fix lint

* fix lint

* add torch.cuda.synchronize in save_checkpoint

---------

Co-authored-by: sunpeng <sunpengsdu@gmail.com>
pull/212/head
Sun Peng 2023-08-23 16:59:59 +08:00 committed by GitHub
parent a48210f1f3
commit 32664328e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 193 additions and 51 deletions

View File

@ -3,6 +3,7 @@
import math
from functools import partial
from itertools import product
import torch
import torch.distributed as dist
@ -19,6 +20,7 @@ from internlm.solver.optimizer.store import (
)
from internlm.solver.optimizer.utils import (
DynamicGradScaler,
ParamBcastSyncHandler,
flatten,
get_grad_accumulate_object,
has_inf_or_nan,
@ -87,9 +89,9 @@ class HybridZeroOptimizer(BaseOptimizer):
self,
optimizer: Optimizer,
cpu_offload=False,
overlap_broadcast=False,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
param_bcast_sync_handler: ParamBcastSyncHandler = None,
):
# DynamicGradScaler related args
if gpc.config.model.dtype is torch.float32:
@ -158,7 +160,9 @@ class HybridZeroOptimizer(BaseOptimizer):
+ f"zo-{self._zero_local_rank}.pt"
)
self.params_per_rank_id_dict = []
self.overlap_broadcast = overlap_broadcast
self._param_bcast_sync_handler = param_bcast_sync_handler
if self._overlap_communication:
assert self._param_bcast_sync_handler is not None
# iterate over the param group in the optimizer
# partition these param groups for data parallel training
@ -230,6 +234,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# communication-computation overlapping
if self._overlap_communication:
self._comm_stream = torch.cuda.Stream()
else:
self._comm_stream = torch.cuda.current_stream()
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
@ -267,8 +273,10 @@ class HybridZeroOptimizer(BaseOptimizer):
global_id = str(i)
for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])])
rank_to_go = numel_per_rank.index(min(numel_per_rank))
if self._overlap_communication:
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
else:
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param)
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
numel_per_rank[rank_to_go] += param.numel()
@ -299,7 +307,9 @@ class HybridZeroOptimizer(BaseOptimizer):
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduction_func = partial(
self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank
self._store_and_try_reduce_grads_by_bucket,
param=param,
reduce_rank=reduce_rank,
)
# define hook
@ -385,16 +395,16 @@ class HybridZeroOptimizer(BaseOptimizer):
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication:
stream = self._comm_stream
stream.synchronize()
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
with torch.cuda.stream(self._comm_stream):
flat = bucket.flatten()
reduced_flat = reduce_tensor(
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA
tensor=flat,
dtype=self.dtype,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.DATA,
)
# update the reduced tensor
@ -532,7 +542,10 @@ class HybridZeroOptimizer(BaseOptimizer):
for group_id in range(self.num_param_groups):
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
)
)
@ -562,7 +575,10 @@ class HybridZeroOptimizer(BaseOptimizer):
if found_inf:
if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.")
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
send_alert_message(
address=gpc.config.alert_address,
message="Overflow occurs, please check it.",
)
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, None
@ -625,35 +641,40 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
# TODO: support broadcast overlap
self.broadcast_params(overlap=False)
with torch.cuda.stream(self._comm_stream):
self.broadcast_params()
timer("step").stop()
# update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
def broadcast_params(self, overlap=False):
def broadcast_params(self):
handles = []
for group_id in range(self.num_param_groups):
for rank in range(self._zero_world_size):
# The following operations are performed only on the rank to which parameters are assigned.
if rank not in self.param_group_no_params_ranks[group_id]:
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast(
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
)
handles.append(handle)
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
# The following operations are performed only on the rank to which parameters are assigned.
if rank in self.param_group_no_params_ranks[group_id]:
continue
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast(
fp16_param,
src=g_rank,
group=gpc.get_group(ParallelMode.ZERO1),
async_op=True,
)
if not overlap:
for handle in handles:
handle.wait()
else:
return handles
if self._overlap_communication:
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
else:
handles.append(handle)
for handle in handles:
handle.wait()
##################
# FP16 Utilities #
@ -671,7 +692,11 @@ class HybridZeroOptimizer(BaseOptimizer):
if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0)
break
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
dist.all_reduce(
self._found_overflow,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.GLOBAL),
)
return self._found_overflow.item() > 0

View File

@ -3,15 +3,18 @@
import math
from abc import ABC, abstractmethod
from typing import Dict, Optional
from collections import OrderedDict
from functools import partial
from typing import Any, Dict, Optional, Union
import torch
import torch.distributed as dist
from torch import Tensor
from torch import Tensor, nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_model_parallel_parameter
@ -60,12 +63,19 @@ def get_grad_accumulate_object(tensor):
def split_half_float_double(tensor_list):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
buckets = []
for _, dtype in enumerate(dtypes):
bucket = [t for t in tensor_list if t.type() == dtype]
if bucket:
buckets.append(bucket)
dtype_buckets = {
"torch.cuda.HalfTensor": [],
"torch.cuda.FloatTensor": [],
"torch.cuda.DoubleTensor": [],
"torch.cuda.BFloat16Tensor": [],
}
for t in tensor_list:
dtype = t.type()
if dtype in dtype_buckets:
dtype_buckets[dtype].append(t)
buckets = [bucket for bucket in dtype_buckets.values() if bucket]
return buckets
@ -184,7 +194,10 @@ def calc_l2_norm(grads):
if APEX_AVAILABLE:
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False, # no per-parameter norm
)
else:
norm, _ = multi_tensor_l2norm_torch(grads, False)
@ -228,7 +241,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
# Take max across all model-parallel GPUs.
if gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
dist.all_reduce(
total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL),
)
total_norm = total_norm_cuda[0].item()
else:
tensor_parallel_grads = []
@ -280,7 +297,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
# Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
dist.all_reduce(
total_norm,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.MODEL),
)
# This is because we use zero1, so we need to use this reduction.
# TODO: Check zero group to be a subset of dp group.
@ -459,3 +480,90 @@ class DynamicGradScaler(BaseGradScaler):
self._scale = self._scale.fill_(state_dict["_scale"])
self._growth_step = state_dict["_growth_step"]
self._hysteresis_step = state_dict["_hysteresis_step"]
class ParamBcastSyncHandler:
"""
Model Partition Handler for overlap broadcast with forward
"""
def __init__(self, model: Union[nn.Module, nn.ModuleList]) -> None:
self._block_to_param = OrderedDict() # <key: nn.Module> <value: list(param)>
self._param_to_rank = dict() # <key: param> <value: rank)>
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
total_param_num = sum(p.numel() for p in model.parameters())
avg_param_num = total_param_num * 1.0 // zero1_size
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
# record the parameters to transformer/embeding/head/norm block
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, children in _chunk.named_children():
# should be the transformer block definaton in modeling_xxx.py
if isinstance(children, nn.ModuleList):
# record the block that a parameter belongs to
for _, block in enumerate(children):
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters())
else:
# record the block that a parameter belongs to
# self._block_to_param[name] = list(children.parameters())
self._block_to_param[children] = list(children.parameters())
alloc_num = 0
rank_to_go = 0
# process the parameters in block_to_param sequencially,
# allocate each parameter to a local rank of ParallelMode.ZERO1,
# NOTE that we do NOT consider following scenarios:
# 1) whether a parameter is trainable;
# 2) paramters maybe in different optimizer group
for block, params in self._block_to_param.items():
# allocate a model block to a local rank of ParallelMode.ZERO1
self._block_to_rank[block] = [rank_to_go]
for p in params:
alloc_num = alloc_num + p.numel()
# in this case, allocate the param to next rank if possible
if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1:
rank_to_go = rank_to_go + 1
alloc_num = 0
self._block_to_rank[block].append(rank_to_go)
# allocate a parameter to a local rank of ParallelMode.ZERO1
self._param_to_rank[p] = rank_to_go
# initialize an empty list for _bcast_handles of each rank
for rank in range(gpc.get_world_size(ParallelMode.ZERO1)):
self._bcast_handles[rank] = []
# register_forward_pre_hook for transformer/embeding/norm/xxx block
self._register_sync_parameters_hook()
def _register_sync_parameters_hook(self) -> None:
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
bcast_handles = []
# gather all required broadcast hanles into a list
for rank in self._block_to_rank[model]:
bcast_handles.extend(self._bcast_handles[rank])
# need to clear _bcast_handles since they would be processed later
self._bcast_handles[rank] = []
# wait all required broadcast handles to be completed
for handle in bcast_handles:
handle.wait()
# register_forward_pre_hook for transformer/embeding/norm/xxx block
for block, _ in self._block_to_rank.items():
block.register_forward_pre_hook(partial(_pre_forward_hook))
def get_rank_by_param(self, param) -> int:
return self._param_to_rank[param]
def add_bcast_handle(self, rank, handle) -> None:
self._bcast_handles[rank].append(handle)

View File

@ -14,18 +14,19 @@ class _Timer:
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()
self.stream = torch.cuda.current_stream()
def start(self):
"""Start the timer."""
assert not self.started_, "timer has already been started"
torch.cuda.synchronize()
self.stream.synchronize()
self.start_time = time.time()
self.started_ = True
def stop(self):
"""Stop the timer."""
assert self.started_, "timer is not started"
torch.cuda.synchronize()
self.stream.synchronize()
self.elapsed_ += time.time() - self.start_time
self.started_ = False

View File

@ -565,6 +565,7 @@ set load_ckpt_folder or use default value \
start = time.time()
self.set_save_folder(folder, train_state.step_count)
torch.cuda.synchronize()
torch.distributed.barrier()
if gpc.is_rank_for_log():
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")

View File

@ -5,7 +5,7 @@ import socket
import time
import traceback
from functools import partial
from typing import Iterable
from typing import Iterable, Union
import numpy as np
import torch
@ -36,6 +36,7 @@ from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import (
BatchSkipper,
DummyProfile,
@ -291,7 +292,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
return batch, train_iter
def initialize_optimizer(model: nn.Module):
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
"""
Initialize optimizer.
@ -300,6 +301,7 @@ def initialize_optimizer(model: nn.Module):
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
"""
param_bcast_sync_handler = ParamBcastSyncHandler(model)
adam_cfg = gpc.config.adam
naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
@ -309,7 +311,10 @@ def initialize_optimizer(model: nn.Module):
)
optimizer = HybridZeroOptimizer(
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
@ -599,6 +604,7 @@ def main(args):
# do forward and backward
timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False
)
@ -661,7 +667,8 @@ def main(args):
if memory_profiler is not None:
memory_profiler.step()
prof.step()
if batch_count % 2 == 0:
prof.step()
ckpt_manager.wait_async_upload_finish()