mirror of https://github.com/InternLM/InternLM
Feat/sync grad use async op (#277)
* fix/brocast should not in commu stream * fix/brocast should not in commu stream * feat: support allreduce grad using async op * fix bug of async op * use reduceop.avg * use torch flat * delete unused stream * delete unused stream * feat: overap allreduce with memcapy --------- Co-authored-by: yingtongxiong <974106207@qq.com>pull/298/head
parent
7c99e01ca7
commit
b7a8af8133
|
@ -125,6 +125,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||||
self._bucket_store = BucketStore(ParallelMode.DATA)
|
self._bucket_store = BucketStore(ParallelMode.DATA)
|
||||||
|
self._bucket_in_progress = []
|
||||||
|
|
||||||
# fp16 and fp32 params for mixed precision training
|
# fp16 and fp32 params for mixed precision training
|
||||||
self._fp16_param_groups = dict()
|
self._fp16_param_groups = dict()
|
||||||
|
@ -232,13 +233,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
||||||
self.skip_grad_reduce = False
|
self.skip_grad_reduce = False
|
||||||
|
|
||||||
# initialize communication stream for
|
|
||||||
# communication-computation overlapping
|
|
||||||
if self._overlap_sync_grad:
|
|
||||||
self._comm_stream = torch.cuda.Stream()
|
|
||||||
else:
|
|
||||||
self._comm_stream = torch.cuda.current_stream()
|
|
||||||
|
|
||||||
# reduction hook is only used if overlapping communication
|
# reduction hook is only used if overlapping communication
|
||||||
# if it is stage 1 without overlapping, no hook will be attached
|
# if it is stage 1 without overlapping, no hook will be attached
|
||||||
if self._overlap_sync_grad:
|
if self._overlap_sync_grad:
|
||||||
|
@ -384,34 +378,41 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||||
|
next_bucket_list = []
|
||||||
|
# add parameters into bucket for reduction
|
||||||
for tensor_list in grad_buckets_by_dtype:
|
for tensor_list in grad_buckets_by_dtype:
|
||||||
param_bucket = TensorBucket(size=bucket_size)
|
param_bucket = TensorBucket(size=bucket_size)
|
||||||
for tensor in tensor_list:
|
for tensor in tensor_list:
|
||||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||||
if param_bucket.is_full_or_oversized():
|
|
||||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
|
||||||
param_bucket.empty()
|
|
||||||
if not param_bucket.is_empty():
|
if not param_bucket.is_empty():
|
||||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||||
|
next_bucket_list.append(param_bucket)
|
||||||
|
|
||||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
|
||||||
if self._overlap_sync_grad:
|
# here we can also overlap the communication with some memcpy operation caused by bucket.flatten()
|
||||||
self._comm_stream.synchronize()
|
for bucket in self._bucket_in_progress:
|
||||||
|
bucket.commu_handle.wait()
|
||||||
|
bucket.unflatten_and_copy()
|
||||||
|
bucket.empty()
|
||||||
|
self._bucket_in_progress = []
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
self._param_store.clear_grads_of_previous_reduced_params()
|
||||||
|
|
||||||
with torch.cuda.stream(self._comm_stream):
|
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
|
||||||
flat = bucket.flatten()
|
self._bucket_in_progress = next_bucket_list.copy()
|
||||||
reduced_flat = reduce_tensor(
|
|
||||||
tensor=flat,
|
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||||
dtype=self.dtype,
|
# flatten the tensors and do allreduce
|
||||||
|
bucket.flatten()
|
||||||
|
bucket.commu_handle = reduce_tensor(
|
||||||
|
tensor=bucket.get_flat_tensor(),
|
||||||
|
dtype=None,
|
||||||
dst_rank=reduce_rank,
|
dst_rank=reduce_rank,
|
||||||
parallel_mode=ParallelMode.DATA,
|
parallel_mode=ParallelMode.DATA,
|
||||||
)
|
)
|
||||||
|
|
||||||
# update the reduced tensor
|
# update the reduced tensor
|
||||||
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
||||||
bucket.unflatten_and_copy(reduced_flat)
|
bucket.set_unflatten_and_copy_flag(flag=True)
|
||||||
|
|
||||||
def _has_inf_or_nan(self, tensor):
|
def _has_inf_or_nan(self, tensor):
|
||||||
try:
|
try:
|
||||||
|
@ -536,9 +537,12 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||||
|
|
||||||
# clear reduced grads
|
# clear reduced grads
|
||||||
if self._overlap_sync_grad:
|
|
||||||
# grads in the last bucket is reduced
|
# grads in the last bucket is reduced
|
||||||
self._comm_stream.synchronize()
|
for bucket in self._bucket_in_progress:
|
||||||
|
bucket.commu_handle.wait()
|
||||||
|
bucket.unflatten_and_copy()
|
||||||
|
bucket.empty()
|
||||||
|
self._bucket_in_progress = []
|
||||||
self._param_store.clear_grads_of_previous_reduced_params()
|
self._param_store.clear_grads_of_previous_reduced_params()
|
||||||
|
|
||||||
# compute norm for gradients in the last bucket
|
# compute norm for gradients in the last bucket
|
||||||
|
@ -626,7 +630,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if gpc.config.model.dtype is not torch.float32:
|
if gpc.config.model.dtype is not torch.float32:
|
||||||
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
|
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
|
||||||
self._unscale_and_clip_grads(
|
self._unscale_and_clip_grads(
|
||||||
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
|
single_grad_partition_groups,
|
||||||
|
list(global_norm_groups.values()),
|
||||||
|
loss_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# update the parameters
|
# update the parameters
|
||||||
|
|
|
@ -249,11 +249,17 @@ class ParameterStore(BaseStore):
|
||||||
if not last_bucket:
|
if not last_bucket:
|
||||||
if group_id not in self._former_bucket_reduced_param:
|
if group_id not in self._former_bucket_reduced_param:
|
||||||
return [], []
|
return [], []
|
||||||
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
|
return (
|
||||||
|
self._former_bucket_reduced_param[group_id],
|
||||||
|
self._former_bucket_reduced_grad[group_id],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
if group_id not in self._last_bucket_reduced_param:
|
if group_id not in self._last_bucket_reduced_param:
|
||||||
return [], []
|
return [], []
|
||||||
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
|
return (
|
||||||
|
self._last_bucket_reduced_param[group_id],
|
||||||
|
self._last_bucket_reduced_grad[group_id],
|
||||||
|
)
|
||||||
|
|
||||||
def reset_reduced_data_for_compute_norm(self):
|
def reset_reduced_data_for_compute_norm(self):
|
||||||
self._former_bucket_reduced_param = {}
|
self._former_bucket_reduced_param = {}
|
||||||
|
@ -277,6 +283,9 @@ class TensorBucket:
|
||||||
self._max_size = size
|
self._max_size = size
|
||||||
self._current_size = 0
|
self._current_size = 0
|
||||||
self._bucket = []
|
self._bucket = []
|
||||||
|
self._flat_tensor = None
|
||||||
|
self._unflatten_and_copy_flag = False
|
||||||
|
self.commu_handle = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_size(self):
|
def max_size(self):
|
||||||
|
@ -292,6 +301,15 @@ class TensorBucket:
|
||||||
def is_empty(self):
|
def is_empty(self):
|
||||||
return len(self._bucket) == 0
|
return len(self._bucket) == 0
|
||||||
|
|
||||||
|
def set_unflatten_and_copy_flag(self, flag):
|
||||||
|
self._unflatten_and_copy_flag = flag
|
||||||
|
|
||||||
|
def get_unflatten_and_copy_flag(self):
|
||||||
|
return self._unflatten_and_copy_flag
|
||||||
|
|
||||||
|
def get_flat_tensor(self):
|
||||||
|
return self._flat_tensor
|
||||||
|
|
||||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
def add_to_bucket(self, tensor, allow_oversize=False):
|
||||||
tensor_size = tensor.numel()
|
tensor_size = tensor.numel()
|
||||||
|
|
||||||
|
@ -312,11 +330,14 @@ class TensorBucket:
|
||||||
def empty(self):
|
def empty(self):
|
||||||
self._bucket = []
|
self._bucket = []
|
||||||
self._size = 0
|
self._size = 0
|
||||||
|
self._flat_tensor = None
|
||||||
|
self.commu_handle = None
|
||||||
|
|
||||||
def flatten(self):
|
def flatten(self):
|
||||||
return _flatten_dense_tensors(self._bucket)
|
self._flat_tensor = _flatten_dense_tensors(self._bucket)
|
||||||
|
|
||||||
def unflatten_and_copy(self, flat_tensor):
|
def unflatten_and_copy(self):
|
||||||
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
|
if self._unflatten_and_copy_flag:
|
||||||
|
unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket)
|
||||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||||
old.copy_(new)
|
old.copy_(new)
|
||||||
|
|
|
@ -95,37 +95,34 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
|
||||||
:type parallel_mode: ParallelMode, optional
|
:type parallel_mode: ParallelMode, optional
|
||||||
"""
|
"""
|
||||||
# use the original dtype
|
# use the original dtype
|
||||||
if dtype is None:
|
# if dtype is None:
|
||||||
|
assert dtype is None
|
||||||
dtype = tensor.dtype
|
dtype = tensor.dtype
|
||||||
|
|
||||||
# cast the data to specified dtype for reduce/all-reduce
|
# cast the data to specified dtype for reduce/all-reduce
|
||||||
if tensor.dtype != dtype:
|
# if tensor.dtype != dtype:
|
||||||
tensor_to_reduce = tensor.to(dtype)
|
# tensor_to_reduce = tensor.to(dtype)
|
||||||
else:
|
# else:
|
||||||
tensor_to_reduce = tensor
|
# tensor_to_reduce = tensor
|
||||||
|
|
||||||
world_size = gpc.get_world_size(parallel_mode)
|
# world_size = gpc.get_world_size(parallel_mode)
|
||||||
|
# tensor.div_(world_size)
|
||||||
group = gpc.get_group(parallel_mode)
|
group = gpc.get_group(parallel_mode)
|
||||||
tensor_to_reduce.div_(world_size)
|
|
||||||
|
|
||||||
# if rank is None, all reduce will be used
|
# if rank is None, all reduce will be used
|
||||||
# else, reduce is used
|
# else, reduce is used
|
||||||
use_all_reduce = dst_rank is None
|
use_all_reduce = dst_rank is None
|
||||||
|
|
||||||
if use_all_reduce:
|
if use_all_reduce:
|
||||||
dist.all_reduce(tensor_to_reduce, group=group)
|
handle = dist.all_reduce(tensor=tensor, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True)
|
||||||
else:
|
else:
|
||||||
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
|
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
|
||||||
global_rank = ranks_in_group[dst_rank]
|
global_rank = ranks_in_group[dst_rank]
|
||||||
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
|
handle = dist.reduce(
|
||||||
|
tensor=tensor, dst=global_rank, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True
|
||||||
|
)
|
||||||
|
|
||||||
# recover the original dtype
|
return handle
|
||||||
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
|
||||||
local_rank = gpc.get_local_rank(parallel_mode)
|
|
||||||
if use_all_reduce or dst_rank == local_rank:
|
|
||||||
tensor.copy_(tensor_to_reduce)
|
|
||||||
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
|
|
||||||
def has_inf_or_nan(tensor):
|
def has_inf_or_nan(tensor):
|
||||||
|
|
Loading…
Reference in New Issue