[Gemini] Use async stream to prefetch and h2d data moving (#5781)

* use async stream to prefetch and h2d data moving

* Remove redundant code
pull/5803/head
Haze188 2024-06-12 15:48:52 +08:00 committed by GitHub
parent 8554585a5f
commit d9dddf574f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 12 additions and 12 deletions

View File

@ -25,6 +25,7 @@ class ChunkManager:
chunk_configuration,
init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True,
max_prefetch: int = 0,
) -> None:
self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
@ -42,6 +43,7 @@ class ChunkManager:
# Whether model is accumulating gradients,
self.accumulating_grads = False
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
def register_tensor(
self,

View File

@ -21,6 +21,7 @@ def init_chunk_manager(
hidden_dim: Optional[int] = None,
reuse_fp16_chunk: bool = True,
verbose: bool = False,
max_prefetch: int = 0,
**kwargs,
) -> ChunkManager:
if hidden_dim:
@ -51,9 +52,5 @@ def init_chunk_manager(
)
dist.barrier()
chunk_manager = ChunkManager(
config_dict,
init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
)
chunk_manager = ChunkManager(config_dict, init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch)
return chunk_manager

View File

@ -104,9 +104,7 @@ class GeminiDDP(ModelWrapper):
self.enable_gradient_accumulation = enable_gradient_accumulation
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(
chunk_config_dict,
chunk_init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
chunk_config_dict, chunk_init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch
)
else:
# some ugly hotfix for the compatibility with Lightning
@ -122,6 +120,7 @@ class GeminiDDP(ModelWrapper):
process_group=zero_group,
reuse_fp16_chunk=reuse_fp16_chunk,
verbose=verbose,
max_prefetch=max_prefetch,
)
self.gemini_manager = GeminiManager(
placement_policy,

View File

@ -5,6 +5,7 @@ from typing import List
import torch
from colossalai.accelerator import get_accelerator
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
@ -54,10 +55,11 @@ class GeminiZeROHook(ColoParamOpHook):
)
# prefetch
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)
with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)
# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume()