Merge branch 'prefetch' of github.com:botbw/ColossalAI into botbw-prefetch

pull/5722/head
genghaozhe 2024-05-16 07:23:40 +00:00
commit 1f6b57099c
7 changed files with 93 additions and 22 deletions

View File

@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase):
chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase):
memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
master_weights=master_weights,
max_prefetch=max_prefetch,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,

View File

@ -357,14 +357,14 @@ class Chunk:
else:
raise NotImplementedError
def access_chunk(self):
def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
# sanity check
assert self.chunk_temp is None
if not self.is_gathered:
self.__gather()
return self.__gather(async_op=async_access)
self.__update_tensors_ptr()
return None
def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA."""
@ -498,17 +498,19 @@ class Chunk:
def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys())
def __gather(self):
def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
if not self.is_gathered:
# sanity check
assert self.cuda_shard is not None
alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
self.cuda_shard = None
self.is_gathered = True
return work
return None
def __scatter(self):
if self.keep_gathered:

View File

@ -111,15 +111,16 @@ class ChunkManager:
for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None:
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk can be used for calculation."""
if chunk in self.accessed_chunks:
return
return None
self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu":
chunk.shard_move(get_accelerator().get_current_device())
self.__add_accessed_chunk(chunk)
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
self.__add_memory_usage(chunk.memory_usage)
return maybe_work
def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA."""
@ -251,10 +252,11 @@ class ChunkManager:
for k, v in usage.items():
self.total_mem[k] += v
def __add_accessed_chunk(self, chunk: Chunk):
chunk.access_chunk()
def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
maybe_work = chunk.access_chunk(async_access=async_access)
self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem
return maybe_work
def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk()

View File

@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper):
chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static",
enable_gradient_accumulation: bool = False,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement
@ -132,7 +133,7 @@ class GeminiDDP(ModelWrapper):
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()

View File

@ -1,39 +1,67 @@
from contextlib import contextmanager
from enum import Enum
from functools import partial
from typing import List
from typing import Dict, List
import torch
import torch.distributed as dist
from colossalai.logging import DistributedLogger
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from .chunk import Chunk
class TrainingPhase(Enum):
FORWARD = 0
BACKWARD = 1
logger = DistributedLogger("gemini_hook")
class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None:
def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None:
super().__init__()
self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD
self._max_prefetch = max_prefetch
self._async_works: Dict[Chunk, dist.work] = {}
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
non_prefetched_chunks = []
for chunk in chunks:
if chunk in self._async_works:
print(f"prefetched {chunk.count_id}")
self._async_works[chunk].wait()
del self._async_works[chunk]
else:
non_prefetched_chunks.append(chunk)
return non_prefetched_chunks
def pre_op(self, params):
params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params)
all_chunks = self._chunk_manager.get_chunks(params)
# wait for prefetched chunks, filter those are not prefetched
chunks_fetch_sync = tuple(self.wait_chunks(all_chunks))
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks)
for chunk in chunks:
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0)
# fetch the rest chunks synchronously
for chunk in chunks_fetch_sync:
self._chunk_manager.access_chunk(chunk)
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch)
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._async_works[chunk] = maybe_work
# record cuda model data of the current OP
# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume()
def post_op(self, params):
@ -60,6 +88,11 @@ class GeminiZeROHook(ColoParamOpHook):
@contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
if training_phase == TrainingPhase.FORWARD:
self._cur_param_idx = 0
else:
self._cur_param_idx = len(self._param_visited_order) - 1
old_training_phase = self._training_phase
try:
self._training_phase = training_phase

View File

@ -6,7 +6,7 @@ import torch
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector, MemStats
from .placement_policy import PlacementPolicyFactory
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
class GeminiManager:
@ -91,13 +91,13 @@ class GeminiManager:
self._warmup = False
self.reset_attributes()
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
"""Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
self._layout_time += time() - start
@ -133,9 +133,9 @@ class GeminiManager:
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
self._compute_idx += 1
if self._warmup and self._placement_policy.need_mem_stats:
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
self._compute_list.append(chunks)
def sample_overall_data(self):
@ -156,6 +156,18 @@ class GeminiManager:
return self._mem_stats_collector.cuda_margin_mem
return None
@property
def compute_list(self) -> List[Tuple[Chunk, ...]]:
return self._compute_list
@property
def compute_idx(self) -> int:
return self._compute_idx
@property
def placement_policy(self) -> PlacementPolicy:
return self._placement_policy
@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats

View File

@ -33,6 +33,10 @@ class PlacementPolicy(ABC):
) -> None:
raise NotImplementedError
@abstractmethod
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
raise NotImplementedError
class StaticPlacementPolicy(PlacementPolicy):
def __init__(
@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy):
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
prefetch = []
for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)):
for chunk in self.chunk_manager.compute_list[i]:
if len(prefetch) >= max_prefetch:
break
if chunk not in prefetch:
prefetch.append(chunk)
if len(prefetch) >= max_prefetch:
break
return prefetch
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy):
else:
grads_device_map[p] = torch.device("cpu")
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
return [] # TODO @botbw: implement prefetching for auto
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {