mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'prefetch' of github.com:botbw/ColossalAI into botbw-prefetch
commit
1f6b57099c
|
@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
chunk_init_device: Optional[torch.device] = None,
|
chunk_init_device: Optional[torch.device] = None,
|
||||||
placement_policy: str = "static",
|
placement_policy: str = "static",
|
||||||
enable_gradient_accumulation: bool = False,
|
enable_gradient_accumulation: bool = False,
|
||||||
|
max_prefetch: int = 0,
|
||||||
shard_param_frac: float = 1.0, # only for static placement
|
shard_param_frac: float = 1.0, # only for static placement
|
||||||
offload_optim_frac: float = 0.0, # only for static placement
|
offload_optim_frac: float = 0.0, # only for static placement
|
||||||
offload_param_frac: float = 0.0, # only for static placement
|
offload_param_frac: float = 0.0, # only for static placement
|
||||||
|
@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
memstats=memstats,
|
memstats=memstats,
|
||||||
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||||
master_weights=master_weights,
|
master_weights=master_weights,
|
||||||
|
max_prefetch=max_prefetch,
|
||||||
)
|
)
|
||||||
self.zero_optim_config = dict(
|
self.zero_optim_config = dict(
|
||||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||||
|
|
|
@ -357,14 +357,14 @@ class Chunk:
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def access_chunk(self):
|
def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
|
||||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
|
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
|
||||||
# sanity check
|
# sanity check
|
||||||
assert self.chunk_temp is None
|
assert self.chunk_temp is None
|
||||||
|
|
||||||
if not self.is_gathered:
|
if not self.is_gathered:
|
||||||
self.__gather()
|
return self.__gather(async_op=async_access)
|
||||||
self.__update_tensors_ptr()
|
self.__update_tensors_ptr()
|
||||||
|
return None
|
||||||
|
|
||||||
def release_chunk(self):
|
def release_chunk(self):
|
||||||
"""Release the usable chunk. It's an operation done in CUDA."""
|
"""Release the usable chunk. It's an operation done in CUDA."""
|
||||||
|
@ -498,17 +498,19 @@ class Chunk:
|
||||||
def get_tensors(self) -> List[torch.Tensor]:
|
def get_tensors(self) -> List[torch.Tensor]:
|
||||||
return list(self.tensors_info.keys())
|
return list(self.tensors_info.keys())
|
||||||
|
|
||||||
def __gather(self):
|
def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
|
||||||
if not self.is_gathered:
|
if not self.is_gathered:
|
||||||
# sanity check
|
# sanity check
|
||||||
assert self.cuda_shard is not None
|
assert self.cuda_shard is not None
|
||||||
|
|
||||||
alloc_storage(self.cuda_global_chunk)
|
alloc_storage(self.cuda_global_chunk)
|
||||||
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
|
||||||
|
|
||||||
self.cuda_shard = None
|
self.cuda_shard = None
|
||||||
self.is_gathered = True
|
self.is_gathered = True
|
||||||
|
return work
|
||||||
|
return None
|
||||||
|
|
||||||
def __scatter(self):
|
def __scatter(self):
|
||||||
if self.keep_gathered:
|
if self.keep_gathered:
|
||||||
|
|
|
@ -111,15 +111,16 @@ class ChunkManager:
|
||||||
for group_name in self.chunk_groups:
|
for group_name in self.chunk_groups:
|
||||||
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||||
|
|
||||||
def access_chunk(self, chunk: Chunk) -> None:
|
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
|
||||||
"""Make the chunk can be used for calculation."""
|
"""Make the chunk can be used for calculation."""
|
||||||
if chunk in self.accessed_chunks:
|
if chunk in self.accessed_chunks:
|
||||||
return
|
return None
|
||||||
self.__sub_memory_usage(chunk.memory_usage)
|
self.__sub_memory_usage(chunk.memory_usage)
|
||||||
if chunk.device_type == "cpu":
|
if chunk.device_type == "cpu":
|
||||||
chunk.shard_move(get_accelerator().get_current_device())
|
chunk.shard_move(get_accelerator().get_current_device())
|
||||||
self.__add_accessed_chunk(chunk)
|
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
|
||||||
self.__add_memory_usage(chunk.memory_usage)
|
self.__add_memory_usage(chunk.memory_usage)
|
||||||
|
return maybe_work
|
||||||
|
|
||||||
def release_chunk(self, chunk: Chunk) -> None:
|
def release_chunk(self, chunk: Chunk) -> None:
|
||||||
"""Scatter the chunk in CUDA."""
|
"""Scatter the chunk in CUDA."""
|
||||||
|
@ -251,10 +252,11 @@ class ChunkManager:
|
||||||
for k, v in usage.items():
|
for k, v in usage.items():
|
||||||
self.total_mem[k] += v
|
self.total_mem[k] += v
|
||||||
|
|
||||||
def __add_accessed_chunk(self, chunk: Chunk):
|
def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
|
||||||
chunk.access_chunk()
|
maybe_work = chunk.access_chunk(async_access=async_access)
|
||||||
self.accessed_chunks.add(chunk)
|
self.accessed_chunks.add(chunk)
|
||||||
self.accessed_mem += chunk.chunk_mem
|
self.accessed_mem += chunk.chunk_mem
|
||||||
|
return maybe_work
|
||||||
|
|
||||||
def __sub_accessed_chunk(self, chunk: Chunk):
|
def __sub_accessed_chunk(self, chunk: Chunk):
|
||||||
chunk.release_chunk()
|
chunk.release_chunk()
|
||||||
|
|
|
@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
chunk_init_device: torch.device = torch.device("cpu"),
|
chunk_init_device: torch.device = torch.device("cpu"),
|
||||||
placement_policy: str = "static",
|
placement_policy: str = "static",
|
||||||
enable_gradient_accumulation: bool = False,
|
enable_gradient_accumulation: bool = False,
|
||||||
|
max_prefetch: int = 0,
|
||||||
shard_param_frac: float = 1.0, # only for static placement
|
shard_param_frac: float = 1.0, # only for static placement
|
||||||
offload_optim_frac: float = 0.0, # only for static placement
|
offload_optim_frac: float = 0.0, # only for static placement
|
||||||
offload_param_frac: float = 0.0, # only for static placement
|
offload_param_frac: float = 0.0, # only for static placement
|
||||||
|
@ -132,7 +133,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||||
)
|
)
|
||||||
self.force_outputs_fp32 = force_outputs_fp32
|
self.force_outputs_fp32 = force_outputs_fp32
|
||||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
self.param_op_hook = GeminiZeROHook(self.gemini_manager, max_prefetch=max_prefetch)
|
||||||
self.fp32_params: List[torch.Tensor] = list()
|
self.fp32_params: List[torch.Tensor] = list()
|
||||||
self.fp16_params: List[ColoParameter] = list()
|
self.fp16_params: List[ColoParameter] = list()
|
||||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||||
|
|
|
@ -1,39 +1,67 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import List
|
from typing import Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from colossalai.logging import DistributedLogger
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||||
from colossalai.utils import is_ddp_ignored
|
from colossalai.utils import is_ddp_ignored
|
||||||
from colossalai.zero.gemini import TensorState
|
from colossalai.zero.gemini import TensorState
|
||||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||||
|
|
||||||
|
from .chunk import Chunk
|
||||||
|
|
||||||
|
|
||||||
class TrainingPhase(Enum):
|
class TrainingPhase(Enum):
|
||||||
FORWARD = 0
|
FORWARD = 0
|
||||||
BACKWARD = 1
|
BACKWARD = 1
|
||||||
|
|
||||||
|
|
||||||
|
logger = DistributedLogger("gemini_hook")
|
||||||
|
|
||||||
|
|
||||||
class GeminiZeROHook(ColoParamOpHook):
|
class GeminiZeROHook(ColoParamOpHook):
|
||||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
def __init__(self, gemini_manager: GeminiManager, max_prefetch: int = 0) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._gemini_manager = gemini_manager
|
self._gemini_manager = gemini_manager
|
||||||
self._chunk_manager = gemini_manager.chunk_manager
|
self._chunk_manager = gemini_manager.chunk_manager
|
||||||
self._training_phase = TrainingPhase.FORWARD
|
self._training_phase = TrainingPhase.FORWARD
|
||||||
|
self._max_prefetch = max_prefetch
|
||||||
|
self._async_works: Dict[Chunk, dist.work] = {}
|
||||||
|
|
||||||
|
def wait_chunks(self, chunks: List[Chunk]) -> List[Chunk]:
|
||||||
|
non_prefetched_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk in self._async_works:
|
||||||
|
print(f"prefetched {chunk.count_id}")
|
||||||
|
self._async_works[chunk].wait()
|
||||||
|
del self._async_works[chunk]
|
||||||
|
else:
|
||||||
|
non_prefetched_chunks.append(chunk)
|
||||||
|
return non_prefetched_chunks
|
||||||
|
|
||||||
def pre_op(self, params):
|
def pre_op(self, params):
|
||||||
params = [p for p in params if not is_ddp_ignored(p)]
|
params = [p for p in params if not is_ddp_ignored(p)]
|
||||||
chunks = self._chunk_manager.get_chunks(params)
|
all_chunks = self._chunk_manager.get_chunks(params)
|
||||||
|
# wait for prefetched chunks, filter those are not prefetched
|
||||||
|
chunks_fetch_sync = tuple(self.wait_chunks(all_chunks))
|
||||||
for p in params:
|
for p in params:
|
||||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||||
self._gemini_manager.sample_overall_data()
|
self._gemini_manager.sample_overall_data()
|
||||||
self._gemini_manager.adjust_layout(chunks)
|
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._max_prefetch > 0)
|
||||||
for chunk in chunks:
|
# fetch the rest chunks synchronously
|
||||||
|
for chunk in chunks_fetch_sync:
|
||||||
self._chunk_manager.access_chunk(chunk)
|
self._chunk_manager.access_chunk(chunk)
|
||||||
|
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks(max_prefetch=self._max_prefetch)
|
||||||
|
for chunk in chunks_fetch_async:
|
||||||
|
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||||
|
if maybe_work is not None:
|
||||||
|
self._async_works[chunk] = maybe_work
|
||||||
|
|
||||||
# record cuda model data of the current OP
|
# record cuda model data of the current OP, including memory for prefetched chunks
|
||||||
self._gemini_manager.record_model_data_volume()
|
self._gemini_manager.record_model_data_volume()
|
||||||
|
|
||||||
def post_op(self, params):
|
def post_op(self, params):
|
||||||
|
@ -60,6 +88,11 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
def switch_training_phase(self, training_phase: TrainingPhase = TrainingPhase.BACKWARD):
|
||||||
|
if training_phase == TrainingPhase.FORWARD:
|
||||||
|
self._cur_param_idx = 0
|
||||||
|
else:
|
||||||
|
self._cur_param_idx = len(self._param_visited_order) - 1
|
||||||
|
|
||||||
old_training_phase = self._training_phase
|
old_training_phase = self._training_phase
|
||||||
try:
|
try:
|
||||||
self._training_phase = training_phase
|
self._training_phase = training_phase
|
||||||
|
|
|
@ -6,7 +6,7 @@ import torch
|
||||||
|
|
||||||
from .chunk import Chunk, ChunkManager
|
from .chunk import Chunk, ChunkManager
|
||||||
from .memory_tracer import ChunkMemStatsCollector, MemStats
|
from .memory_tracer import ChunkMemStatsCollector, MemStats
|
||||||
from .placement_policy import PlacementPolicyFactory
|
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
|
||||||
|
|
||||||
|
|
||||||
class GeminiManager:
|
class GeminiManager:
|
||||||
|
@ -91,13 +91,13 @@ class GeminiManager:
|
||||||
self._warmup = False
|
self._warmup = False
|
||||||
self.reset_attributes()
|
self.reset_attributes()
|
||||||
|
|
||||||
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
|
def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||||
"""Adjust the layout of stateful tensors according to the information provided
|
"""Adjust the layout of stateful tensors according to the information provided
|
||||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||||
"""
|
"""
|
||||||
# find stateful tensor in state COMPUTE
|
# find stateful tensor in state COMPUTE
|
||||||
start = time()
|
start = time()
|
||||||
self._record_chunks_order(chunks)
|
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
|
||||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||||
self._layout_time += time() - start
|
self._layout_time += time() - start
|
||||||
|
|
||||||
|
@ -133,9 +133,9 @@ class GeminiManager:
|
||||||
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
|
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
|
||||||
return cuda_demand, can_evict_chunks
|
return cuda_demand, can_evict_chunks
|
||||||
|
|
||||||
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
|
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||||
self._compute_idx += 1
|
self._compute_idx += 1
|
||||||
if self._warmup and self._placement_policy.need_mem_stats:
|
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
|
||||||
self._compute_list.append(chunks)
|
self._compute_list.append(chunks)
|
||||||
|
|
||||||
def sample_overall_data(self):
|
def sample_overall_data(self):
|
||||||
|
@ -156,6 +156,18 @@ class GeminiManager:
|
||||||
return self._mem_stats_collector.cuda_margin_mem
|
return self._mem_stats_collector.cuda_margin_mem
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def compute_list(self) -> List[Tuple[Chunk, ...]]:
|
||||||
|
return self._compute_list
|
||||||
|
|
||||||
|
@property
|
||||||
|
def compute_idx(self) -> int:
|
||||||
|
return self._compute_idx
|
||||||
|
|
||||||
|
@property
|
||||||
|
def placement_policy(self) -> PlacementPolicy:
|
||||||
|
return self._placement_policy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_cuda_margin_mem_avail(self) -> bool:
|
def is_cuda_margin_mem_avail(self) -> bool:
|
||||||
return self._placement_policy.need_mem_stats
|
return self._placement_policy.need_mem_stats
|
||||||
|
|
|
@ -33,6 +33,10 @@ class PlacementPolicy(ABC):
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class StaticPlacementPolicy(PlacementPolicy):
|
class StaticPlacementPolicy(PlacementPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -95,6 +99,18 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||||
|
|
||||||
|
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||||
|
prefetch = []
|
||||||
|
for i in range(self.chunk_manager.compute_idx + 1, len(self.chunk_manager.compute_list)):
|
||||||
|
for chunk in self.chunk_manager.compute_list[i]:
|
||||||
|
if len(prefetch) >= max_prefetch:
|
||||||
|
break
|
||||||
|
if chunk not in prefetch:
|
||||||
|
prefetch.append(chunk)
|
||||||
|
if len(prefetch) >= max_prefetch:
|
||||||
|
break
|
||||||
|
return prefetch
|
||||||
|
|
||||||
|
|
||||||
class AutoPlacementPolicy(PlacementPolicy):
|
class AutoPlacementPolicy(PlacementPolicy):
|
||||||
need_mem_stats: bool = True
|
need_mem_stats: bool = True
|
||||||
|
@ -198,6 +214,9 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
else:
|
else:
|
||||||
grads_device_map[p] = torch.device("cpu")
|
grads_device_map[p] = torch.device("cpu")
|
||||||
|
|
||||||
|
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||||
|
return [] # TODO @botbw: implement prefetching for auto
|
||||||
|
|
||||||
|
|
||||||
class PlacementPolicyFactory:
|
class PlacementPolicyFactory:
|
||||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||||
|
|
Loading…
Reference in New Issue