Merge branch 'prefetch' of github.com:botbw/ColossalAI into feature/prefetch

pull/5722/head
genghaozhe 2024-05-16 08:05:32 +00:00
commit fc2248cf99
7 changed files with 100 additions and 49 deletions

View File

@ -567,6 +567,7 @@ class Chunk:
return self is __o return self is __o
def __repr__(self, detailed: bool = True): def __repr__(self, detailed: bool = True):
return f"Chunk({self.count_id})"
output = [ output = [
"Chunk Information:\n", "Chunk Information:\n",
"\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format( "\tchunk size: {}, chunk dtype: {}, process group size: {}\n".format(

View File

@ -131,9 +131,10 @@ class GeminiDDP(ModelWrapper):
offload_param_frac=offload_param_frac, offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio, warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio,
max_prefetch=max_prefetch,
) )
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch) self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list() self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list() self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict() self.grads_device: Dict[torch.Tensor, torch.device] = dict()

View File

@ -1,10 +1,9 @@
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Dict, List from typing import List
import torch import torch
import torch.distributed as dist
from colossalai.logging import DistributedLogger from colossalai.logging import DistributedLogger
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
@ -12,8 +11,6 @@ from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager from colossalai.zero.gemini.gemini_mgr import GeminiManager
from .chunk import Chunk
class TrainingPhase(Enum): class TrainingPhase(Enum):
FORWARD = 0 FORWARD = 0
@ -22,45 +19,60 @@ class TrainingPhase(Enum):
logger = DistributedLogger("gemini_hook") logger = DistributedLogger("gemini_hook")
import os
rank = int(os.environ["RANK"])
class GeminiZeROHook(ColoParamOpHook): class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None: def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__() super().__init__()
self._gemini_manager = gemini_manager self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
self._max_prefetch = max_prefetch
self._async_works: Dict[Chunk, dist.work] = {}
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
non_prefetched_chunks = []
for chunk in chunks:
if chunk in self._async_works:
print(f"prefetched {chunk.count_id}")
self._async_works[chunk].wait()
del self._async_works[chunk]
else:
non_prefetched_chunks.append(chunk)
return non_prefetched_chunks
def pre_op(self, params): def pre_op(self, params):
# map params to chunks
params = [p for p in params if not is_ddp_ignored(p)] params = [p for p in params if not is_ddp_ignored(p)]
all_chunks = self._chunk_manager.get_chunks(params) all_chunks = self._chunk_manager.get_chunks(params)
# wait for prefetched chunks, filter those are not prefetched # wait for prefetched chunks, filter those are not prefetched
chunks_fetch_sync = tuple(self.wait_chunks(all_chunks)) set(all_chunks)
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
# transfer state
for p in params: for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._gemini_manager.sample_overall_data() self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0)
# fetch the rest chunks synchronously # evit chunks, aware of async fetched
self._gemini_manager.adjust_layout(
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
)
# fetch the rest synchronously
for chunk in chunks_fetch_sync: for chunk in chunks_fetch_sync:
self._chunk_manager.access_chunk(chunk) self._chunk_manager.access_chunk(chunk)
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch)
# get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
if rank == 0 and not self._gemini_manager.is_warmup():
print(
f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}"
)
print(f"{all_chunks=}")
print(f"accessed_chunks={self._chunk_manager.accessed_chunks}")
print(f"{chunks_fetch_sync=}")
print(f"{chunks_fetch_async=}")
print(f"works={list(self._gemini_manager._async_works.keys())}")
# prefetch
for chunk in chunks_fetch_async: for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True) maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None: if maybe_work is not None:
self._async_works[chunk] = maybe_work self._gemini_manager.add_work(chunk, maybe_work)
if rank == 0 and not self._gemini_manager.is_warmup():
print(f"post accessed_chunks={self._chunk_manager.accessed_chunks}")
# record cuda model data of the current OP, including memory for prefetched chunks # record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume() self._gemini_manager.record_model_data_volume()
@ -88,11 +100,6 @@ class GeminiZeROHook(ColoParamOpHook):
@contextmanager @contextmanager
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD): def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
if training_phase == TrainingPhase.FORWARD:
self._cur_param_idx = 0
else:
self._cur_param_idx = len(self._param_visited_order) - 1
old_training_phase = self._training_phase old_training_phase = self._training_phase
try: try:
self._training_phase = training_phase self._training_phase = training_phase

View File

@ -1,8 +1,9 @@
import functools import functools
from time import time from time import time
from typing import Dict, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
import torch import torch
import torch.distributed as dist
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector, MemStats from .memory_tracer import ChunkMemStatsCollector, MemStats
@ -41,9 +42,10 @@ class GeminiManager:
self._mem_stats_collector = ( self._mem_stats_collector = (
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
) )
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1 self._compute_idx: int = -1
self._async_works: Dict[Chunk, dist.work] = {}
self._h2d_volume = 0 self._h2d_volume = 0
self._d2h_volume = 0 self._d2h_volume = 0
@ -98,11 +100,13 @@ class GeminiManager:
# find stateful tensor in state COMPUTE # find stateful tensor in state COMPUTE
start = time() start = time()
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway) self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks)
# don't evict chunks that are asynchronously fetched
can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works]
self._layout_time += time() - start self._layout_time += time() - start
vol, evict_time = self._placement_policy.evict_tensors( vol, evict_time = self._placement_policy.evict_tensors(
can_evict_chunks=hold_cuda_tensor_list, can_evict_chunks=can_evict_chunks,
cuda_demand=cuda_demand, cuda_demand=cuda_demand,
warmup=self._warmup, warmup=self._warmup,
compute_list=self._compute_list, compute_list=self._compute_list,
@ -114,6 +118,21 @@ class GeminiManager:
# move COMPUTE tensors to CUDA # move COMPUTE tensors to CUDA
self._h2d_volume += cuda_demand self._h2d_volume += cuda_demand
def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]:
non_prefetched_chunks = []
for chunk in chunks:
if chunk in self._async_works:
self._async_works[chunk].wait()
del self._async_works[chunk]
else:
non_prefetched_chunks.append(chunk)
return tuple(non_prefetched_chunks)
def add_work(self, chunk: Chunk, work: dist.Work):
assert work is not None
assert chunk not in self._async_works
self._async_works[chunk] = work
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]): def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
start = time() start = time()

View File

@ -18,10 +18,17 @@ class PlacementPolicy(ABC):
need_mem_stats: bool = False need_mem_stats: bool = False
def __init__( def __init__(
self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
**kwargs,
) -> None: ) -> None:
self.gemini_manager = gemini_manager
self.chunk_manager = chunk_manager self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
self.max_prefetch = max_prefetch
@abstractmethod @abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
@ -34,21 +41,30 @@ class PlacementPolicy(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: def get_prefetch_chunks(self) -> List[Chunk]:
raise NotImplementedError raise NotImplementedError
import os
rank = int(os.environ["RANK"])
class StaticPlacementPolicy(PlacementPolicy): class StaticPlacementPolicy(PlacementPolicy):
def __init__( def __init__(
self, self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0, offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0, offload_param_frac: float = 0.0,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0 offload_param_frac = 0.0
@ -99,15 +115,17 @@ class StaticPlacementPolicy(PlacementPolicy):
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]: def get_prefetch_chunks(self) -> List[Chunk]:
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
return []
prefetch = [] prefetch = []
for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)): for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
for chunk in self.chunk_manager.compute_list[i]: for chunk in self.gemini_manager.compute_list[i]:
if len(prefetch) >= max_prefetch: if len(prefetch) >= self.max_prefetch:
break break
if chunk not in prefetch: if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch.append(chunk) prefetch.append(chunk)
if len(prefetch) >= max_prefetch: if len(prefetch) >= self.max_prefetch:
break break
return prefetch return prefetch
@ -117,13 +135,17 @@ class AutoPlacementPolicy(PlacementPolicy):
def __init__( def __init__(
self, self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
warmup_non_model_data_ratio: float = 0.8, warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9, steady_cuda_cap_ratio: float = 0.9,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio()

View File

@ -30,8 +30,9 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir), on_trace_ready=tensorboard_trace_handler(save_dir),
record_shapes=True, # record_shapes=True,
profile_memory=True, # profile_memory=True,
with_stack=True,
) )
else: else:
return nullcontext(DummyProfiler()) return nullcontext(DummyProfiler())

View File

@ -129,7 +129,7 @@ def main():
WARMUP_STEPS = 1 WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
PROF_FLAG = False # The flag of profiling, False by default PROF_FLAG = True # The flag of profiling, False by default
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -166,7 +166,7 @@ def main():
stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True
) )
elif args.distplan == "CAI_Gemini": elif args.distplan == "CAI_Gemini":
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1)
else: else:
raise RuntimeError raise RuntimeError
@ -248,7 +248,7 @@ def main():
prof.step() prof.step()
tflops_list.sort() tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1)
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
torch.cuda.synchronize() torch.cuda.synchronize()