mirror of https://github.com/InternLM/InternLM
Feat/sync grad use async op (#277)
* fix/brocast should not in commu stream * fix/brocast should not in commu stream * feat: support allreduce grad using async op * fix bug of async op * use reduceop.avg * use torch flat * delete unused stream * delete unused stream * feat: overap allreduce with memcapy --------- Co-authored-by: yingtongxiong <974106207@qq.com>pull/298/head
parent
7c99e01ca7
commit
b7a8af8133
|
@ -125,6 +125,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||
self._bucket_store = BucketStore(ParallelMode.DATA)
|
||||
self._bucket_in_progress = []
|
||||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
self._fp16_param_groups = dict()
|
||||
|
@ -232,13 +233,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
||||
self.skip_grad_reduce = False
|
||||
|
||||
# initialize communication stream for
|
||||
# communication-computation overlapping
|
||||
if self._overlap_sync_grad:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
if self._overlap_sync_grad:
|
||||
|
@ -384,34 +378,41 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
|
||||
next_bucket_list = []
|
||||
# add parameters into bucket for reduction
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
next_bucket_list.append(param_bucket)
|
||||
|
||||
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
|
||||
# here we can also overlap the communication with some memcpy operation caused by bucket.flatten()
|
||||
for bucket in self._bucket_in_progress:
|
||||
bucket.commu_handle.wait()
|
||||
bucket.unflatten_and_copy()
|
||||
bucket.empty()
|
||||
self._bucket_in_progress = []
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
|
||||
self._bucket_in_progress = next_bucket_list.copy()
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_sync_grad:
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
# flatten the tensors and do allreduce
|
||||
bucket.flatten()
|
||||
bucket.commu_handle = reduce_tensor(
|
||||
tensor=bucket.get_flat_tensor(),
|
||||
dtype=None,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=ParallelMode.DATA,
|
||||
)
|
||||
|
||||
with torch.cuda.stream(self._comm_stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(
|
||||
tensor=flat,
|
||||
dtype=self.dtype,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=ParallelMode.DATA,
|
||||
)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
||||
bucket.set_unflatten_and_copy_flag(flag=True)
|
||||
|
||||
def _has_inf_or_nan(self, tensor):
|
||||
try:
|
||||
|
@ -536,10 +537,13 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_sync_grad:
|
||||
# grads in the last bucket is reduced
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
# grads in the last bucket is reduced
|
||||
for bucket in self._bucket_in_progress:
|
||||
bucket.commu_handle.wait()
|
||||
bucket.unflatten_and_copy()
|
||||
bucket.empty()
|
||||
self._bucket_in_progress = []
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# compute norm for gradients in the last bucket
|
||||
total_norms = {}
|
||||
|
@ -626,7 +630,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if gpc.config.model.dtype is not torch.float32:
|
||||
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
|
||||
self._unscale_and_clip_grads(
|
||||
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
|
||||
single_grad_partition_groups,
|
||||
list(global_norm_groups.values()),
|
||||
loss_scale,
|
||||
)
|
||||
|
||||
# update the parameters
|
||||
|
|
|
@ -249,11 +249,17 @@ class ParameterStore(BaseStore):
|
|||
if not last_bucket:
|
||||
if group_id not in self._former_bucket_reduced_param:
|
||||
return [], []
|
||||
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
|
||||
return (
|
||||
self._former_bucket_reduced_param[group_id],
|
||||
self._former_bucket_reduced_grad[group_id],
|
||||
)
|
||||
else:
|
||||
if group_id not in self._last_bucket_reduced_param:
|
||||
return [], []
|
||||
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
|
||||
return (
|
||||
self._last_bucket_reduced_param[group_id],
|
||||
self._last_bucket_reduced_grad[group_id],
|
||||
)
|
||||
|
||||
def reset_reduced_data_for_compute_norm(self):
|
||||
self._former_bucket_reduced_param = {}
|
||||
|
@ -277,6 +283,9 @@ class TensorBucket:
|
|||
self._max_size = size
|
||||
self._current_size = 0
|
||||
self._bucket = []
|
||||
self._flat_tensor = None
|
||||
self._unflatten_and_copy_flag = False
|
||||
self.commu_handle = None
|
||||
|
||||
@property
|
||||
def max_size(self):
|
||||
|
@ -292,6 +301,15 @@ class TensorBucket:
|
|||
def is_empty(self):
|
||||
return len(self._bucket) == 0
|
||||
|
||||
def set_unflatten_and_copy_flag(self, flag):
|
||||
self._unflatten_and_copy_flag = flag
|
||||
|
||||
def get_unflatten_and_copy_flag(self):
|
||||
return self._unflatten_and_copy_flag
|
||||
|
||||
def get_flat_tensor(self):
|
||||
return self._flat_tensor
|
||||
|
||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
|
@ -312,11 +330,14 @@ class TensorBucket:
|
|||
def empty(self):
|
||||
self._bucket = []
|
||||
self._size = 0
|
||||
self._flat_tensor = None
|
||||
self.commu_handle = None
|
||||
|
||||
def flatten(self):
|
||||
return _flatten_dense_tensors(self._bucket)
|
||||
self._flat_tensor = _flatten_dense_tensors(self._bucket)
|
||||
|
||||
def unflatten_and_copy(self, flat_tensor):
|
||||
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
||||
def unflatten_and_copy(self):
|
||||
if self._unflatten_and_copy_flag:
|
||||
unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket)
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
||||
|
|
|
@ -95,37 +95,34 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
|
|||
:type parallel_mode: ParallelMode, optional
|
||||
"""
|
||||
# use the original dtype
|
||||
if dtype is None:
|
||||
dtype = tensor.dtype
|
||||
# if dtype is None:
|
||||
assert dtype is None
|
||||
dtype = tensor.dtype
|
||||
|
||||
# cast the data to specified dtype for reduce/all-reduce
|
||||
if tensor.dtype != dtype:
|
||||
tensor_to_reduce = tensor.to(dtype)
|
||||
else:
|
||||
tensor_to_reduce = tensor
|
||||
# if tensor.dtype != dtype:
|
||||
# tensor_to_reduce = tensor.to(dtype)
|
||||
# else:
|
||||
# tensor_to_reduce = tensor
|
||||
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
# world_size = gpc.get_world_size(parallel_mode)
|
||||
# tensor.div_(world_size)
|
||||
group = gpc.get_group(parallel_mode)
|
||||
tensor_to_reduce.div_(world_size)
|
||||
|
||||
# if rank is None, all reduce will be used
|
||||
# else, reduce is used
|
||||
use_all_reduce = dst_rank is None
|
||||
|
||||
if use_all_reduce:
|
||||
dist.all_reduce(tensor_to_reduce, group=group)
|
||||
handle = dist.all_reduce(tensor=tensor, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True)
|
||||
else:
|
||||
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
|
||||
global_rank = ranks_in_group[dst_rank]
|
||||
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
|
||||
handle = dist.reduce(
|
||||
tensor=tensor, dst=global_rank, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True
|
||||
)
|
||||
|
||||
# recover the original dtype
|
||||
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
||||
local_rank = gpc.get_local_rank(parallel_mode)
|
||||
if use_all_reduce or dst_rank == local_rank:
|
||||
tensor.copy_(tensor_to_reduce)
|
||||
|
||||
return tensor
|
||||
return handle
|
||||
|
||||
|
||||
def has_inf_or_nan(tensor):
|
||||
|
|
Loading…
Reference in New Issue