mirror of https://github.com/hpcaitech/ColossalAI
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.cipull/5722/head
parent
82b25524ff
commit
6bbe956316
|
@ -131,7 +131,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
offload_param_frac=offload_param_frac,
|
offload_param_frac=offload_param_frac,
|
||||||
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
warmup_non_model_data_ratio=warmup_non_model_data_ratio,
|
||||||
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
steady_cuda_cap_ratio=steady_cuda_cap_ratio,
|
||||||
max_prefetch=max_prefetch
|
max_prefetch=max_prefetch,
|
||||||
)
|
)
|
||||||
self.force_outputs_fp32 = force_outputs_fp32
|
self.force_outputs_fp32 = force_outputs_fp32
|
||||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Iterable, Tuple
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from colossalai.logging import DistributedLogger
|
from colossalai.logging import DistributedLogger
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||||
|
@ -12,8 +11,6 @@ from colossalai.utils import is_ddp_ignored
|
||||||
from colossalai.zero.gemini import TensorState
|
from colossalai.zero.gemini import TensorState
|
||||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||||
|
|
||||||
from .chunk import Chunk
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingPhase(Enum):
|
class TrainingPhase(Enum):
|
||||||
FORWARD = 0
|
FORWARD = 0
|
||||||
|
@ -23,7 +20,9 @@ class TrainingPhase(Enum):
|
||||||
logger = DistributedLogger("gemini_hook")
|
logger = DistributedLogger("gemini_hook")
|
||||||
|
|
||||||
import os
|
import os
|
||||||
rank = int(os.environ['RANK'])
|
|
||||||
|
rank = int(os.environ["RANK"])
|
||||||
|
|
||||||
|
|
||||||
class GeminiZeROHook(ColoParamOpHook):
|
class GeminiZeROHook(ColoParamOpHook):
|
||||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||||
|
@ -32,14 +31,13 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||||
self._chunk_manager = gemini_manager.chunk_manager
|
self._chunk_manager = gemini_manager.chunk_manager
|
||||||
self._training_phase = TrainingPhase.FORWARD
|
self._training_phase = TrainingPhase.FORWARD
|
||||||
|
|
||||||
|
|
||||||
def pre_op(self, params):
|
def pre_op(self, params):
|
||||||
# map params to chunks
|
# map params to chunks
|
||||||
params = [p for p in params if not is_ddp_ignored(p)]
|
params = [p for p in params if not is_ddp_ignored(p)]
|
||||||
all_chunks = self._chunk_manager.get_chunks(params)
|
all_chunks = self._chunk_manager.get_chunks(params)
|
||||||
|
|
||||||
# wait for prefetched chunks, filter those are not prefetched
|
# wait for prefetched chunks, filter those are not prefetched
|
||||||
unique_chunks = set(all_chunks)
|
set(all_chunks)
|
||||||
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
|
chunks_fetch_sync = self._gemini_manager.wait_chunks(all_chunks)
|
||||||
|
|
||||||
# transfer state
|
# transfer state
|
||||||
|
@ -48,7 +46,9 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||||
self._gemini_manager.sample_overall_data()
|
self._gemini_manager.sample_overall_data()
|
||||||
|
|
||||||
# evit chunks, aware of async fetched
|
# evit chunks, aware of async fetched
|
||||||
self._gemini_manager.adjust_layout(all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0)
|
self._gemini_manager.adjust_layout(
|
||||||
|
all_chunks, record_anyway=self._gemini_manager.placement_policy.max_prefetch > 0
|
||||||
|
)
|
||||||
|
|
||||||
# fetch the rest synchronously
|
# fetch the rest synchronously
|
||||||
for chunk in chunks_fetch_sync:
|
for chunk in chunks_fetch_sync:
|
||||||
|
@ -57,7 +57,9 @@ class GeminiZeROHook(ColoParamOpHook):
|
||||||
# get possible chunks to prefetch
|
# get possible chunks to prefetch
|
||||||
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
|
chunks_fetch_async = self._gemini_manager.placement_policy.get_prefetch_chunks()
|
||||||
if rank == 0 and not self._gemini_manager.is_warmup():
|
if rank == 0 and not self._gemini_manager.is_warmup():
|
||||||
print(f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}")
|
print(
|
||||||
|
f"compute_id: {self._gemini_manager.compute_idx} self._gemini_manager.compute_list: {self._gemini_manager.compute_list}"
|
||||||
|
)
|
||||||
print(f"{all_chunks=}")
|
print(f"{all_chunks=}")
|
||||||
print(f"accessed_chunks={self._chunk_manager.accessed_chunks}")
|
print(f"accessed_chunks={self._chunk_manager.accessed_chunks}")
|
||||||
print(f"{chunks_fetch_sync=}")
|
print(f"{chunks_fetch_sync=}")
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import functools
|
import functools
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Dict, List, Optional, Tuple, Iterable
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -101,7 +101,7 @@ class GeminiManager:
|
||||||
start = time()
|
start = time()
|
||||||
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
|
self._record_warmup_chunks_order(chunks, record_anyway=record_anyway)
|
||||||
cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
cuda_demand, can_evict_chunks = self._get_layout_info(self._compute_idx, self._warmup, chunks)
|
||||||
# don't evict chunks that are asynchronously fetched
|
# don't evict chunks that are asynchronously fetched
|
||||||
can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works]
|
can_evict_chunks = [chunk for chunk in can_evict_chunks if chunk not in self._async_works]
|
||||||
self._layout_time += time() - start
|
self._layout_time += time() - start
|
||||||
|
|
||||||
|
|
|
@ -13,11 +13,17 @@ from colossalai.zero.gemini.chunk import Chunk
|
||||||
from .chunk import Chunk, ChunkManager
|
from .chunk import Chunk, ChunkManager
|
||||||
from .memory_tracer import ChunkMemStatsCollector
|
from .memory_tracer import ChunkMemStatsCollector
|
||||||
|
|
||||||
|
|
||||||
class PlacementPolicy(ABC):
|
class PlacementPolicy(ABC):
|
||||||
need_mem_stats: bool = False
|
need_mem_stats: bool = False
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, gemini_manager: 'GeminiManager', chunk_manager: ChunkManager, mem_stats_collector: Optional[ChunkMemStatsCollector] = None, max_prefetch:int = 0, **kwargs
|
self,
|
||||||
|
gemini_manager: "GeminiManager",
|
||||||
|
chunk_manager: ChunkManager,
|
||||||
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
|
max_prefetch: int = 0,
|
||||||
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.gemini_manager = gemini_manager
|
self.gemini_manager = gemini_manager
|
||||||
self.chunk_manager = chunk_manager
|
self.chunk_manager = chunk_manager
|
||||||
|
@ -38,13 +44,16 @@ class PlacementPolicy(ABC):
|
||||||
def get_prefetch_chunks(self) -> List[Chunk]:
|
def get_prefetch_chunks(self) -> List[Chunk]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
rank = int(os.environ["RANK"])
|
rank = int(os.environ["RANK"])
|
||||||
|
|
||||||
|
|
||||||
class StaticPlacementPolicy(PlacementPolicy):
|
class StaticPlacementPolicy(PlacementPolicy):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gemini_manager: 'GeminiManager',
|
gemini_manager: "GeminiManager",
|
||||||
chunk_manager: ChunkManager,
|
chunk_manager: ChunkManager,
|
||||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
max_prefetch: int = 0,
|
max_prefetch: int = 0,
|
||||||
|
@ -53,7 +62,9 @@ class StaticPlacementPolicy(PlacementPolicy):
|
||||||
offload_param_frac: float = 0.0,
|
offload_param_frac: float = 0.0,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
|
super().__init__(
|
||||||
|
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
||||||
|
)
|
||||||
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
if offload_param_frac > 0.0 and (shard_param_frac != 1.0 or offload_optim_frac != 1.0):
|
||||||
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
|
warnings.warn("offload_param_frac is ignored when shard_param_frac != 1.0 or offload_optim_frac != 1.0")
|
||||||
offload_param_frac = 0.0
|
offload_param_frac = 0.0
|
||||||
|
@ -124,7 +135,7 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gemini_manager: 'GeminiManager',
|
gemini_manager: "GeminiManager",
|
||||||
chunk_manager: ChunkManager,
|
chunk_manager: ChunkManager,
|
||||||
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
mem_stats_collector: Optional[ChunkMemStatsCollector] = None,
|
||||||
max_prefetch: int = 0,
|
max_prefetch: int = 0,
|
||||||
|
@ -132,7 +143,9 @@ class AutoPlacementPolicy(PlacementPolicy):
|
||||||
steady_cuda_cap_ratio: float = 0.9,
|
steady_cuda_cap_ratio: float = 0.9,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch)
|
super().__init__(
|
||||||
|
gemini_manager, chunk_manager, mem_stats_collector=mem_stats_collector, max_prefetch=max_prefetch
|
||||||
|
)
|
||||||
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
# model data will use 1-_warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||||
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
# you can set them by AutoPlacementPolicy.set_warmup_non_model_data_ratio()
|
||||||
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
# and AutoPlacementPolicy.set_steady_cuda_cap_ratio()
|
||||||
|
|
Loading…
Reference in New Issue