mirror of https://github.com/hpcaitech/ColossalAI
[gemini] use compute_chunk to find next chunk
parent
b2e9745888
commit
4148ceed9f
|
@ -114,7 +114,7 @@ class ChunkManager:
|
|||
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
|
||||
"""Make the chunk can be used for calculation."""
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
return None
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_accelerator().get_current_device())
|
||||
|
|
|
@ -133,6 +133,7 @@ class GeminiDDP(ModelWrapper):
|
|||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||
)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch)
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
|
@ -157,8 +158,6 @@ class GeminiDDP(ModelWrapper):
|
|||
for p in module.parameters():
|
||||
param_order.append(p)
|
||||
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager, param_order=param_order, max_prefetch=max_prefetch)
|
||||
|
||||
for name, param in module.named_parameters():
|
||||
self.param2name[param] = name
|
||||
for m_name, m_var in module.named_modules():
|
||||
|
|
|
@ -6,16 +6,15 @@ from typing import Dict, List
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
from colossalai.zero.gemini import TensorState
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.zero.gemini.memory_tracer.param_runtime_order import OrderedParamGenerator
|
||||
from colossalai.logging import DistributedLogger
|
||||
|
||||
from .chunk import Chunk
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
FORWARD = 0
|
||||
BACKWARD = 1
|
||||
|
@ -23,51 +22,15 @@ class TrainingPhase(Enum):
|
|||
|
||||
logger = DistributedLogger("gemini_hook")
|
||||
|
||||
|
||||
class GeminiZeROHook(ColoParamOpHook):
|
||||
def __init__(
|
||||
self, gemini_manager: GeminiManager, param_order: OrderedParamGenerator, max_prefetch: int = 0
|
||||
) -> None:
|
||||
def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None:
|
||||
super().__init__()
|
||||
self._gemini_manager = gemini_manager
|
||||
self._chunk_manager = gemini_manager.chunk_manager
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
# param_visited_order might be updated somewhere else
|
||||
self._param_visited_order = param_order.param_visited_order
|
||||
self._max_prefetch = max_prefetch
|
||||
self._async_works: Dict[Chunk, dist.work] = {}
|
||||
# used by get_prefetch_chunks to track current param
|
||||
self._cur_param_idx = 0
|
||||
|
||||
def get_prefetch_chunks(self, all_params: List[ColoParameter], cur_chunks: List[Chunk]) -> List[Chunk]:
|
||||
chunks_to_prefetch = set()
|
||||
if self._training_phase == TrainingPhase.FORWARD: # forward phrase: increase
|
||||
self._cur_param_idx += len(all_params) # need to update first
|
||||
idx = self._cur_param_idx + 1
|
||||
# still have params and prefetched chunks don't exceed the limit
|
||||
while idx < len(self._param_visited_order) and len(chunks_to_prefetch) + 1 < self._max_prefetch:
|
||||
param = self._param_visited_order[idx]
|
||||
if is_ddp_ignored(param):
|
||||
idx += 1
|
||||
continue
|
||||
chunk = self._chunk_manager.get_chunk(param)
|
||||
if chunk not in cur_chunks:
|
||||
chunks_to_prefetch.add(chunk)
|
||||
idx += 1
|
||||
else:
|
||||
self._cur_param_idx -= len(all_params)
|
||||
idx = self._cur_param_idx - 1
|
||||
chunks_to_prefetch = set()
|
||||
while idx >= 0 and len(chunks_to_prefetch) + 1 < self._max_prefetch:
|
||||
param = self._param_visited_order[idx]
|
||||
if is_ddp_ignored(param):
|
||||
idx -= 1
|
||||
continue
|
||||
chunk = self._chunk_manager.get_chunk(self._param_visited_order[idx])
|
||||
if chunk not in cur_chunks:
|
||||
chunks_to_prefetch.add(chunk)
|
||||
idx -= 1
|
||||
print(f"cur id {self._cur_param_idx}")
|
||||
return list(chunks_to_prefetch)
|
||||
|
||||
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
|
||||
non_prefetched_chunks = []
|
||||
|
@ -80,45 +43,25 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||
non_prefetched_chunks.append(chunk)
|
||||
return non_prefetched_chunks
|
||||
|
||||
def pre_op(self, all_params):
|
||||
# def find_idx(param):
|
||||
# for i, p in enumerate(self._param_visited_order):
|
||||
# if param is p:
|
||||
# return i
|
||||
# assert False
|
||||
|
||||
# idxs = [find_idx(p) for p in all_params]
|
||||
# max_id = min(idxs)
|
||||
# idxs = [i - max_id for i in idxs]
|
||||
# assert list(range(len(idxs))) == sorted(idxs), f'{idxs}'
|
||||
|
||||
# deal with current needed chunks
|
||||
params = [p for p in all_params if not is_ddp_ignored(p)]
|
||||
def pre_op(self, params):
|
||||
params = [p for p in params if not is_ddp_ignored(p)]
|
||||
all_chunks = self._chunk_manager.get_chunks(params)
|
||||
chunks_need_to_fetch_sync = tuple(self.wait_chunks(all_chunks))
|
||||
# wait for prefetched chunks, filter those are not prefetched
|
||||
chunks_fetch_sync = tuple(self.wait_chunks(all_chunks))
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
self._gemini_manager.sample_overall_data()
|
||||
self._gemini_manager.adjust_layout(chunks_need_to_fetch_sync)
|
||||
|
||||
# deal with chunks that are to be async fetched
|
||||
chunks_can_be_fetch_async = self.get_prefetch_chunks(all_params=all_params, cur_chunks=chunks_need_to_fetch_sync)
|
||||
|
||||
print(f"cur_chunks {' '.join([str(x.count_id) for x in chunks_need_to_fetch_sync])}, prefetch {' '.join([str(x.count_id) for x in chunks_can_be_fetch_async])}")
|
||||
# deal with chunks that are to be fetched now
|
||||
for chunk in chunks_need_to_fetch_sync:
|
||||
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0)
|
||||
# fetch the rest chunks synchronously
|
||||
for chunk in chunks_fetch_sync:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
|
||||
# deal with chunks that are to be pre fetched TODO @botbw: the order here matters?
|
||||
for chunk in chunks_can_be_fetch_async:
|
||||
if chunk in self._async_works:
|
||||
continue
|
||||
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch)
|
||||
for chunk in chunks_fetch_async:
|
||||
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||
if maybe_work is not None:
|
||||
print(f"prefetch {chunk.count_id}")
|
||||
self._async_works[chunk] = maybe_work
|
||||
|
||||
# record cuda model data of the current OP
|
||||
# record cuda model data of the current OP, including memory for prefetched chunks
|
||||
self._gemini_manager.record_model_data_volume()
|
||||
|
||||
def post_op(self, params):
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
|||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .memory_tracer import ChunkMemStatsCollector, MemStats
|
||||
from .placement_policy import PlacementPolicyFactory
|
||||
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
|
||||
|
||||
|
||||
class GeminiManager:
|
||||
|
@ -91,13 +91,13 @@ class GeminiManager:
|
|||
self._warmup = False
|
||||
self.reset_attributes()
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||
"""Adjust the layout of stateful tensors according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
start = time()
|
||||
self._record_chunks_order(chunks)
|
||||
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||
self._layout_time += time() - start
|
||||
|
||||
|
@ -133,9 +133,9 @@ class GeminiManager:
|
|||
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
|
||||
return cuda_demand, can_evict_chunks
|
||||
|
||||
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||
self._compute_idx += 1
|
||||
if self._warmup and self._placement_policy.need_mem_stats:
|
||||
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
|
||||
self._compute_list.append(chunks)
|
||||
|
||||
def sample_overall_data(self):
|
||||
|
@ -156,6 +156,18 @@ class GeminiManager:
|
|||
return self._mem_stats_collector.cuda_margin_mem
|
||||
return None
|
||||
|
||||
@property
|
||||
def compute_list(self) -> List[Tuple[Chunk, ...]]:
|
||||
return self._compute_list
|
||||
|
||||
@property
|
||||
def compute_idx(self) -> int:
|
||||
return self._compute_idx
|
||||
|
||||
@property
|
||||
def placement_policy(self) -> PlacementPolicy:
|
||||
return self._placement_policy
|
||||
|
||||
@property
|
||||
def is_cuda_margin_mem_avail(self) -> bool:
|
||||
return self._placement_policy.need_mem_stats
|
||||
|
|
|
@ -33,6 +33,10 @@ class PlacementPolicy(ABC):
|
|||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StaticPlacementPolicy(PlacementPolicy):
|
||||
def __init__(
|
||||
|
@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy):
|
|||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||
|
||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||
prefetch = []
|
||||
for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)):
|
||||
for chunk in self.chunk_manager.compute_list[i]:
|
||||
if len(prefetch) >= max_prefetch:
|
||||
break
|
||||
if chunk not in prefetch:
|
||||
prefetch.append(chunk)
|
||||
if len(prefetch) >= max_prefetch:
|
||||
break
|
||||
return prefetch
|
||||
|
||||
|
||||
class AutoPlacementPolicy(PlacementPolicy):
|
||||
need_mem_stats: bool = True
|
||||
|
@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
else:
|
||||
grads_device_map[p] = torch.device("cpu")
|
||||
|
||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||
return [] # TODO @botbw: implement prefetching for auto
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||
|
|
Loading…
Reference in New Issue