Merge pull request #5722 from botbw/prefetch

[gemini] prefetch chunks
pull/5731/head
botbw 2024-05-17 13:46:18 +08:00 committed by GitHub
commit 9690981601
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 134 additions and 34 deletions

View File

@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase):
chunk_init_device: Optional[torch.device] = None, chunk_init_device: Optional[torch.device] = None,
placement_policy: str = "static", placement_policy: str = "static",
enable_gradient_accumulation: bool = False, enable_gradient_accumulation: bool = False,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, # only for static placement shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement
@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase):
memstats=memstats, memstats=memstats,
mixed_precision=PRECISION_STR_TO_DTYPE[precision], mixed_precision=PRECISION_STR_TO_DTYPE[precision],
master_weights=master_weights, master_weights=master_weights,
max_prefetch=max_prefetch,
) )
self.zero_optim_config = dict( self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio, gpu_margin_mem_ratio=gpu_margin_mem_ratio,

View File

@ -357,14 +357,14 @@ class Chunk:
else: else:
raise NotImplementedError raise NotImplementedError
def access_chunk(self): def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA.""" """Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
# sanity check # sanity check
assert self.chunk_temp is None assert self.chunk_temp is None
if not self.is_gathered: if not self.is_gathered:
self.__gather() return self.__gather(async_op=async_access)
self.__update_tensors_ptr() self.__update_tensors_ptr()
return None
def release_chunk(self): def release_chunk(self):
"""Release the usable chunk. It's an operation done in CUDA.""" """Release the usable chunk. It's an operation done in CUDA."""
@ -498,17 +498,19 @@ class Chunk:
def get_tensors(self) -> List[torch.Tensor]: def get_tensors(self) -> List[torch.Tensor]:
return list(self.tensors_info.keys()) return list(self.tensors_info.keys())
def __gather(self): def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
if not self.is_gathered: if not self.is_gathered:
# sanity check # sanity check
assert self.cuda_shard is not None assert self.cuda_shard is not None
alloc_storage(self.cuda_global_chunk) alloc_storage(self.cuda_global_chunk)
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0)) gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg) work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
self.cuda_shard = None self.cuda_shard = None
self.is_gathered = True self.is_gathered = True
return work
return None
def __scatter(self): def __scatter(self):
if self.keep_gathered: if self.keep_gathered:

View File

@ -111,15 +111,16 @@ class ChunkManager:
for group_name in self.chunk_groups: for group_name in self.chunk_groups:
self.__close_one_chunk(self.chunk_groups[group_name][-1]) self.__close_one_chunk(self.chunk_groups[group_name][-1])
def access_chunk(self, chunk: Chunk) -> None: def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk can be used for calculation.""" """Make the chunk can be used for calculation."""
if chunk in self.accessed_chunks: if chunk in self.accessed_chunks:
return return None
self.__sub_memory_usage(chunk.memory_usage) self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu": if chunk.device_type == "cpu":
chunk.shard_move(get_accelerator().get_current_device()) chunk.shard_move(get_accelerator().get_current_device())
self.__add_accessed_chunk(chunk) maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
self.__add_memory_usage(chunk.memory_usage) self.__add_memory_usage(chunk.memory_usage)
return maybe_work
def release_chunk(self, chunk: Chunk) -> None: def release_chunk(self, chunk: Chunk) -> None:
"""Scatter the chunk in CUDA.""" """Scatter the chunk in CUDA."""
@ -251,10 +252,11 @@ class ChunkManager:
for k, v in usage.items(): for k, v in usage.items():
self.total_mem[k] += v self.total_mem[k] += v
def __add_accessed_chunk(self, chunk: Chunk): def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
chunk.access_chunk() maybe_work = chunk.access_chunk(async_access=async_access)
self.accessed_chunks.add(chunk) self.accessed_chunks.add(chunk)
self.accessed_mem += chunk.chunk_mem self.accessed_mem += chunk.chunk_mem
return maybe_work
def __sub_accessed_chunk(self, chunk: Chunk): def __sub_accessed_chunk(self, chunk: Chunk):
chunk.release_chunk() chunk.release_chunk()

View File

@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper):
chunk_init_device: torch.device = torch.device("cpu"), chunk_init_device: torch.device = torch.device("cpu"),
placement_policy: str = "static", placement_policy: str = "static",
enable_gradient_accumulation: bool = False, enable_gradient_accumulation: bool = False,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, # only for static placement shard_param_frac: float = 1.0, # only for static placement
offload_optim_frac: float = 0.0, # only for static placement offload_optim_frac: float = 0.0, # only for static placement
offload_param_frac: float = 0.0, # only for static placement offload_param_frac: float = 0.0, # only for static placement
@ -130,6 +131,7 @@ class GeminiDDP(ModelWrapper):
offload_param_frac=offload_param_frac, offload_param_frac=offload_param_frac,
warmup_non_model_data_ratio=warmup_non_model_data_ratio, warmup_non_model_data_ratio=warmup_non_model_data_ratio,
steady_cuda_cap_ratio=steady_cuda_cap_ratio, steady_cuda_cap_ratio=steady_cuda_cap_ratio,
max_prefetch=max_prefetch,
) )
self.force_outputs_fp32 = force_outputs_fp32 self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager) self.param_op_hook = GeminiZeROHook(self.gemini_manager)

View File

@ -5,6 +5,7 @@ from typing import List
import torch import torch
from colossalai.logging import DistributedLogger
from colossalai.tensor.param_op_hook import ColoParamOpHook from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState from colossalai.zero.gemini import TensorState
@ -16,6 +17,9 @@ class TrainingPhase(Enum):
BACKWARD = 1 BACKWARD = 1
logger = DistributedLogger("gemini_hook")
class GeminiZeROHook(ColoParamOpHook): class GeminiZeROHook(ColoParamOpHook):
def __init__(self, gemini_manager: GeminiManager) -> None: def __init__(self, gemini_manager: GeminiManager) -> None:
super().__init__() super().__init__()
@ -24,16 +28,37 @@ class GeminiZeROHook(ColoParamOpHook):
self._training_phase = TrainingPhase.FORWARD self._training_phase = TrainingPhase.FORWARD
def pre_op(self, params): def pre_op(self, params):
# map params to chunks
params = [p for p in params if not is_ddp_ignored(p)] params = [p for p in params if not is_ddp_ignored(p)]
chunks = self._chunk_manager.get_chunks(params) all_chunks = self._chunk_manager.get_chunks(params)
# wait for prefetched chunks, filter those are not prefetched
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
# transfer state
for p in params: for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE) self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._gemini_manager.sample_overall_data() self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks)
for chunk in chunks: # evit chunks, aware of async fetched
self._gemini_manager.adjust_layout(
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
)
# fetch the rest synchronously
for chunk in chunks_fetch_sync:
self._chunk_manager.access_chunk(chunk) self._chunk_manager.access_chunk(chunk)
# record cuda model data of the current OP # get possible chunks to prefetch
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
# prefetch
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
self._gemini_manager.add_work(chunk, maybe_work)
# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume() self._gemini_manager.record_model_data_volume()
def post_op(self, params): def post_op(self, params):

View File

@ -1,12 +1,13 @@
import functools import functools
from time import time from time import time
from typing import Dict, List, Optional, Tuple from typing import Dict, Iterable, List, Optional, Tuple
import torch import torch
import torch.distributed as dist
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector, MemStats from .memory_tracer import ChunkMemStatsCollector, MemStats
from .placement_policy import PlacementPolicyFactory from .placement_policy import PlacementPolicy, PlacementPolicyFactory
class GeminiManager: class GeminiManager:
@ -41,9 +42,10 @@ class GeminiManager:
self._mem_stats_collector = ( self._mem_stats_collector = (
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
) )
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs) self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs)
self._compute_list: List[Tuple[Chunk, ...]] = [] self._compute_list: List[Tuple[Chunk, ...]] = []
self._compute_idx: int = -1 self._compute_idx: int = -1
self._async_works: Dict[Chunk, dist.work] = {}
self._h2d_volume = 0 self._h2d_volume = 0
self._d2h_volume = 0 self._d2h_volume = 0
@ -91,18 +93,20 @@ class GeminiManager:
self._warmup = False self._warmup = False
self.reset_attributes() self.reset_attributes()
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None: def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
"""Adjust the layout of stateful tensors according to the information provided """Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model. by mem_stats_collector, which should belongs to a Sharded Model.
""" """
# find stateful tensor in state COMPUTE # find stateful tensor in state COMPUTE
start = time() start = time()
self._record_chunks_order(chunks) self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks) cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks)
# don't evict chunks that are asynchronously fetched
can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works]
self._layout_time += time() - start self._layout_time += time() - start
vol, evict_time = self._placement_policy.evict_tensors( vol, evict_time = self._placement_policy.evict_tensors(
can_evict_chunks=hold_cuda_tensor_list, can_evict_chunks=can_evict_chunks,
cuda_demand=cuda_demand, cuda_demand=cuda_demand,
warmup=self._warmup, warmup=self._warmup,
compute_list=self._compute_list, compute_list=self._compute_list,
@ -114,6 +118,21 @@ class GeminiManager:
# move COMPUTE tensors to CUDA # move COMPUTE tensors to CUDA
self._h2d_volume += cuda_demand self._h2d_volume += cuda_demand
def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]:
non_prefetched_chunks = []
for chunk in chunks:
if chunk in self._async_works:
self._async_works[chunk].wait()
del self._async_works[chunk]
else:
non_prefetched_chunks.append(chunk)
return tuple(non_prefetched_chunks)
def add_work(self, chunk: Chunk, work: dist.Work):
assert work is not None
assert chunk not in self._async_works
self._async_works[chunk] = work
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]): def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
start = time() start = time()
@ -133,9 +152,9 @@ class GeminiManager:
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks() can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
return cuda_demand, can_evict_chunks return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None: def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
self._compute_idx += 1 self._compute_idx += 1
if self._warmup and self._placement_policy.need_mem_stats: if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
self._compute_list.append(chunks) self._compute_list.append(chunks)
def sample_overall_data(self): def sample_overall_data(self):
@ -156,6 +175,18 @@ class GeminiManager:
return self._mem_stats_collector.cuda_margin_mem return self._mem_stats_collector.cuda_margin_mem
return None return None
@property
def compute_list(self) -> List[Tuple[Chunk, ...]]:
return self._compute_list
@property
def compute_idx(self) -> int:
return self._compute_idx
@property
def placement_policy(self) -> PlacementPolicy:
return self._placement_policy
@property @property
def is_cuda_margin_mem_avail(self) -> bool: def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats return self._placement_policy.need_mem_stats

View File

@ -18,10 +18,17 @@ class PlacementPolicy(ABC):
need_mem_stats: bool = False need_mem_stats: bool = False
def __init__( def __init__(
self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
**kwargs,
) -> None: ) -> None:
self.gemini_manager = gemini_manager
self.chunk_manager = chunk_manager self.chunk_manager = chunk_manager
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
self.max_prefetch = max_prefetch
@abstractmethod @abstractmethod
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]: def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
@ -33,18 +40,26 @@ class PlacementPolicy(ABC):
) -> None: ) -> None:
raise NotImplementedError raise NotImplementedError
@abstractmethod
def get_prefetch_chunks(self) -> List[Chunk]:
raise NotImplementedError
class StaticPlacementPolicy(PlacementPolicy): class StaticPlacementPolicy(PlacementPolicy):
def __init__( def __init__(
self, self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
shard_param_frac: float = 1.0, shard_param_frac: float = 1.0,
offload_optim_frac: float = 0.0, offload_optim_frac: float = 0.0,
offload_param_frac: float = 0.0, offload_param_frac: float = 0.0,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0): if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0") warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
offload_param_frac = 0.0 offload_param_frac = 0.0
@ -95,19 +110,38 @@ class StaticPlacementPolicy(PlacementPolicy):
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac) self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac) self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(self) -> List[Chunk]:
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
return []
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
prefetch = []
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
for chunk in self.gemini_manager.compute_list[i]:
if len(prefetch) >= can_prefetch:
break
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
prefetch.append(chunk)
if len(prefetch) >= can_prefetch:
break
return prefetch
class AutoPlacementPolicy(PlacementPolicy): class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True need_mem_stats: bool = True
def __init__( def __init__(
self, self,
gemini_manager: "GeminiManager",
chunk_manager: ChunkManager, chunk_manager: ChunkManager,
mem_stats_collector: Optional[ChunkMemStatsCollector] = None, mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
max_prefetch: int = 0,
warmup_non_model_data_ratio: float = 0.8, warmup_non_model_data_ratio: float = 0.8,
steady_cuda_cap_ratio: float = 0.9, steady_cuda_cap_ratio: float = 0.9,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector) super().__init__(
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
)
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase # model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio() # you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio() # and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
@ -198,6 +232,9 @@ class AutoPlacementPolicy(PlacementPolicy):
else: else:
grads_device_map[p] = torch.device("cpu") grads_device_map[p] = torch.device("cpu")
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
return [] # TODO @botbw: implement prefetching for auto
class PlacementPolicyFactory: class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = { policies: Dict[str, Type[PlacementPolicy]] = {

View File

@ -30,8 +30,9 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps), schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
on_trace_ready=tensorboard_trace_handler(save_dir), on_trace_ready=tensorboard_trace_handler(save_dir),
record_shapes=True, # record_shapes=True,
profile_memory=True, # profile_memory=True,
with_stack=True,
) )
else: else:
return nullcontext(DummyProfiler()) return nullcontext(DummyProfiler())

View File

@ -129,7 +129,7 @@ def main():
WARMUP_STEPS = 1 WARMUP_STEPS = 1
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps" assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median" assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
PROF_FLAG = False # The flag of profiling, False by default PROF_FLAG = True # The flag of profiling, False by default
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch() colossalai.launch_from_torch()
@ -166,7 +166,7 @@ def main():
stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True
) )
elif args.distplan == "CAI_Gemini": elif args.distplan == "CAI_Gemini":
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd) plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1)
else: else:
raise RuntimeError raise RuntimeError
@ -248,7 +248,7 @@ def main():
prof.step() prof.step()
tflops_list.sort() tflops_list.sort()
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1)
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}") logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
torch.cuda.synchronize() torch.cuda.synchronize()

View File

@ -40,9 +40,7 @@ EXAMPLE_MODELS = [
] ]
# bfloat16 cannot represent them exactly # bfloat16 cannot represent them exactly
BF16_IGNORED_KEYS = [ BF16_IGNORED_KEYS = ["masked_bias"]
"masked_bias",
]
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype): def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):