Feat/overlap_bcast_forward (#218)

* feat/support bcast forward overlao

* feat/optimize the bcast call

* feat/optimize the bcast call

* feat/optimize the bcast call

* fix lint

* fix lint

* fix lint

* fix lint

* add torch.cuda.synchronize in save_checkpoint

---------

Co-authored-by: sunpeng <sunpengsdu@gmail.com>
pull/212/head
Sun Peng 2023-08-23 16:59:59 +08:00 committed by GitHub
parent a48210f1f3
commit 32664328e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 193 additions and 51 deletions

View File

@ -3,6 +3,7 @@
import math import math
from functools import partial from functools import partial
from itertools import product
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -19,6 +20,7 @@ from internlm.solver.optimizer.store import (
) )
from internlm.solver.optimizer.utils import ( from internlm.solver.optimizer.utils import (
DynamicGradScaler, DynamicGradScaler,
ParamBcastSyncHandler,
flatten, flatten,
get_grad_accumulate_object, get_grad_accumulate_object,
has_inf_or_nan, has_inf_or_nan,
@ -87,9 +89,9 @@ class HybridZeroOptimizer(BaseOptimizer):
self, self,
optimizer: Optimizer, optimizer: Optimizer,
cpu_offload=False, cpu_offload=False,
overlap_broadcast=False,
grad_scal_cfg: Config = None, grad_scal_cfg: Config = None,
zero_cfg: Config = None, zero_cfg: Config = None,
param_bcast_sync_handler: ParamBcastSyncHandler = None,
): ):
# DynamicGradScaler related args # DynamicGradScaler related args
if gpc.config.model.dtype is torch.float32: if gpc.config.model.dtype is torch.float32:
@ -158,7 +160,9 @@ class HybridZeroOptimizer(BaseOptimizer):
+ f"zo-{self._zero_local_rank}.pt" + f"zo-{self._zero_local_rank}.pt"
) )
self.params_per_rank_id_dict = [] self.params_per_rank_id_dict = []
self.overlap_broadcast = overlap_broadcast self._param_bcast_sync_handler = param_bcast_sync_handler
if self._overlap_communication:
assert self._param_bcast_sync_handler is not None
# iterate over the param group in the optimizer # iterate over the param group in the optimizer
# partition these param groups for data parallel training # partition these param groups for data parallel training
@ -230,6 +234,8 @@ class HybridZeroOptimizer(BaseOptimizer):
# communication-computation overlapping # communication-computation overlapping
if self._overlap_communication: if self._overlap_communication:
self._comm_stream = torch.cuda.Stream() self._comm_stream = torch.cuda.Stream()
else:
self._comm_stream = torch.cuda.current_stream()
# reduction hook is only used if overlapping communication # reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached # if it is stage 1 without overlapping, no hook will be attached
@ -267,8 +273,10 @@ class HybridZeroOptimizer(BaseOptimizer):
global_id = str(i) global_id = str(i)
for j in range(len(param.size())): for j in range(len(param.size())):
global_id = "_".join([global_id, str(param.size()[j])]) global_id = "_".join([global_id, str(param.size()[j])])
if self._overlap_communication:
rank_to_go = numel_per_rank.index(min(numel_per_rank)) rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
else:
rank_to_go = numel_per_rank.index(min(numel_per_rank))
params_per_rank[rank_to_go].append(param) params_per_rank[rank_to_go].append(param)
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id) self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
numel_per_rank[rank_to_go] += param.numel() numel_per_rank[rank_to_go] += param.numel()
@ -299,7 +307,9 @@ class HybridZeroOptimizer(BaseOptimizer):
self._grad_store.add_accumulate_grad_object(accum_grad_obj) self._grad_store.add_accumulate_grad_object(accum_grad_obj)
reduction_func = partial( reduction_func = partial(
self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank self._store_and_try_reduce_grads_by_bucket,
param=param,
reduce_rank=reduce_rank,
) )
# define hook # define hook
@ -385,16 +395,16 @@ class HybridZeroOptimizer(BaseOptimizer):
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication: if self._overlap_communication:
stream = self._comm_stream self._comm_stream.synchronize()
stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params() self._param_store.clear_grads_of_previous_reduced_params()
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream): with torch.cuda.stream(self._comm_stream):
flat = bucket.flatten() flat = bucket.flatten()
reduced_flat = reduce_tensor( reduced_flat = reduce_tensor(
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA tensor=flat,
dtype=self.dtype,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.DATA,
) )
# update the reduced tensor # update the reduced tensor
@ -532,7 +542,10 @@ class HybridZeroOptimizer(BaseOptimizer):
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
total_norms.append( total_norms.append(
self._compute_norm_with_stage( self._compute_norm_with_stage(
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id] group_id=group_id,
last_bucket=True,
last_stage=True,
previous_norm=groups_norms[group_id],
) )
) )
@ -562,7 +575,10 @@ class HybridZeroOptimizer(BaseOptimizer):
if found_inf: if found_inf:
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.warning("Overflow occurs, please check it.") logger.warning("Overflow occurs, please check it.")
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.") send_alert_message(
address=gpc.config.alert_address,
message="Overflow occurs, please check it.",
)
self._grad_store._averaged_gradients = dict() self._grad_store._averaged_gradients = dict()
self.zero_grad() self.zero_grad()
return False, None return False, None
@ -625,35 +641,40 @@ class HybridZeroOptimizer(BaseOptimizer):
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param) fp16_param.data.copy_(fp32_param)
# TODO: support broadcast overlap with torch.cuda.stream(self._comm_stream):
self.broadcast_params(overlap=False) self.broadcast_params()
timer("step").stop() timer("step").stop()
# update gradients may not be needed here, because the sync_params function is used in initialization, # update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained # so synchronization is maintained
return True, [global_norm / loss_scale for global_norm in global_norm_groups] return True, [global_norm / loss_scale for global_norm in global_norm_groups]
def broadcast_params(self, overlap=False): def broadcast_params(self):
handles = [] handles = []
for group_id in range(self.num_param_groups): for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
for rank in range(self._zero_world_size): # The following operations are performed only on the rank to which parameters are assigned.
# The following operations are performed only on the rank to which parameters are assigned. if rank in self.param_group_no_params_ranks[group_id]:
if rank not in self.param_group_no_params_ranks[group_id]: continue
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank # grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}" # assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank] g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast( handle = dist.broadcast(
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True fp16_param,
) src=g_rank,
handles.append(handle) group=gpc.get_group(ParallelMode.ZERO1),
async_op=True,
)
if not overlap: if self._overlap_communication:
for handle in handles: self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
handle.wait() else:
else: handles.append(handle)
return handles
for handle in handles:
handle.wait()
################## ##################
# FP16 Utilities # # FP16 Utilities #
@ -671,7 +692,11 @@ class HybridZeroOptimizer(BaseOptimizer):
if avg_grad is not None and has_inf_or_nan(avg_grad): if avg_grad is not None and has_inf_or_nan(avg_grad):
self._found_overflow.fill_(1.0) self._found_overflow.fill_(1.0)
break break
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL)) dist.all_reduce(
self._found_overflow,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.GLOBAL),
)
return self._found_overflow.item() > 0 return self._found_overflow.item() > 0

View File

@ -3,15 +3,18 @@
import math import math
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, Optional from collections import OrderedDict
from functools import partial
from typing import Any, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import Tensor from torch import Tensor, nn
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from internlm.core.context import ParallelMode from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc from internlm.core.context import global_context as gpc
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_tensor_norm, move_norm_to_cuda from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.parallel import is_model_parallel_parameter from internlm.utils.parallel import is_model_parallel_parameter
@ -60,12 +63,19 @@ def get_grad_accumulate_object(tensor):
def split_half_float_double(tensor_list): def split_half_float_double(tensor_list):
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"] dtype_buckets = {
buckets = [] "torch.cuda.HalfTensor": [],
for _, dtype in enumerate(dtypes): "torch.cuda.FloatTensor": [],
bucket = [t for t in tensor_list if t.type() == dtype] "torch.cuda.DoubleTensor": [],
if bucket: "torch.cuda.BFloat16Tensor": [],
buckets.append(bucket) }
for t in tensor_list:
dtype = t.type()
if dtype in dtype_buckets:
dtype_buckets[dtype].append(t)
buckets = [bucket for bucket in dtype_buckets.values() if bucket]
return buckets return buckets
@ -184,7 +194,10 @@ def calc_l2_norm(grads):
if APEX_AVAILABLE: if APEX_AVAILABLE:
dummy_overflow_buf = torch.cuda.IntTensor([0]) dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier( norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[grads],
False, # no per-parameter norm
) )
else: else:
norm, _ = multi_tensor_l2norm_torch(grads, False) norm, _ = multi_tensor_l2norm_torch(grads, False)
@ -228,7 +241,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
# Take max across all model-parallel GPUs. # Take max across all model-parallel GPUs.
if gpc.get_world_size(ParallelMode.MODEL) > 1: if gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL)) dist.all_reduce(
total_norm_cuda,
op=dist.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.MODEL),
)
total_norm = total_norm_cuda[0].item() total_norm = total_norm_cuda[0].item()
else: else:
tensor_parallel_grads = [] tensor_parallel_grads = []
@ -280,7 +297,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL): if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL)) dist.all_reduce(
total_norm,
op=dist.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.MODEL),
)
# This is because we use zero1, so we need to use this reduction. # This is because we use zero1, so we need to use this reduction.
# TODO: Check zero group to be a subset of dp group. # TODO: Check zero group to be a subset of dp group.
@ -459,3 +480,90 @@ class DynamicGradScaler(BaseGradScaler):
self._scale = self._scale.fill_(state_dict["_scale"]) self._scale = self._scale.fill_(state_dict["_scale"])
self._growth_step = state_dict["_growth_step"] self._growth_step = state_dict["_growth_step"]
self._hysteresis_step = state_dict["_hysteresis_step"] self._hysteresis_step = state_dict["_hysteresis_step"]
class ParamBcastSyncHandler:
"""
Model Partition Handler for overlap broadcast with forward
"""
def __init__(self, model: Union[nn.Module, nn.ModuleList]) -> None:
self._block_to_param = OrderedDict() # <key: nn.Module> <value: list(param)>
self._param_to_rank = dict() # <key: param> <value: rank)>
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
total_param_num = sum(p.numel() for p in model.parameters())
avg_param_num = total_param_num * 1.0 // zero1_size
# just want to share same for loop for ModuleList and Module
if not isinstance(model, nn.ModuleList):
model = [model]
# record the parameters to transformer/embeding/head/norm block
for _chunk in model:
if isinstance(_chunk, NaiveAMPModel):
_chunk = _chunk.model
for _, children in _chunk.named_children():
# should be the transformer block definaton in modeling_xxx.py
if isinstance(children, nn.ModuleList):
# record the block that a parameter belongs to
for _, block in enumerate(children):
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
self._block_to_param[block] = list(block.parameters())
else:
# record the block that a parameter belongs to
# self._block_to_param[name] = list(children.parameters())
self._block_to_param[children] = list(children.parameters())
alloc_num = 0
rank_to_go = 0
# process the parameters in block_to_param sequencially,
# allocate each parameter to a local rank of ParallelMode.ZERO1,
# NOTE that we do NOT consider following scenarios:
# 1) whether a parameter is trainable;
# 2) paramters maybe in different optimizer group
for block, params in self._block_to_param.items():
# allocate a model block to a local rank of ParallelMode.ZERO1
self._block_to_rank[block] = [rank_to_go]
for p in params:
alloc_num = alloc_num + p.numel()
# in this case, allocate the param to next rank if possible
if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1:
rank_to_go = rank_to_go + 1
alloc_num = 0
self._block_to_rank[block].append(rank_to_go)
# allocate a parameter to a local rank of ParallelMode.ZERO1
self._param_to_rank[p] = rank_to_go
# initialize an empty list for _bcast_handles of each rank
for rank in range(gpc.get_world_size(ParallelMode.ZERO1)):
self._bcast_handles[rank] = []
# register_forward_pre_hook for transformer/embeding/norm/xxx block
self._register_sync_parameters_hook()
def _register_sync_parameters_hook(self) -> None:
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
bcast_handles = []
# gather all required broadcast hanles into a list
for rank in self._block_to_rank[model]:
bcast_handles.extend(self._bcast_handles[rank])
# need to clear _bcast_handles since they would be processed later
self._bcast_handles[rank] = []
# wait all required broadcast handles to be completed
for handle in bcast_handles:
handle.wait()
# register_forward_pre_hook for transformer/embeding/norm/xxx block
for block, _ in self._block_to_rank.items():
block.register_forward_pre_hook(partial(_pre_forward_hook))
def get_rank_by_param(self, param) -> int:
return self._param_to_rank[param]
def add_bcast_handle(self, rank, handle) -> None:
self._bcast_handles[rank].append(handle)

View File

@ -14,18 +14,19 @@ class _Timer:
self.elapsed_ = 0.0 self.elapsed_ = 0.0
self.started_ = False self.started_ = False
self.start_time = time.time() self.start_time = time.time()
self.stream = torch.cuda.current_stream()
def start(self): def start(self):
"""Start the timer.""" """Start the timer."""
assert not self.started_, "timer has already been started" assert not self.started_, "timer has already been started"
torch.cuda.synchronize() self.stream.synchronize()
self.start_time = time.time() self.start_time = time.time()
self.started_ = True self.started_ = True
def stop(self): def stop(self):
"""Stop the timer.""" """Stop the timer."""
assert self.started_, "timer is not started" assert self.started_, "timer is not started"
torch.cuda.synchronize() self.stream.synchronize()
self.elapsed_ += time.time() - self.start_time self.elapsed_ += time.time() - self.start_time
self.started_ = False self.started_ = False

View File

@ -565,6 +565,7 @@ set load_ckpt_folder or use default value \
start = time.time() start = time.time()
self.set_save_folder(folder, train_state.step_count) self.set_save_folder(folder, train_state.step_count)
torch.cuda.synchronize()
torch.distributed.barrier() torch.distributed.barrier()
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...") logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")

View File

@ -5,7 +5,7 @@ import socket
import time import time
import traceback import traceback
from functools import partial from functools import partial
from typing import Iterable from typing import Iterable, Union
import numpy as np import numpy as np
import torch import torch
@ -36,6 +36,7 @@ from internlm.monitor.monitor import monitor_manager as mm
from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
from internlm.solver.optimizer import HybridZeroOptimizer from internlm.solver.optimizer import HybridZeroOptimizer
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
from internlm.utils.common import ( from internlm.utils.common import (
BatchSkipper, BatchSkipper,
DummyProfile, DummyProfile,
@ -291,7 +292,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
return batch, train_iter return batch, train_iter
def initialize_optimizer(model: nn.Module): def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
""" """
Initialize optimizer. Initialize optimizer.
@ -300,6 +301,7 @@ def initialize_optimizer(model: nn.Module):
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler). Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
""" """
param_bcast_sync_handler = ParamBcastSyncHandler(model)
adam_cfg = gpc.config.adam adam_cfg = gpc.config.adam
naive_optimizer = torch.optim.AdamW( naive_optimizer = torch.optim.AdamW(
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}], params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
@ -309,7 +311,10 @@ def initialize_optimizer(model: nn.Module):
) )
optimizer = HybridZeroOptimizer( optimizer = HybridZeroOptimizer(
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
param_bcast_sync_handler=param_bcast_sync_handler,
) )
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
@ -599,6 +604,7 @@ def main(args):
# do forward and backward # do forward and backward
timer("fwd-bwd").start() timer("fwd-bwd").start()
_, _, loss = trainer.execute_schedule( _, _, loss = trainer.execute_schedule(
batch, forward_only=False, return_loss=True, return_output_label=False batch, forward_only=False, return_loss=True, return_output_label=False
) )
@ -661,7 +667,8 @@ def main(args):
if memory_profiler is not None: if memory_profiler is not None:
memory_profiler.step() memory_profiler.step()
prof.step() if batch_count % 2 == 0:
prof.step()
ckpt_manager.wait_async_upload_finish() ckpt_manager.wait_async_upload_finish()