mirror of https://github.com/InternLM/InternLM
Feat/overlap_bcast_forward (#218)
* feat/support bcast forward overlao * feat/optimize the bcast call * feat/optimize the bcast call * feat/optimize the bcast call * fix lint * fix lint * fix lint * fix lint * add torch.cuda.synchronize in save_checkpoint --------- Co-authored-by: sunpeng <sunpengsdu@gmail.com>pull/212/head
parent
a48210f1f3
commit
32664328e7
|
@ -3,6 +3,7 @@
|
|||
|
||||
import math
|
||||
from functools import partial
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -19,6 +20,7 @@ from internlm.solver.optimizer.store import (
|
|||
)
|
||||
from internlm.solver.optimizer.utils import (
|
||||
DynamicGradScaler,
|
||||
ParamBcastSyncHandler,
|
||||
flatten,
|
||||
get_grad_accumulate_object,
|
||||
has_inf_or_nan,
|
||||
|
@ -87,9 +89,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self,
|
||||
optimizer: Optimizer,
|
||||
cpu_offload=False,
|
||||
overlap_broadcast=False,
|
||||
grad_scal_cfg: Config = None,
|
||||
zero_cfg: Config = None,
|
||||
param_bcast_sync_handler: ParamBcastSyncHandler = None,
|
||||
):
|
||||
# DynamicGradScaler related args
|
||||
if gpc.config.model.dtype is torch.float32:
|
||||
|
@ -158,7 +160,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
+ f"zo-{self._zero_local_rank}.pt"
|
||||
)
|
||||
self.params_per_rank_id_dict = []
|
||||
self.overlap_broadcast = overlap_broadcast
|
||||
self._param_bcast_sync_handler = param_bcast_sync_handler
|
||||
if self._overlap_communication:
|
||||
assert self._param_bcast_sync_handler is not None
|
||||
|
||||
# iterate over the param group in the optimizer
|
||||
# partition these param groups for data parallel training
|
||||
|
@ -230,6 +234,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# communication-computation overlapping
|
||||
if self._overlap_communication:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
|
@ -267,8 +273,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
global_id = str(i)
|
||||
for j in range(len(param.size())):
|
||||
global_id = "_".join([global_id, str(param.size()[j])])
|
||||
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
if self._overlap_communication:
|
||||
rank_to_go = self._param_bcast_sync_handler.get_rank_by_param(param)
|
||||
else:
|
||||
rank_to_go = numel_per_rank.index(min(numel_per_rank))
|
||||
params_per_rank[rank_to_go].append(param)
|
||||
self.params_per_rank_id_dict[-1][rank_to_go].append(global_id)
|
||||
numel_per_rank[rank_to_go] += param.numel()
|
||||
|
@ -299,7 +307,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._grad_store.add_accumulate_grad_object(accum_grad_obj)
|
||||
|
||||
reduction_func = partial(
|
||||
self._store_and_try_reduce_grads_by_bucket, param=param, reduce_rank=reduce_rank
|
||||
self._store_and_try_reduce_grads_by_bucket,
|
||||
param=param,
|
||||
reduce_rank=reduce_rank,
|
||||
)
|
||||
|
||||
# define hook
|
||||
|
@ -385,16 +395,16 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_communication:
|
||||
stream = self._comm_stream
|
||||
stream.synchronize()
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
else:
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
with torch.cuda.stream(self._comm_stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(
|
||||
tensor=flat, dtype=self.dtype, dst_rank=reduce_rank, parallel_mode=ParallelMode.DATA
|
||||
tensor=flat,
|
||||
dtype=self.dtype,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=ParallelMode.DATA,
|
||||
)
|
||||
|
||||
# update the reduced tensor
|
||||
|
@ -532,7 +542,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
for group_id in range(self.num_param_groups):
|
||||
total_norms.append(
|
||||
self._compute_norm_with_stage(
|
||||
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
|
||||
group_id=group_id,
|
||||
last_bucket=True,
|
||||
last_stage=True,
|
||||
previous_norm=groups_norms[group_id],
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -562,7 +575,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if found_inf:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("Overflow occurs, please check it.")
|
||||
send_alert_message(address=gpc.config.alert_address, message="Overflow occurs, please check it.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address,
|
||||
message="Overflow occurs, please check it.",
|
||||
)
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
self.zero_grad()
|
||||
return False, None
|
||||
|
@ -625,35 +641,40 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
# TODO: support broadcast overlap
|
||||
self.broadcast_params(overlap=False)
|
||||
with torch.cuda.stream(self._comm_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
# update gradients may not be needed here, because the sync_params function is used in initialization,
|
||||
# so synchronization is maintained
|
||||
return True, [global_norm / loss_scale for global_norm in global_norm_groups]
|
||||
|
||||
def broadcast_params(self, overlap=False):
|
||||
def broadcast_params(self):
|
||||
handles = []
|
||||
|
||||
for group_id in range(self.num_param_groups):
|
||||
for rank in range(self._zero_world_size):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
|
||||
)
|
||||
handles.append(handle)
|
||||
for rank, group_id in product(range(self._zero_world_size), range(self.num_param_groups)):
|
||||
# The following operations are performed only on the rank to which parameters are assigned.
|
||||
if rank in self.param_group_no_params_ranks[group_id]:
|
||||
continue
|
||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
||||
# assert grank == rank, f"{grank} == {rank}"
|
||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
||||
handle = dist.broadcast(
|
||||
fp16_param,
|
||||
src=g_rank,
|
||||
group=gpc.get_group(ParallelMode.ZERO1),
|
||||
async_op=True,
|
||||
)
|
||||
|
||||
if not overlap:
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
else:
|
||||
return handles
|
||||
if self._overlap_communication:
|
||||
self._param_bcast_sync_handler.add_bcast_handle(rank, handle)
|
||||
else:
|
||||
handles.append(handle)
|
||||
|
||||
for handle in handles:
|
||||
handle.wait()
|
||||
|
||||
##################
|
||||
# FP16 Utilities #
|
||||
|
@ -671,7 +692,11 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if avg_grad is not None and has_inf_or_nan(avg_grad):
|
||||
self._found_overflow.fill_(1.0)
|
||||
break
|
||||
dist.all_reduce(self._found_overflow, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.GLOBAL))
|
||||
dist.all_reduce(
|
||||
self._found_overflow,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.GLOBAL),
|
||||
)
|
||||
|
||||
return self._found_overflow.item() > 0
|
||||
|
||||
|
|
|
@ -3,15 +3,18 @@
|
|||
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.common import get_tensor_norm, move_norm_to_cuda
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.parallel import is_model_parallel_parameter
|
||||
|
@ -60,12 +63,19 @@ def get_grad_accumulate_object(tensor):
|
|||
|
||||
|
||||
def split_half_float_double(tensor_list):
|
||||
dtypes = ["torch.cuda.HalfTensor", "torch.cuda.FloatTensor", "torch.cuda.DoubleTensor", "torch.cuda.BFloat16Tensor"]
|
||||
buckets = []
|
||||
for _, dtype in enumerate(dtypes):
|
||||
bucket = [t for t in tensor_list if t.type() == dtype]
|
||||
if bucket:
|
||||
buckets.append(bucket)
|
||||
dtype_buckets = {
|
||||
"torch.cuda.HalfTensor": [],
|
||||
"torch.cuda.FloatTensor": [],
|
||||
"torch.cuda.DoubleTensor": [],
|
||||
"torch.cuda.BFloat16Tensor": [],
|
||||
}
|
||||
|
||||
for t in tensor_list:
|
||||
dtype = t.type()
|
||||
if dtype in dtype_buckets:
|
||||
dtype_buckets[dtype].append(t)
|
||||
|
||||
buckets = [bucket for bucket in dtype_buckets.values() if bucket]
|
||||
return buckets
|
||||
|
||||
|
||||
|
@ -184,7 +194,10 @@ def calc_l2_norm(grads):
|
|||
if APEX_AVAILABLE:
|
||||
dummy_overflow_buf = torch.cuda.IntTensor([0])
|
||||
norm, _ = multi_tensor_applier(
|
||||
amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads], False # no per-parameter norm
|
||||
amp_C.multi_tensor_l2norm,
|
||||
dummy_overflow_buf,
|
||||
[grads],
|
||||
False, # no per-parameter norm
|
||||
)
|
||||
else:
|
||||
norm, _ = multi_tensor_l2norm_torch(grads, False)
|
||||
|
@ -228,7 +241,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
|
|||
|
||||
# Take max across all model-parallel GPUs.
|
||||
if gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
|
||||
dist.all_reduce(
|
||||
total_norm_cuda,
|
||||
op=dist.ReduceOp.MAX,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
)
|
||||
total_norm = total_norm_cuda[0].item()
|
||||
else:
|
||||
tensor_parallel_grads = []
|
||||
|
@ -280,7 +297,11 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no
|
|||
|
||||
# Sum across all model-parallel GPUs.
|
||||
if gpc.is_initialized(ParallelMode.MODEL):
|
||||
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))
|
||||
dist.all_reduce(
|
||||
total_norm,
|
||||
op=dist.ReduceOp.SUM,
|
||||
group=gpc.get_group(ParallelMode.MODEL),
|
||||
)
|
||||
|
||||
# This is because we use zero1, so we need to use this reduction.
|
||||
# TODO: Check zero group to be a subset of dp group.
|
||||
|
@ -459,3 +480,90 @@ class DynamicGradScaler(BaseGradScaler):
|
|||
self._scale = self._scale.fill_(state_dict["_scale"])
|
||||
self._growth_step = state_dict["_growth_step"]
|
||||
self._hysteresis_step = state_dict["_hysteresis_step"]
|
||||
|
||||
|
||||
class ParamBcastSyncHandler:
|
||||
"""
|
||||
Model Partition Handler for overlap broadcast with forward
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[nn.Module, nn.ModuleList]) -> None:
|
||||
self._block_to_param = OrderedDict() # <key: nn.Module> <value: list(param)>
|
||||
self._param_to_rank = dict() # <key: param> <value: rank)>
|
||||
self._block_to_rank = dict() # <key: nn.Module> <value: rank)>
|
||||
self._bcast_handles = dict() # <key: rank> <value: list(bcast handles))>
|
||||
|
||||
zero1_size = gpc.get_world_size(ParallelMode.ZERO1)
|
||||
total_param_num = sum(p.numel() for p in model.parameters())
|
||||
avg_param_num = total_param_num * 1.0 // zero1_size
|
||||
|
||||
# just want to share same for loop for ModuleList and Module
|
||||
if not isinstance(model, nn.ModuleList):
|
||||
model = [model]
|
||||
|
||||
# record the parameters to transformer/embeding/head/norm block
|
||||
for _chunk in model:
|
||||
if isinstance(_chunk, NaiveAMPModel):
|
||||
_chunk = _chunk.model
|
||||
|
||||
for _, children in _chunk.named_children():
|
||||
# should be the transformer block definaton in modeling_xxx.py
|
||||
if isinstance(children, nn.ModuleList):
|
||||
# record the block that a parameter belongs to
|
||||
for _, block in enumerate(children):
|
||||
# self._block_to_param[f"{name}.{idx}"] = list(block.parameters())
|
||||
self._block_to_param[block] = list(block.parameters())
|
||||
else:
|
||||
# record the block that a parameter belongs to
|
||||
# self._block_to_param[name] = list(children.parameters())
|
||||
self._block_to_param[children] = list(children.parameters())
|
||||
|
||||
alloc_num = 0
|
||||
rank_to_go = 0
|
||||
|
||||
# process the parameters in block_to_param sequencially,
|
||||
# allocate each parameter to a local rank of ParallelMode.ZERO1,
|
||||
# NOTE that we do NOT consider following scenarios:
|
||||
# 1) whether a parameter is trainable;
|
||||
# 2) paramters maybe in different optimizer group
|
||||
for block, params in self._block_to_param.items():
|
||||
# allocate a model block to a local rank of ParallelMode.ZERO1
|
||||
self._block_to_rank[block] = [rank_to_go]
|
||||
for p in params:
|
||||
alloc_num = alloc_num + p.numel()
|
||||
# in this case, allocate the param to next rank if possible
|
||||
if alloc_num > avg_param_num * 1.01 and rank_to_go < zero1_size - 1:
|
||||
rank_to_go = rank_to_go + 1
|
||||
alloc_num = 0
|
||||
self._block_to_rank[block].append(rank_to_go)
|
||||
# allocate a parameter to a local rank of ParallelMode.ZERO1
|
||||
self._param_to_rank[p] = rank_to_go
|
||||
|
||||
# initialize an empty list for _bcast_handles of each rank
|
||||
for rank in range(gpc.get_world_size(ParallelMode.ZERO1)):
|
||||
self._bcast_handles[rank] = []
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
self._register_sync_parameters_hook()
|
||||
|
||||
def _register_sync_parameters_hook(self) -> None:
|
||||
def _pre_forward_hook(model: nn.Module, inputs: Any): # pylint: disable=W0613
|
||||
bcast_handles = []
|
||||
# gather all required broadcast hanles into a list
|
||||
for rank in self._block_to_rank[model]:
|
||||
bcast_handles.extend(self._bcast_handles[rank])
|
||||
# need to clear _bcast_handles since they would be processed later
|
||||
self._bcast_handles[rank] = []
|
||||
# wait all required broadcast handles to be completed
|
||||
for handle in bcast_handles:
|
||||
handle.wait()
|
||||
|
||||
# register_forward_pre_hook for transformer/embeding/norm/xxx block
|
||||
for block, _ in self._block_to_rank.items():
|
||||
block.register_forward_pre_hook(partial(_pre_forward_hook))
|
||||
|
||||
def get_rank_by_param(self, param) -> int:
|
||||
return self._param_to_rank[param]
|
||||
|
||||
def add_bcast_handle(self, rank, handle) -> None:
|
||||
self._bcast_handles[rank].append(handle)
|
||||
|
|
|
@ -14,18 +14,19 @@ class _Timer:
|
|||
self.elapsed_ = 0.0
|
||||
self.started_ = False
|
||||
self.start_time = time.time()
|
||||
self.stream = torch.cuda.current_stream()
|
||||
|
||||
def start(self):
|
||||
"""Start the timer."""
|
||||
assert not self.started_, "timer has already been started"
|
||||
torch.cuda.synchronize()
|
||||
self.stream.synchronize()
|
||||
self.start_time = time.time()
|
||||
self.started_ = True
|
||||
|
||||
def stop(self):
|
||||
"""Stop the timer."""
|
||||
assert self.started_, "timer is not started"
|
||||
torch.cuda.synchronize()
|
||||
self.stream.synchronize()
|
||||
self.elapsed_ += time.time() - self.start_time
|
||||
self.started_ = False
|
||||
|
||||
|
|
|
@ -565,6 +565,7 @@ set load_ckpt_folder or use default value \
|
|||
|
||||
start = time.time()
|
||||
self.set_save_folder(folder, train_state.step_count)
|
||||
torch.cuda.synchronize()
|
||||
torch.distributed.barrier()
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")
|
||||
|
|
15
train.py
15
train.py
|
@ -5,7 +5,7 @@ import socket
|
|||
import time
|
||||
import traceback
|
||||
from functools import partial
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -36,6 +36,7 @@ from internlm.monitor.monitor import monitor_manager as mm
|
|||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer.utils import ParamBcastSyncHandler
|
||||
from internlm.utils.common import (
|
||||
BatchSkipper,
|
||||
DummyProfile,
|
||||
|
@ -291,7 +292,7 @@ def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: Trai
|
|||
return batch, train_iter
|
||||
|
||||
|
||||
def initialize_optimizer(model: nn.Module):
|
||||
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
||||
|
@ -300,6 +301,7 @@ def initialize_optimizer(model: nn.Module):
|
|||
|
||||
Returns: A tuple of (optimizer, beta2_scheduler, lr_scheduler).
|
||||
"""
|
||||
param_bcast_sync_handler = ParamBcastSyncHandler(model)
|
||||
adam_cfg = gpc.config.adam
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": adam_cfg.weight_decay}],
|
||||
|
@ -309,7 +311,10 @@ def initialize_optimizer(model: nn.Module):
|
|||
)
|
||||
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=gpc.config.grad_scaler,
|
||||
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||
param_bcast_sync_handler=param_bcast_sync_handler,
|
||||
)
|
||||
|
||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||
|
@ -599,6 +604,7 @@ def main(args):
|
|||
|
||||
# do forward and backward
|
||||
timer("fwd-bwd").start()
|
||||
|
||||
_, _, loss = trainer.execute_schedule(
|
||||
batch, forward_only=False, return_loss=True, return_output_label=False
|
||||
)
|
||||
|
@ -661,7 +667,8 @@ def main(args):
|
|||
if memory_profiler is not None:
|
||||
memory_profiler.step()
|
||||
|
||||
prof.step()
|
||||
if batch_count % 2 == 0:
|
||||
prof.step()
|
||||
|
||||
ckpt_manager.wait_async_upload_finish()
|
||||
|
||||
|
|
Loading…
Reference in New Issue