mirror of https://github.com/hpcaitech/ColossalAI
[gemini] optimize reduce scatter d2h copy (#5760)
* [gemini] optimize reduce scatter d2h copy
* [fix] fix missing reduce variable
* [refactor] remove legacy async reduce scatter code
* [gemini] missing sync
* Revert "[refactor] remove legacy async reduce scatter code"
This reverts commit 58ad76d466
.
* [gemini] further optimize with async all reduce
* [fix] pass flag from manager to chunk
pull/5787/head
parent
10a19e22c6
commit
3f7e3131d9
|
@ -369,6 +369,11 @@ class GeminiPlugin(DPPluginBase):
|
|||
assert precision in SUPPORTED_PRECISION, f"precision {precision} is not supported"
|
||||
if get_accelerator().name == "npu":
|
||||
assert placement_policy == "static", "NPU only supports static placement policy"
|
||||
if placement_policy == "auto" and enable_async_reduce:
|
||||
logging.warning(
|
||||
f"enable_async_reduce requires pin_memory to achieve best performance, which is not implicitly set."
|
||||
)
|
||||
pin_memory = True
|
||||
self.gemini_config = dict(
|
||||
chunk_config_dict=chunk_config_dict,
|
||||
chunk_init_device=(chunk_init_device or get_accelerator().get_current_device()),
|
||||
|
|
|
@ -316,12 +316,13 @@ class Chunk:
|
|||
if self.shard_device.type == "cpu":
|
||||
self.cuda_shard = None
|
||||
|
||||
def shard_move(self, device: torch.device, force_copy: bool = False):
|
||||
def shard_move(self, device: torch.device, force_copy: bool = False, non_blocking=False):
|
||||
"""Move the shard tensor in the chunk.
|
||||
|
||||
Args:
|
||||
device: the device to which the shard will move
|
||||
force_copy: if True, copy function is called mandatorily
|
||||
non_blocking: if True, the operation is non-blocking, the caller is responsible for synchronization
|
||||
"""
|
||||
# sanity check
|
||||
assert not self.is_gathered
|
||||
|
@ -329,7 +330,7 @@ class Chunk:
|
|||
# just use another way for the movement
|
||||
if not self.optim_sync_flag:
|
||||
assert device.type == "cuda" or device.type == "npu", "each chunk should first be moved to CUDA"
|
||||
self.__paired_shard_move()
|
||||
self.__paired_shard_move(non_blocking=non_blocking)
|
||||
self.optim_sync_flag = True
|
||||
return
|
||||
|
||||
|
@ -339,7 +340,7 @@ class Chunk:
|
|||
if self.cuda_shard:
|
||||
return
|
||||
|
||||
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device())
|
||||
self.cuda_shard = self.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking)
|
||||
|
||||
if not self.pin_memory:
|
||||
self.cpu_shard = None
|
||||
|
@ -349,11 +350,11 @@ class Chunk:
|
|||
|
||||
if self.pin_memory:
|
||||
if force_copy or not self.cpu_vis_flag:
|
||||
self.cpu_shard.copy_(self.cuda_shard)
|
||||
self.cpu_shard.copy_(self.cuda_shard, non_blocking=non_blocking)
|
||||
# if cpu_shard has been visited
|
||||
# copy operation is not need
|
||||
else:
|
||||
self.cpu_shard = self.cuda_shard.cpu()
|
||||
self.cpu_shard = self.cuda_shard.to("cpu", non_blocking=non_blocking)
|
||||
self.cpu_vis_flag = True
|
||||
self.cuda_shard = None
|
||||
else:
|
||||
|
@ -542,7 +543,7 @@ class Chunk:
|
|||
free_storage(self.cuda_global_chunk)
|
||||
self.is_gathered = False
|
||||
|
||||
def __paired_shard_move(self):
|
||||
def __paired_shard_move(self, non_blocking=False):
|
||||
assert self.paired_chunk is not None, "chunks should be paired before training"
|
||||
optim_chunk = self.paired_chunk
|
||||
assert self.chunk_size == optim_chunk.chunk_size
|
||||
|
@ -550,7 +551,7 @@ class Chunk:
|
|||
# only be called when optimizer state is in CPU memory
|
||||
# the grad and param should be in the same device
|
||||
assert self.cuda_shard is None
|
||||
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device())
|
||||
temp = optim_chunk.cpu_shard.to(get_accelerator().get_current_device(), non_blocking=non_blocking)
|
||||
# avoid to transform FP32 in CPU
|
||||
self.cuda_shard = temp.to(self.dtype)
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ class ChunkManager:
|
|||
return None
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_accelerator().get_current_device())
|
||||
chunk.shard_move(get_accelerator().get_current_device(), non_blocking=async_access)
|
||||
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return maybe_work
|
||||
|
|
|
@ -147,6 +147,12 @@ class GeminiDDP(ModelWrapper):
|
|||
self.extra_dp_group = extra_dp_group
|
||||
|
||||
self.master_weights = master_weights
|
||||
self.enable_async_reduce = enable_async_reduce
|
||||
|
||||
if enable_async_reduce:
|
||||
self.async_reduce_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self.async_reduce_stream = None
|
||||
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
|
@ -176,6 +182,7 @@ class GeminiDDP(ModelWrapper):
|
|||
super().__init__(module)
|
||||
self._non_persistent_buffers_set = self._get_non_persistent_buffers_set(module)
|
||||
self._cast_buffers()
|
||||
|
||||
# register grad hook
|
||||
for p in module.parameters():
|
||||
if is_ddp_ignored(p):
|
||||
|
@ -191,7 +198,7 @@ class GeminiDDP(ModelWrapper):
|
|||
master_weights=self.master_weights,
|
||||
enable_gradient_accumulation=self.enable_gradient_accumulation,
|
||||
p=p,
|
||||
async_reduce=enable_async_reduce,
|
||||
async_reduce_stream=self.async_reduce_stream,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -339,10 +346,8 @@ class GeminiDDP(ModelWrapper):
|
|||
setattr(param, "_gemini_reduced", False)
|
||||
|
||||
def _post_backward(self):
|
||||
for param in self.param2name:
|
||||
if hasattr(param, "_release_grad_chunk_cb"):
|
||||
param._release_grad_chunk_cb()
|
||||
delattr(param, "_release_grad_chunk_cb")
|
||||
if self.enable_async_reduce:
|
||||
self.async_reduce_stream.synchronize()
|
||||
|
||||
if self.chunk_manager.accessed_mem != 0:
|
||||
error_params = ["Reduction failed at followed parameters:"]
|
||||
|
@ -381,7 +386,7 @@ class GeminiDDP(ModelWrapper):
|
|||
master_weights: bool,
|
||||
enable_gradient_accumulation: bool,
|
||||
p: nn.Parameter,
|
||||
async_reduce: bool,
|
||||
async_reduce_stream: Optional[torch.cuda.Stream] = None,
|
||||
):
|
||||
setattr(p, "_gemini_reduced", True)
|
||||
empty_grad = torch.empty_like(grad)
|
||||
|
@ -417,56 +422,35 @@ class GeminiDDP(ModelWrapper):
|
|||
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
|
||||
else:
|
||||
grad_chunk.add_tensor_to_chunk_slice(p, grad)
|
||||
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=async_reduce)
|
||||
if reduced: # if not async, can release immediately, else release in when work finished
|
||||
if async_reduce:
|
||||
# dirty fix by installing callback
|
||||
assert not hasattr(p, "_release_grad_chunk_cb")
|
||||
|
||||
def _release_grad_chunk_cb():
|
||||
grad_chunk.wait_async_reduce()
|
||||
GeminiDDP.release_grad_chunk_handle(
|
||||
chunk_manager,
|
||||
grads_device,
|
||||
master_weights,
|
||||
enable_gradient_accumulation,
|
||||
p,
|
||||
chunk,
|
||||
grad_chunk,
|
||||
)
|
||||
if async_reduce_stream is not None:
|
||||
async_reduce_stream.wait_stream(torch.cuda.current_stream())
|
||||
|
||||
p._release_grad_chunk_cb = _release_grad_chunk_cb
|
||||
else:
|
||||
GeminiDDP.release_grad_chunk_handle(
|
||||
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
|
||||
)
|
||||
return empty_grad
|
||||
|
||||
@staticmethod
|
||||
def release_grad_chunk_handle(
|
||||
chunk_manager, grads_device, master_weights, enable_gradient_accumulation, p, chunk, grad_chunk
|
||||
):
|
||||
if not chunk_manager.reuse_fp16_chunk:
|
||||
if chunk.keep_gathered:
|
||||
chunk_manager.fake_release_chunk(chunk)
|
||||
else:
|
||||
chunk_manager.release_chunk(chunk)
|
||||
if grad_chunk.is_gathered:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
if chunk.extra_dp_group is not None:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
||||
else:
|
||||
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
||||
if chunk.extra_dp_group is not None:
|
||||
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
||||
# check overflow elements
|
||||
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||
if chunk.l2_norm_flag:
|
||||
grad_chunk.set_l2_norm()
|
||||
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
||||
if not (master_weights) or (enable_gradient_accumulation):
|
||||
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
||||
with torch.cuda.stream(async_reduce_stream):
|
||||
reduced = chunk_manager.reduce_chunk(grad_chunk, async_op=(async_reduce_stream is not None))
|
||||
if reduced:
|
||||
grad_chunk.wait_async_reduce()
|
||||
if not chunk_manager.reuse_fp16_chunk:
|
||||
if chunk.keep_gathered:
|
||||
chunk_manager.fake_release_chunk(chunk)
|
||||
else:
|
||||
chunk_manager.release_chunk(chunk)
|
||||
if grad_chunk.is_gathered:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
|
||||
if chunk.extra_dp_group is not None:
|
||||
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
|
||||
else:
|
||||
grad_chunk.cuda_shard.div_(chunk.pg_size)
|
||||
if chunk.extra_dp_group is not None:
|
||||
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
|
||||
# check overflow elements
|
||||
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
|
||||
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
|
||||
if chunk.l2_norm_flag:
|
||||
grad_chunk.set_l2_norm()
|
||||
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
|
||||
if not (master_weights) or (enable_gradient_accumulation):
|
||||
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
|
Loading…
Reference in New Issue