mirror of https://github.com/hpcaitech/ColossalAI
commit
9690981601
|
@ -329,6 +329,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
chunk_init_device: Optional[torch.device] = None,
|
||||
placement_policy: str = "static",
|
||||
enable_gradient_accumulation: bool = False,
|
||||
max_prefetch: int = 0,
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
|
@ -386,6 +387,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
memstats=memstats,
|
||||
mixed_precision=PRECISION_STR_TO_DTYPE[precision],
|
||||
master_weights=master_weights,
|
||||
max_prefetch=max_prefetch,
|
||||
)
|
||||
self.zero_optim_config = dict(
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
|
|
|
@ -357,14 +357,14 @@ class Chunk:
|
|||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def access_chunk(self):
|
||||
def access_chunk(self, async_access: bool = False) -> Optional[dist.Work]:
|
||||
"""Make the chunk usable for the parameters inside it. It's an operation done in CUDA."""
|
||||
# sanity check
|
||||
assert self.chunk_temp is None
|
||||
|
||||
if not self.is_gathered:
|
||||
self.__gather()
|
||||
return self.__gather(async_op=async_access)
|
||||
self.__update_tensors_ptr()
|
||||
return None
|
||||
|
||||
def release_chunk(self):
|
||||
"""Release the usable chunk. It's an operation done in CUDA."""
|
||||
|
@ -498,17 +498,19 @@ class Chunk:
|
|||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
def __gather(self):
|
||||
def __gather(self, async_op: bool = False) -> Optional[dist.Work]:
|
||||
if not self.is_gathered:
|
||||
# sanity check
|
||||
assert self.cuda_shard is not None
|
||||
|
||||
alloc_storage(self.cuda_global_chunk)
|
||||
gather_list = list(torch.chunk(input=self.cuda_global_chunk, chunks=self.pg_size, dim=0))
|
||||
dist.all_gather(gather_list, self.cuda_shard, self.torch_pg)
|
||||
work = dist.all_gather(gather_list, self.cuda_shard, self.torch_pg, async_op=async_op)
|
||||
|
||||
self.cuda_shard = None
|
||||
self.is_gathered = True
|
||||
return work
|
||||
return None
|
||||
|
||||
def __scatter(self):
|
||||
if self.keep_gathered:
|
||||
|
|
|
@ -111,15 +111,16 @@ class ChunkManager:
|
|||
for group_name in self.chunk_groups:
|
||||
self.__close_one_chunk(self.chunk_groups[group_name][-1])
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
def access_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
|
||||
"""Make the chunk can be used for calculation."""
|
||||
if chunk in self.accessed_chunks:
|
||||
return
|
||||
return None
|
||||
self.__sub_memory_usage(chunk.memory_usage)
|
||||
if chunk.device_type == "cpu":
|
||||
chunk.shard_move(get_accelerator().get_current_device())
|
||||
self.__add_accessed_chunk(chunk)
|
||||
maybe_work = self.__add_accessed_chunk(chunk, async_access=async_access)
|
||||
self.__add_memory_usage(chunk.memory_usage)
|
||||
return maybe_work
|
||||
|
||||
def release_chunk(self, chunk: Chunk) -> None:
|
||||
"""Scatter the chunk in CUDA."""
|
||||
|
@ -251,10 +252,11 @@ class ChunkManager:
|
|||
for k, v in usage.items():
|
||||
self.total_mem[k] += v
|
||||
|
||||
def __add_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.access_chunk()
|
||||
def __add_accessed_chunk(self, chunk: Chunk, async_access: bool = False) -> Optional[dist.Work]:
|
||||
maybe_work = chunk.access_chunk(async_access=async_access)
|
||||
self.accessed_chunks.add(chunk)
|
||||
self.accessed_mem += chunk.chunk_mem
|
||||
return maybe_work
|
||||
|
||||
def __sub_accessed_chunk(self, chunk: Chunk):
|
||||
chunk.release_chunk()
|
||||
|
|
|
@ -78,6 +78,7 @@ class GeminiDDP(ModelWrapper):
|
|||
chunk_init_device: torch.device = torch.device("cpu"),
|
||||
placement_policy: str = "static",
|
||||
enable_gradient_accumulation: bool = False,
|
||||
max_prefetch: int = 0,
|
||||
shard_param_frac: float = 1.0, # only for static placement
|
||||
offload_optim_frac: float = 0.0, # only for static placement
|
||||
offload_param_frac: float = 0.0, # only for static placement
|
||||
|
@ -130,6 +131,7 @@ class GeminiDDP(ModelWrapper):
|
|||
offload_param_frac=offload_param_frac,
|
||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||
max_prefetch=max_prefetch,
|
||||
)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import List
|
|||
|
||||
import torch
|
||||
|
||||
from colossalai.logging import DistributedLogger
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
from colossalai.utils import is_ddp_ignored
|
||||
from colossalai.zero.gemini import TensorState
|
||||
|
@ -16,6 +17,9 @@ class TrainingPhase(Enum):
|
|||
BACKWARD = 1
|
||||
|
||||
|
||||
logger = DistributedLogger("gemini_hook")
|
||||
|
||||
|
||||
class GeminiZeROHook(ColoParamOpHook):
|
||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__()
|
||||
|
@ -24,16 +28,37 @@ class GeminiZeROHook(ColoParamOpHook):
|
|||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
# map params to chunks
|
||||
params = [p for p in params if not is_ddp_ignored(p)]
|
||||
chunks = self._chunk_manager.get_chunks(params)
|
||||
all_chunks = self._chunk_manager.get_chunks(params)
|
||||
|
||||
# wait for prefetched chunks, filter those are not prefetched
|
||||
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
|
||||
|
||||
# transfer state
|
||||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
self._gemini_manager.sample_overall_data()
|
||||
self._gemini_manager.adjust_layout(chunks)
|
||||
for chunk in chunks:
|
||||
|
||||
# evit chunks, aware of async fetched
|
||||
self._gemini_manager.adjust_layout(
|
||||
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
|
||||
)
|
||||
|
||||
# fetch the rest synchronously
|
||||
for chunk in chunks_fetch_sync:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
|
||||
# record cuda model data of the current OP
|
||||
# get possible chunks to prefetch
|
||||
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
|
||||
|
||||
# prefetch
|
||||
for chunk in chunks_fetch_async:
|
||||
maybe_work = self._chunk_manager.access_chunk(chunk, async_access=True)
|
||||
if maybe_work is not None:
|
||||
self._gemini_manager.add_work(chunk, maybe_work)
|
||||
|
||||
# record cuda model data of the current OP, including memory for prefetched chunks
|
||||
self._gemini_manager.record_model_data_volume()
|
||||
|
||||
def post_op(self, params):
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
import functools
|
||||
from time import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
from .memory_tracer import ChunkMemStatsCollector, MemStats
|
||||
from .placement_policy import PlacementPolicyFactory
|
||||
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
|
||||
|
||||
|
||||
class GeminiManager:
|
||||
|
@ -41,9 +42,10 @@ class GeminiManager:
|
|||
self._mem_stats_collector = (
|
||||
ChunkMemStatsCollector(chunk_manager, self._memstats) if policy_cls.need_mem_stats else None
|
||||
)
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
||||
self._placement_policy = policy_cls(self, chunk_manager, self._mem_stats_collector, **placement_kwargs)
|
||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||
self._compute_idx: int = -1
|
||||
self._async_works: Dict[Chunk, dist.work] = {}
|
||||
|
||||
self._h2d_volume = 0
|
||||
self._d2h_volume = 0
|
||||
|
@ -91,18 +93,20 @@ class GeminiManager:
|
|||
self._warmup = False
|
||||
self.reset_attributes()
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||
"""Adjust the layout of stateful tensors according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
start = time()
|
||||
self._record_chunks_order(chunks)
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
|
||||
cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||
# don't evict chunks that are asynchronously fetched
|
||||
can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works]
|
||||
self._layout_time += time() - start
|
||||
|
||||
vol, evict_time = self._placement_policy.evict_tensors(
|
||||
can_evict_chunks=hold_cuda_tensor_list,
|
||||
can_evict_chunks=can_evict_chunks,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
|
@ -114,6 +118,21 @@ class GeminiManager:
|
|||
# move COMPUTE tensors to CUDA
|
||||
self._h2d_volume += cuda_demand
|
||||
|
||||
def wait_chunks(self, chunks: Iterable[Chunk]) -> Tuple[Chunk]:
|
||||
non_prefetched_chunks = []
|
||||
for chunk in chunks:
|
||||
if chunk in self._async_works:
|
||||
self._async_works[chunk].wait()
|
||||
del self._async_works[chunk]
|
||||
else:
|
||||
non_prefetched_chunks.append(chunk)
|
||||
return tuple(non_prefetched_chunks)
|
||||
|
||||
def add_work(self, chunk: Chunk, work: dist.Work):
|
||||
assert work is not None
|
||||
assert chunk not in self._async_works
|
||||
self._async_works[chunk] = work
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...]):
|
||||
start = time()
|
||||
|
@ -133,9 +152,9 @@ class GeminiManager:
|
|||
can_evict_chunks = self._chunk_manager.get_cuda_movable_chunks()
|
||||
return cuda_demand, can_evict_chunks
|
||||
|
||||
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
def _record_warmup_chunks_order(self, chunks: Tuple[Chunk, ...], record_anyway: bool = False) -> None:
|
||||
self._compute_idx += 1
|
||||
if self._warmup and self._placement_policy.need_mem_stats:
|
||||
if self._warmup and (self._placement_policy.need_mem_stats or record_anyway):
|
||||
self._compute_list.append(chunks)
|
||||
|
||||
def sample_overall_data(self):
|
||||
|
@ -156,6 +175,18 @@ class GeminiManager:
|
|||
return self._mem_stats_collector.cuda_margin_mem
|
||||
return None
|
||||
|
||||
@property
|
||||
def compute_list(self) -> List[Tuple[Chunk, ...]]:
|
||||
return self._compute_list
|
||||
|
||||
@property
|
||||
def compute_idx(self) -> int:
|
||||
return self._compute_idx
|
||||
|
||||
@property
|
||||
def placement_policy(self) -> PlacementPolicy:
|
||||
return self._placement_policy
|
||||
|
||||
@property
|
||||
def is_cuda_margin_mem_avail(self) -> bool:
|
||||
return self._placement_policy.need_mem_stats
|
||||
|
|
|
@ -18,10 +18,17 @@ class PlacementPolicy(ABC):
|
|||
need_mem_stats: bool = False
|
||||
|
||||
def __init__(
|
||||
self, chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, **kwargs
|
||||
self,
|
||||
gemini_manager: "GeminiManager",
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
max_prefetch: int = 0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = chunk_manager
|
||||
self.mem_stats_collector: Optional[ChunkMemStatsCollector] = mem_stats_collector
|
||||
self.max_prefetch = max_prefetch
|
||||
|
||||
@abstractmethod
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> Tuple[int, float]:
|
||||
|
@ -33,18 +40,26 @@ class PlacementPolicy(ABC):
|
|||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StaticPlacementPolicy(PlacementPolicy):
|
||||
def __init__(
|
||||
self,
|
||||
gemini_manager: "GeminiManager",
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
max_prefetch: int = 0,
|
||||
shard_param_frac: float = 1.0,
|
||||
offload_optim_frac: float = 0.0,
|
||||
offload_param_frac: float = 0.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
super().__init__(
|
||||
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
||||
)
|
||||
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
||||
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
|
||||
offload_param_frac = 0.0
|
||||
|
@ -95,19 +110,38 @@ class StaticPlacementPolicy(PlacementPolicy):
|
|||
self.keep_gathered_chunk_mem = total_chunk_mem * (1 - self.shard_param_frac)
|
||||
self.keep_cuda_chunk_mem = total_chunk_mem * (1 - self.offload_param_frac)
|
||||
|
||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||
if self.gemini_manager.is_warmup(): # no prefetch during warmup since we need compute_list
|
||||
return []
|
||||
can_prefetch = self.max_prefetch - len(self.gemini_manager._async_works)
|
||||
prefetch = []
|
||||
for i in range(self.gemini_manager.compute_idx + 1, len(self.gemini_manager.compute_list)):
|
||||
for chunk in self.gemini_manager.compute_list[i]:
|
||||
if len(prefetch) >= can_prefetch:
|
||||
break
|
||||
if chunk not in prefetch and chunk not in self.chunk_manager.accessed_chunks:
|
||||
prefetch.append(chunk)
|
||||
if len(prefetch) >= can_prefetch:
|
||||
break
|
||||
return prefetch
|
||||
|
||||
|
||||
class AutoPlacementPolicy(PlacementPolicy):
|
||||
need_mem_stats: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gemini_manager: "GeminiManager",
|
||||
chunk_manager: ChunkManager,
|
||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||
max_prefetch: int = 0,
|
||||
warmup_non_model_data_ratio: float = 0.8,
|
||||
steady_cuda_cap_ratio: float = 0.9,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
super().__init__(
|
||||
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
||||
)
|
||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||
|
@ -198,6 +232,9 @@ class AutoPlacementPolicy(PlacementPolicy):
|
|||
else:
|
||||
grads_device_map[p] = torch.device("cpu")
|
||||
|
||||
def get_prefetch_chunks(self, max_prefetch: int) -> List[Chunk]:
|
||||
return [] # TODO @botbw: implement prefetching for auto
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, Type[PlacementPolicy]] = {
|
||||
|
|
|
@ -30,8 +30,9 @@ def get_profile_context(enable_flag, warmup_steps, active_steps, save_dir):
|
|||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=schedule(wait=0, warmup=warmup_steps, active=active_steps),
|
||||
on_trace_ready=tensorboard_trace_handler(save_dir),
|
||||
record_shapes=True,
|
||||
profile_memory=True,
|
||||
# record_shapes=True,
|
||||
# profile_memory=True,
|
||||
with_stack=True,
|
||||
)
|
||||
else:
|
||||
return nullcontext(DummyProfiler())
|
||||
|
|
|
@ -129,7 +129,7 @@ def main():
|
|||
WARMUP_STEPS = 1
|
||||
assert WARMUP_STEPS < NUM_STEPS, "warmup steps should smaller than the total steps"
|
||||
assert (NUM_STEPS - WARMUP_STEPS) % 2 == 1, "the number of valid steps should be odd to take the median"
|
||||
PROF_FLAG = False # The flag of profiling, False by default
|
||||
PROF_FLAG = True # The flag of profiling, False by default
|
||||
|
||||
disable_existing_loggers()
|
||||
colossalai.launch_from_torch()
|
||||
|
@ -166,7 +166,7 @@ def main():
|
|||
stage=zero_stage, reduce_bucket_size_in_m=12, overlap_communication=True, verbose=True
|
||||
)
|
||||
elif args.distplan == "CAI_Gemini":
|
||||
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
|
||||
plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd, max_prefetch=1)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
|
@ -248,7 +248,7 @@ def main():
|
|||
prof.step()
|
||||
|
||||
tflops_list.sort()
|
||||
median_index = ((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS
|
||||
median_index = min(((NUM_STEPS - WARMUP_STEPS) >> 1) + WARMUP_STEPS, len(tflops_list) - 1)
|
||||
logger.info(f"Median TFLOPS is {tflops_list[median_index]:.3f}")
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
|
|
@ -40,9 +40,7 @@ EXAMPLE_MODELS = [
|
|||
]
|
||||
|
||||
# bfloat16 cannot represent them exactly
|
||||
BF16_IGNORED_KEYS = [
|
||||
"masked_bias",
|
||||
]
|
||||
BF16_IGNORED_KEYS = ["masked_bias"]
|
||||
|
||||
|
||||
def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dtype):
|
||||
|
|
Loading…
Reference in New Issue