mirror of https://github.com/hpcaitech/ColossalAI
[Gemini] Use async stream to prefetch and h2d data moving (#5781)
* use async stream to prefetch and h2d data moving * Remove redundant codepull/5803/head
parent
8554585a5f
commit
d9dddf574f
|
@ -25,6 +25,7 @@ class ChunkManager:
|
|||
chunk_configuration,
|
||||
init_device: Optional[torch.device] = None,
|
||||
reuse_fp16_chunk: bool = True,
|
||||
max_prefetch: int = 0,
|
||||
) -> None:
|
||||
self.device = init_device or get_accelerator().get_current_device()
|
||||
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
|
||||
|
@ -42,6 +43,7 @@ class ChunkManager:
|
|||
# Whether model is accumulating gradients,
|
||||
self.accumulating_grads = False
|
||||
self.overflow_counter = torch.tensor([0], dtype=torch.int, device=get_accelerator().get_current_device())
|
||||
self._prefetch_stream = get_accelerator().Stream() if max_prefetch else None
|
||||
|
||||
def register_tensor(
|
||||
self,
|
||||
|
|
|
@ -21,6 +21,7 @@ def init_chunk_manager(
|
|||
hidden_dim: Optional[int] = None,
|
||||
reuse_fp16_chunk: bool = True,
|
||||
verbose: bool = False,
|
||||
max_prefetch: int = 0,
|
||||
**kwargs,
|
||||
) -> ChunkManager:
|
||||
if hidden_dim:
|
||||
|
@ -51,9 +52,5 @@ def init_chunk_manager(
|
|||
)
|
||||
dist.barrier()
|
||||
|
||||
chunk_manager = ChunkManager(
|
||||
config_dict,
|
||||
init_device,
|
||||
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||
)
|
||||
chunk_manager = ChunkManager(config_dict, init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch)
|
||||
return chunk_manager
|
||||
|
|
|
@ -104,9 +104,7 @@ class GeminiDDP(ModelWrapper):
|
|||
self.enable_gradient_accumulation = enable_gradient_accumulation
|
||||
if chunk_config_dict is not None:
|
||||
self.chunk_manager = ChunkManager(
|
||||
chunk_config_dict,
|
||||
chunk_init_device,
|
||||
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||
chunk_config_dict, chunk_init_device, reuse_fp16_chunk=reuse_fp16_chunk, max_prefetch=max_prefetch
|
||||
)
|
||||
else:
|
||||
# some ugly hotfix for the compatibility with Lightning
|
||||
|
@ -122,6 +120,7 @@ class GeminiDDP(ModelWrapper):
|
|||
process_group=zero_group,
|
||||
reuse_fp16_chunk=reuse_fp16_chunk,
|
||||
verbose=verbose,
|
||||
max_prefetch=max_prefetch,
|
||||
)
|
||||
self.gemini_manager = GeminiManager(
|
||||
placement_policy,
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import List
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
from colossalai.zero.gemini import TensorState
|
||||
|
@ -54,10 +55,11 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||
)
|
||||
|
||||
# prefetch
|
||||
for chunk in chunks_fetch_async:
|
||||
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||
if maybe_work is not None:
|
||||
self._gemini_manager.add_work(chunk, maybe_work)
|
||||
with get_accelerator().stream(self._gemini_manager.chunk_manager._prefetch_stream):
|
||||
for chunk in chunks_fetch_async:
|
||||
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||
if maybe_work is not None:
|
||||
self._gemini_manager.add_work(chunk, maybe_work)
|
||||
|
||||
# record cuda model data of the current OP, including memory for prefetched chunks
|
||||
self._gemini_manager.record_model_data_volume()
|
||||
|
|
Loading…
Reference in New Issue