Feat/sync grad use async op (#277)

* fix/brocast should not in commu stream

* fix/brocast should not in commu stream

* feat: support allreduce grad using async op

* fix bug of async op

* use reduceop.avg

* use torch flat

* delete unused stream

* delete unused stream

* feat: overap allreduce with memcapy

---------

Co-authored-by: yingtongxiong <974106207@qq.com>
pull/298/head
Sun Peng 2023-09-07 21:51:30 +08:00 committed by GitHub
parent 7c99e01ca7
commit b7a8af8133
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 79 additions and 55 deletions

View File

@ -125,6 +125,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._param_store = ParameterStore(ParallelMode.ZERO1)
self._grad_store = GradientStore(ParallelMode.DATA)
self._bucket_store = BucketStore(ParallelMode.DATA)
self._bucket_in_progress = []
# fp16 and fp32 params for mixed precision training
self._fp16_param_groups = dict()
@ -232,13 +233,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
self.skip_grad_reduce = False
# initialize communication stream for
# communication-computation overlapping
if self._overlap_sync_grad:
self._comm_stream = torch.cuda.Stream()
else:
self._comm_stream = torch.cuda.current_stream()
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_sync_grad:
@ -384,34 +378,41 @@ class HybridZeroOptimizer(BaseOptimizer):
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
grad_buckets_by_dtype = split_half_float_double(grads)
next_bucket_list = []
# add parameters into bucket for reduction
for tensor_list in grad_buckets_by_dtype:
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if param_bucket.is_full_or_oversized():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
param_bucket.empty()
if not param_bucket.is_empty():
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
next_bucket_list.append(param_bucket)
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
# here we can also overlap the communication with some memcpy operation caused by bucket.flatten()
for bucket in self._bucket_in_progress:
bucket.commu_handle.wait()
bucket.unflatten_and_copy()
bucket.empty()
self._bucket_in_progress = []
self._param_store.clear_grads_of_previous_reduced_params()
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
self._bucket_in_progress = next_bucket_list.copy()
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_sync_grad:
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# flatten the tensors and do allreduce
bucket.flatten()
bucket.commu_handle = reduce_tensor(
tensor=bucket.get_flat_tensor(),
dtype=None,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.DATA,
)
with torch.cuda.stream(self._comm_stream):
flat = bucket.flatten()
reduced_flat = reduce_tensor(
tensor=flat,
dtype=self.dtype,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.DATA,
)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._zero_local_rank:
bucket.unflatten_and_copy(reduced_flat)
# update the reduced tensor
if reduce_rank is None or reduce_rank == self._zero_local_rank:
bucket.set_unflatten_and_copy_flag(flag=True)
def _has_inf_or_nan(self, tensor):
try:
@ -536,10 +537,13 @@ class HybridZeroOptimizer(BaseOptimizer):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
# clear reduced grads
if self._overlap_sync_grad:
# grads in the last bucket is reduced
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
# grads in the last bucket is reduced
for bucket in self._bucket_in_progress:
bucket.commu_handle.wait()
bucket.unflatten_and_copy()
bucket.empty()
self._bucket_in_progress = []
self._param_store.clear_grads_of_previous_reduced_params()
# compute norm for gradients in the last bucket
total_norms = {}
@ -626,7 +630,9 @@ class HybridZeroOptimizer(BaseOptimizer):
if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
self._unscale_and_clip_grads(
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
single_grad_partition_groups,
list(global_norm_groups.values()),
loss_scale,
)
# update the parameters

View File

@ -249,11 +249,17 @@ class ParameterStore(BaseStore):
if not last_bucket:
if group_id not in self._former_bucket_reduced_param:
return [], []
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
return (
self._former_bucket_reduced_param[group_id],
self._former_bucket_reduced_grad[group_id],
)
else:
if group_id not in self._last_bucket_reduced_param:
return [], []
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
return (
self._last_bucket_reduced_param[group_id],
self._last_bucket_reduced_grad[group_id],
)
def reset_reduced_data_for_compute_norm(self):
self._former_bucket_reduced_param = {}
@ -277,6 +283,9 @@ class TensorBucket:
self._max_size = size
self._current_size = 0
self._bucket = []
self._flat_tensor = None
self._unflatten_and_copy_flag = False
self.commu_handle = None
@property
def max_size(self):
@ -292,6 +301,15 @@ class TensorBucket:
def is_empty(self):
return len(self._bucket) == 0
def set_unflatten_and_copy_flag(self, flag):
self._unflatten_and_copy_flag = flag
def get_unflatten_and_copy_flag(self):
return self._unflatten_and_copy_flag
def get_flat_tensor(self):
return self._flat_tensor
def add_to_bucket(self, tensor, allow_oversize=False):
tensor_size = tensor.numel()
@ -312,11 +330,14 @@ class TensorBucket:
def empty(self):
self._bucket = []
self._size = 0
self._flat_tensor = None
self.commu_handle = None
def flatten(self):
return _flatten_dense_tensors(self._bucket)
self._flat_tensor = _flatten_dense_tensors(self._bucket)
def unflatten_and_copy(self, flat_tensor):
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
def unflatten_and_copy(self):
if self._unflatten_and_copy_flag:
unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)

View File

@ -95,37 +95,34 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
:type parallel_mode: ParallelMode, optional
"""
# use the original dtype
if dtype is None:
dtype = tensor.dtype
# if dtype is None:
assert dtype is None
dtype = tensor.dtype
# cast the data to specified dtype for reduce/all-reduce
if tensor.dtype != dtype:
tensor_to_reduce = tensor.to(dtype)
else:
tensor_to_reduce = tensor
# if tensor.dtype != dtype:
# tensor_to_reduce = tensor.to(dtype)
# else:
# tensor_to_reduce = tensor
world_size = gpc.get_world_size(parallel_mode)
# world_size = gpc.get_world_size(parallel_mode)
# tensor.div_(world_size)
group = gpc.get_group(parallel_mode)
tensor_to_reduce.div_(world_size)
# if rank is None, all reduce will be used
# else, reduce is used
use_all_reduce = dst_rank is None
if use_all_reduce:
dist.all_reduce(tensor_to_reduce, group=group)
handle = dist.all_reduce(tensor=tensor, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True)
else:
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
global_rank = ranks_in_group[dst_rank]
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
handle = dist.reduce(
tensor=tensor, dst=global_rank, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True
)
# recover the original dtype
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
local_rank = gpc.get_local_rank(parallel_mode)
if use_all_reduce or dst_rank == local_rank:
tensor.copy_(tensor_to_reduce)
return tensor
return handle
def has_inf_or_nan(tensor):