[gemini] use compute_chunk to find next chunk

pull/5722/head
hxwang 2024-05-16 13:17:26 +08:00
parent b2e9745888
commit 4148ceed9f
5 changed files with 52 additions and 79 deletions

View File

@ -114,7 +114,7 @@ class ChunkManager:
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
"""Make the chunk can be used for calculation."""
if chunk in self.accessed_chunks:
return
return None
self.__sub_memory_usage(chunk.memory_usage)
if chunk.device_type == "cpu":
chunk.shard_move(get_accelerator().get_current_device())

View File

@ -133,6 +133,7 @@ class GeminiDDP(ModelWrapper):
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
@ -157,8 +158,6 @@ class GeminiDDP(ModelWrapper):
for p in module.parameters():
param_order.append(p)
self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch)
for name, param in module.named_parameters():
self.param2name[param] = name
for m_name, m_var in module.named_modules():

View File

@ -6,16 +6,15 @@ from typing import Dict, List
import torch
import torch.distributed as dist
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.logging import DistributedLogger
from colossalai.tensor.param_op_hook import ColoParamOpHook
from colossalai.utils import is_ddp_ignored
from colossalai.zero.gemini import TensorState
from colossalai.zero.gemini.gemini_mgr import GeminiManager
from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator
from colossalai.logging import DistributedLogger
from .chunk import Chunk
class TrainingPhase(Enum):
FORWARD = 0
BACKWARD = 1
@ -23,51 +22,15 @@ class TrainingPhase(Enum):
logger = DistributedLogger("gemini_hook")
class GeminiZeROHook(ColoParamOpHook):
def __init__(
self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0
) -> None:
def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None:
super().__init__()
self._gemini_manager = gemini_manager
self._chunk_manager = gemini_manager.chunk_manager
self._training_phase = TrainingPhase.FORWARD
# param_visited_order might be updated somewhere else
self._param_visited_order = param_order.param_visited_order
self._max_prefetch = max_prefetch
self._async_works: Dict[Chunk, dist.work] = {}
# used by get_prefetch_chunks to track current param
self._cur_param_idx = 0
def get_prefetch_chunks(self, all_params: List[ColoParameter], cur_chunks: List[Chunk]) -> List[Chunk]:
chunks_to_prefetch = set()
if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase
self._cur_param_idx += len(all_params) # need to update first
idx = self._cur_param_idx + 1
# still have params and prefetched chunks don't exceed the limit
while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch:
param = self._param_visited_order[idx]
if is_ddp_ignored(param):
idx += 1
continue
chunk = self._chunk_manager.get_chunk(param)
if chunk not in cur_chunks:
chunks_to_prefetch.add(chunk)
idx += 1
else:
self._cur_param_idx -= len(all_params)
idx = self._cur_param_idx - 1
chunks_to_prefetch = set()
while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch:
param = self._param_visited_order[idx]
if is_ddp_ignored(param):
idx -= 1
continue
chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx])
if chunk not in cur_chunks:
chunks_to_prefetch.add(chunk)
idx -= 1
print(f"cur id {self._cur_param_idx}")
return list(chunks_to_prefetch)
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
non_prefetched_chunks = []
@ -80,45 +43,25 @@ class GeminiZeROHook(ColoParamOpHook):
non_prefetched_chunks.append(chunk)
return non_prefetched_chunks
def pre_op(self, all_params):
# def find_idx(param):
# for i, p in enumerate(self._param_visited_order):
# if param is p:
# return i
# assert False
# idxs = [find_idx(p) for p in all_params]
# max_id = min(idxs)
# idxs = [i - max_id for i in idxs]
# assert list(range(len(idxs))) == sorted(idxs), f'{idxs}'
# deal with current needed chunks
params = [p for p in all_params if not is_ddp_ignored(p)]
def pre_op(self, params):
params = [p for p in params if not is_ddp_ignored(p)]
all_chunks = self._chunk_manager.get_chunks(params)
chunks_need_to_fetch_sync = tuple(self.wait_chunks(all_chunks))
# wait for prefetched chunks, filter those are not prefetched
chunks_fetch_sync = tuple(self.wait_chunks(all_chunks))
for p in params:
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
self._gemini_manager.sample_overall_data()
self._gemini_manager.adjust_layout(chunks_need_to_fetch_sync)
# deal with chunks that are to be async fetched
chunks_can_be_fetch_async = self.get_prefetch_chunks(all_params=all_params, cur_chunks=chunks_need_to_fetch_sync)
print(f"cur_chunks {' '.join([str(x.count_id) for x in chunks_need_to_fetch_sync])}, prefetch {' '.join([str(x.count_id) for x in chunks_can_be_fetch_async])}")
# deal with chunks that are to be fetched now
for chunk in chunks_need_to_fetch_sync:
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0)
# fetch the rest chunks synchronously
for chunk in chunks_fetch_sync:
self._chunk_manager.access_chunk(chunk)
# deal with chunks that are to be pre fetched TODO @botbw: the order here matters?
for chunk in chunks_can_be_fetch_async:
if chunk in self._async_works:
continue
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch)
for chunk in chunks_fetch_async:
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
if maybe_work is not None:
print(f"prefetch {chunk.count_id}")
self._async_works[chunk] = maybe_work
# record cuda model data of the current OP
# record cuda model data of the current OP, including memory for prefetched chunks
self._gemini_manager.record_model_data_volume()
def post_op(self, params):

View File

@ -6,7 +6,7 @@ import torch
from .chunk import Chunk, ChunkManager
from .memory_tracer import ChunkMemStatsCollector, MemStats
from .placement_policy import PlacementPolicyFactory
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
class GeminiManager:
@ -91,13 +91,13 @@ class GeminiManager:
self._warmup = False
self.reset_attributes()
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
"""Adjust the layout of stateful tensors according to the information provided
by mem_stats_collector, which should belongs to a Sharded Model.
"""
# find stateful tensor in state COMPUTE
start = time()
self._record_chunks_order(chunks)
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
self._layout_time += time() - start
@ -133,9 +133,9 @@ class GeminiManager:
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
return cuda_demand, can_evict_chunks
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
self._compute_idx += 1
if self._warmup and self._placement_policy.need_mem_stats:
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
self._compute_list.append(chunks)
def sample_overall_data(self):
@ -156,6 +156,18 @@ class GeminiManager:
return self._mem_stats_collector.cuda_margin_mem
return None
@property
def compute_list(self) -> List[Tuple[Chunk, ...]]:
return self._compute_list
@property
def compute_idx(self) -> int:
return self._compute_idx
@property
def placement_policy(self) -> PlacementPolicy:
return self._placement_policy
@property
def is_cuda_margin_mem_avail(self) -> bool:
return self._placement_policy.need_mem_stats

View File

@ -33,6 +33,10 @@ class PlacementPolicy(ABC):
) -> None:
raise NotImplementedError
@abstractmethod
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
raise NotImplementedError
class StaticPlacementPolicy(PlacementPolicy):
def __init__(
@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy):
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
prefetch = []
for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)):
for chunk in self.chunk_manager.compute_list[i]:
if len(prefetch) >= max_prefetch:
break
if chunk not in prefetch:
prefetch.append(chunk)
if len(prefetch) >= max_prefetch:
break
return prefetch
class AutoPlacementPolicy(PlacementPolicy):
need_mem_stats: bool = True
@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy):
else:
grads_device_map[p] = torch.device("cpu")
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
return [] # TODO @botbw: implement prefetching for auto
class PlacementPolicyFactory:
policies: Dict[str, Type[PlacementPolicy]] = {