mirror of https://github.com/hpcaitech/ColossalAI
[gemini] prefetch chunks
parent
785cd9a9c9
commit
6e38eafebe
|
@ -357,14 +357,14 @@ class Chunk:
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def access_chunk(self):
|
||||
def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
|
||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
if not self.is_gathered:
|
||||
self.__gather()
|
||||
return self.__gather(async_op=async_access)
|
||||
self.__update_tensors_ptr()
|
||||
return None
|
||||
|
||||
def release_chunk(self):
|
||||
"""Release the usable chunk. It's an operation done in CUDA."""
|
||||
|
@ -498,17 +498,19 @@ class Chunk:
|
|||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
def __gather(self):
|
||||
def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
|
||||
if not self.is_gathered:
|
||||
# sanity check
|
||||
assert self.cuda_shard is not None
|
||||
|
||||
alloc_storage(self.cuda_global_chunk)
|
||||
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
||||
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
return work
|
||||
return None
|
||||
|
||||
def __scatter(self):
|
||||
if self.keep_gathered:
|
||||
|
|
|
@ -111,15 +111,16 @@ class ChunkManager:
|
|||
for group_name in self.chunk_groups:
|
||||
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
|
||||
"""Make the chunk can be used for calculation."""
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_accelerator().get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return maybe_work
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""Scatter the chunk in CUDA."""
|
||||
|
@ -251,10 +252,11 @@ class ChunkManager:
|
|||
for k, v in usage.items():
|
||||
self.total_mem[k] += v
|
||||
|
||||
def __add_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.access_chunk()
|
||||
def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
|
||||
maybe_work = chunk.access_chunk(async_access=async_access)
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.accessed_mem += chunk.chunk_mem
|
||||
return maybe_work
|
||||
|
||||
def __sub_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.release_chunk()
|
||||
|
|
|
@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper):
|
|||
chunk_init_device: torch.device = torch.device("cpu"),
|
||||
placement_policy: str = "static",
|
||||
enable_gradient_accumulation: bool = False,
|
||||
max_prefetch: int = 0,
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
|
@ -132,7 +133,6 @@ class GeminiDDP(ModelWrapper):
|
|||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||
)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
|
@ -157,6 +157,8 @@ class GeminiDDP(ModelWrapper):
|
|||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
|
|
|
@ -1,14 +1,18 @@
|
|||
from chunk import Chunk
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
from colossalai.zero.gemini import TensorState
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
|
@ -16,23 +20,92 @@ class TrainingPhase(Enum):
|
|||
BACKWARD = 1
|
||||
|
||||
|
||||
DEBUG = True # TODO @botbw: remove
|
||||
|
||||
|
||||
class GeminiZeROHook(ColoParamOpHook):
|
||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||
def __init__(
|
||||
self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._gemini_manager = gemini_manager
|
||||
self._chunk_manager = gemini_manager.chunk_manager
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
self._cur_param = None
|
||||
# param_visited_order might be updated somewhere else
|
||||
self._param_visited_order = param_order.param_visited_order
|
||||
self._max_prefetch = max_prefetch
|
||||
self._async_works: Dict[Chunk, dist.work] = {}
|
||||
|
||||
def pre_op(self, params):
|
||||
params = [p for p in params if not is_ddp_ignored(p)]
|
||||
chunks = self._chunk_manager.get_chunks(params)
|
||||
# used by get_prefetch_chunks to track current param
|
||||
self._cur_param_idx = 0
|
||||
|
||||
def get_prefetch_chunks(self, all_params: List[ColoParameter]) -> List[Chunk]:
|
||||
chunks_to_prefetch = set()
|
||||
if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase
|
||||
self._cur_param_idx += len(all_params) # need to update first
|
||||
idx = self._cur_param_idx + 1
|
||||
# still have params and prefetched chunks don't exceed the limit
|
||||
while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch:
|
||||
param = self._param_visited_order[idx]
|
||||
if is_ddp_ignored(param):
|
||||
idx += 1
|
||||
continue
|
||||
chunk = self._chunk_manager.get_chunk(param)
|
||||
chunks_to_prefetch.add(chunk)
|
||||
idx += 1
|
||||
else:
|
||||
assert self._training_phase == TrainingPhase.BACKWARD
|
||||
self._cur_param_idx -= len(all_params)
|
||||
idx = self._cur_param_idx - 1
|
||||
chunks_to_prefetch = set()
|
||||
while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch:
|
||||
param = self._param_visited_order[idx]
|
||||
if is_ddp_ignored(param):
|
||||
idx -= 1
|
||||
continue
|
||||
chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx])
|
||||
chunks_to_prefetch.add(chunk)
|
||||
idx -= 1
|
||||
return list(chunks_to_prefetch)
|
||||
|
||||
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
|
||||
non_prefetched_chunks = []
|
||||
for chunk in chunks:
|
||||
if chunk in self._async_works:
|
||||
self._async_works[chunk].wait()
|
||||
del self._async_works[chunk]
|
||||
else:
|
||||
non_prefetched_chunks.append(chunk)
|
||||
return non_prefetched_chunks
|
||||
|
||||
def pre_op(self, all_params):
|
||||
if DEBUG: # TODO @botbw: remove
|
||||
idxs = list(map(lambda x: self._linked_param_order.param_visited_order.index(x), all_params))
|
||||
mx = max(idxs)
|
||||
idxs = sorted(map(lambda x: x - mx, idxs))
|
||||
assert list(range(len(idxs))) == idxs, f"{idxs=}"
|
||||
|
||||
# deal with current needed chunks
|
||||
params = [p for p in all_params if not is_ddp_ignored(p)]
|
||||
all_chunks = self._chunk_manager.get_chunks(params)
|
||||
chunks_wo_work = self.wait_chunks(all_chunks)
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
self._gemini_manager.sample_overall_data()
|
||||
self._gemini_manager.adjust_layout(chunks)
|
||||
for chunk in chunks:
|
||||
self._gemini_manager.adjust_layout(chunks_wo_work)
|
||||
|
||||
# deal with chunks that are to be async fetched
|
||||
prefetch_chunks = self.get_prefetch_chunks(all_params)
|
||||
|
||||
# deal with chunks that are to be fetched now
|
||||
for chunk in chunks_wo_work:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
|
||||
# deal with chunks that are to be pre fetched TODO @botbw: the order here matters?
|
||||
for chunk in prefetch_chunks:
|
||||
self._async_works[chunk] = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||
|
||||
# record cuda model data of the current OP
|
||||
self._gemini_manager.record_model_data_volume()
|
||||
|
||||
|
|
Loading…
Reference in New Issue