mirror of https://github.com/hpcaitech/ColossalAI
[gemini] zero supports gemini (#1093)
* add placement policy * add gemini mgr * update mem stats collector * update zero * update zero optim * fix bugs * zero optim monitor os * polish unit test * polish unit test * add assertpull/1098/head
parent
2b2dc1c86b
commit
1f894e033f
|
@ -0,0 +1,113 @@
|
|||
import functools
|
||||
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||
from typing import List, Optional, Tuple
|
||||
from time import time
|
||||
from colossalai.tensor.chunk import Chunk, ChunkManager
|
||||
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
|
||||
|
||||
|
||||
class GeminiManager:
|
||||
"""
|
||||
Stateful Tensor Manager, inspired from PatrickStar
|
||||
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, placement_policy: str, chunk_manager: ChunkManager) -> None:
|
||||
# TODO: remove assert
|
||||
assert placement_policy == 'cuda', 'placement_policy can only be "cuda" now'
|
||||
assert placement_policy in PlacementPolicyFactory.get_polocy_names()
|
||||
policy_cls = PlacementPolicyFactory.create(placement_policy)
|
||||
self._chunk_manager = chunk_manager
|
||||
self._mem_stats_collector = MemStatsCollectorV2(chunk_manager) if policy_cls.need_mem_stats else None
|
||||
self._placement_policy = policy_cls(chunk_manager, self._mem_stats_collector)
|
||||
self._compute_list: List[Tuple[Chunk, ...]] = []
|
||||
self._compute_idx: int = -1
|
||||
|
||||
self._cpu_gpu_move_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
self._warmup = True
|
||||
|
||||
def pre_iter(self):
|
||||
if self._mem_stats_collector and self._warmup:
|
||||
self._mem_stats_collector.start_collection()
|
||||
|
||||
def post_iter(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
if self._mem_stats_collector and self._warmup:
|
||||
self._mem_stats_collector.finish_collection()
|
||||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
self._cpu_gpu_move_volume = 0
|
||||
self._layout_time = 0
|
||||
self._evict_time = 0
|
||||
|
||||
def adjust_layout(self, chunks: Tuple[Chunk, ...], group_name: str) -> None:
|
||||
""" Adjust the layout of statefuil tensor according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
start = time()
|
||||
self._record_chunks_order(chunks)
|
||||
cuda_demand, hold_cuda_tensor_list = self._get_layout_info(self._compute_idx, self._warmup, chunks, group_name)
|
||||
self._layout_time += time() - start
|
||||
vol, evict_time = self._placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
self._cpu_gpu_move_volume += vol
|
||||
self._evict_time += evict_time
|
||||
# move COMPUTE tensors to CUDA
|
||||
self._cpu_gpu_move_volume += cuda_demand
|
||||
|
||||
@property
|
||||
def cpu_gpu_move_volume(self):
|
||||
return self._cpu_gpu_move_volume
|
||||
|
||||
# @functools.lru_cache(maxsize=None)
|
||||
# TODO: test lru
|
||||
def _get_layout_info(self, compute_idx: int, warmup: bool, chunks: Tuple[Chunk, ...], group_name: str):
|
||||
cuda_demand = 0
|
||||
for chunk in chunks:
|
||||
if chunk.device_type == 'cpu' or chunk.is_free:
|
||||
cuda_demand += chunk.mem
|
||||
can_evict_chunks = []
|
||||
for chunk in self._chunk_manager.chunk_groups[group_name]:
|
||||
if not chunk.is_free and chunk.device_type == 'cuda' and chunk.can_move_device:
|
||||
can_evict_chunks.append(chunk)
|
||||
return cuda_demand, can_evict_chunks
|
||||
|
||||
def _record_chunks_order(self, chunks: Tuple[Chunk, ...]) -> None:
|
||||
self._compute_idx += 1
|
||||
if self._warmup and self._placement_policy.need_mem_stats:
|
||||
self._compute_list.append(chunks)
|
||||
|
||||
@property
|
||||
def default_device(self):
|
||||
return self._placement_policy.get_default_device()
|
||||
|
||||
def sample_overall_data(self):
|
||||
if self._mem_stats_collector:
|
||||
self._mem_stats_collector.sample_overall_data()
|
||||
|
||||
def sample_model_data(self):
|
||||
if self._mem_stats_collector:
|
||||
self._mem_stats_collector.sample_model_data()
|
||||
|
||||
@property
|
||||
def chunk_manager(self):
|
||||
return self._chunk_manager
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> Optional[float]:
|
||||
if self._mem_stats_collector:
|
||||
return self._mem_stats_collector.cuda_margin_mem
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_cuda_margin_mem_avail(self) -> bool:
|
||||
return self._placement_policy.need_mem_stats
|
|
@ -1,5 +1,6 @@
|
|||
from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||
from colossalai.utils.memory import colo_device_memory_used
|
||||
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.tensor import ChunkManager
|
||||
|
||||
|
@ -145,3 +146,7 @@ class MemStatsCollectorV2(MemStatsCollector):
|
|||
cpu_mem = self._chunk_manager.total_mem['cpu']
|
||||
self._model_data_cuda_list.append(cuda_mem)
|
||||
self._model_data_cpu_list.append(cpu_mem)
|
||||
|
||||
@property
|
||||
def cuda_margin_mem(self) -> float:
|
||||
return colo_device_memory_capacity(get_current_device()) - max(self.overall_mem_stats('cuda'))
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from time import time
|
||||
from typing import List, Optional, Tuple, Dict
|
||||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
|
||||
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||
from typing import Type
|
||||
import functools
|
||||
from colossalai.tensor.chunk import Chunk, ChunkManager
|
||||
|
||||
|
||||
class PlacementPolicy(ABC):
|
||||
need_mem_stats: bool = False
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
self.chunk_manager = chunk_manager
|
||||
self.mem_stats_collector: Optional[MemStatsCollectorV2] = mem_stats_collector
|
||||
|
||||
@abstractmethod
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_default_device() -> torch.device:
|
||||
return torch.device('cpu')
|
||||
|
||||
|
||||
class CPUPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
|
||||
volume = 0
|
||||
for chunk in can_evict_chunks:
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
volume += chunk.mem
|
||||
return volume, 0
|
||||
|
||||
|
||||
class CUDAPlacementPolicy(PlacementPolicy):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, can_evict_chunks: List[Chunk], **kwargs) -> int:
|
||||
return 0, 0
|
||||
|
||||
@staticmethod
|
||||
def get_default_device() -> torch.device:
|
||||
return get_current_device()
|
||||
|
||||
|
||||
class AutoPlacementPolicy(PlacementPolicy):
|
||||
|
||||
need_mem_stats: bool = True
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager, mem_stats_collector: Optional[MemStatsCollectorV2] = None) -> None:
|
||||
super().__init__(chunk_manager, mem_stats_collector=mem_stats_collector)
|
||||
# model data will use 1-self._warmup_non_model_data_ratio CUDA memory in warmup phase
|
||||
# TODO(ver217): make these args configurable
|
||||
self._warmup_non_model_data_ratio: float = 0.8
|
||||
self._steady_cuda_cap_ratio: float = 0.9
|
||||
|
||||
def evict_tensors(self,
|
||||
can_evict_chunks: List[Chunk],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: List[Tuple[Chunk, ...]] = [],
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> int:
|
||||
"""
|
||||
Evict tensors from CUDA device.
|
||||
|
||||
Args:
|
||||
hold_cuda_tensor_list (List[StatefulTensor]): the list of tensor in state of HOLD-like
|
||||
cuda_demand (int, optional): the volume of data needed on cuda device. Defaults to 0.
|
||||
warmup (bool, optional): a flag indicates whether in the phase of warmup. Defaults to True.
|
||||
compute_list (List[StatefulTensor], optional): TODO. Defaults to [].
|
||||
compute_idx (int, optional): the idx of computing device. Defaults to 0.
|
||||
|
||||
Raises:
|
||||
RuntimeError:
|
||||
|
||||
Returns:
|
||||
int: the volume of memory that is evicted
|
||||
"""
|
||||
start = time()
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
used_cuda_model_data = self.chunk_manager.total_mem['cuda']
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.next_period_non_model_data_usage('cuda')
|
||||
cuda_capacity *= self._steady_cuda_cap_ratio
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
freed_cuda_model_data = 0
|
||||
end = time()
|
||||
if avail_cuda_model_data < cuda_demand:
|
||||
# Move cuda_demand - avail_cuda_model_data volume of tensors
|
||||
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_chunks = can_evict_chunks
|
||||
if not warmup:
|
||||
to_free_chunks = self._sort_can_evict_chunks(tuple(to_free_chunks), compute_idx, tuple(compute_list))
|
||||
# print(self._sort_can_evict_chunks.cache_info())
|
||||
end = time()
|
||||
for chunk in to_free_chunks:
|
||||
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += chunk.mem
|
||||
self.chunk_manager.move_chunk(chunk, torch.device('cpu'))
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
return freed_cuda_model_data, end - start
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _sort_can_evict_chunks(can_evict_chunks: tuple, compute_idx: int, compute_list: tuple) -> list:
|
||||
next_compute_idx = {chunk: len(compute_list) for chunk in can_evict_chunks}
|
||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||
for chunk in compute_list[i]:
|
||||
if chunk in next_compute_idx:
|
||||
next_compute_idx[chunk] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
return [t for (t, idx) in next_compute_idx]
|
||||
|
||||
|
||||
class PlacementPolicyFactory:
|
||||
policies: Dict[str, PlacementPolicy] = {
|
||||
'cpu': CPUPlacementPolicy,
|
||||
'cuda': CUDAPlacementPolicy,
|
||||
'auto': AutoPlacementPolicy
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create(policy_name: str) -> Type[PlacementPolicy]:
|
||||
if policy_name not in PlacementPolicyFactory.policies:
|
||||
raise TypeError(f"Unknown tensor placement policy {policy_name}")
|
||||
return PlacementPolicyFactory.policies[policy_name]
|
||||
|
||||
@staticmethod
|
||||
def get_polocy_names():
|
||||
return tuple(PlacementPolicyFactory.policies.keys())
|
||||
|
||||
@staticmethod
|
||||
def get_default_device(policy_name: str) -> torch.device:
|
||||
policy_cls = PlacementPolicyFactory.create(policy_name)
|
||||
return policy_cls.get_default_device()
|
|
@ -4,8 +4,11 @@ from colossalai.core import global_context as gpc
|
|||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.tensor.chunk import ChunkManager, TensorState
|
||||
from colossalai.tensor.chunk import ChunkManager, TensorState, Chunk
|
||||
from colossalai.tensor.param_op_hook import use_param_op_hooks
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Dict
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
def free_storage(data: torch.Tensor) -> None:
|
||||
|
@ -89,12 +92,14 @@ class ColoDDP(torch.nn.Module):
|
|||
|
||||
class ColoDDPV2(ColoDDP):
|
||||
|
||||
def __init__(self, module: torch.nn.Module, chunk_manager: ChunkManager) -> None:
|
||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__(module)
|
||||
self.chunk_manager = chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(chunk_manager)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
self.fp32_params = []
|
||||
self.overflow_counter = 0
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = {}
|
||||
# TODO: get param order and filter unused params
|
||||
for p in module.parameters():
|
||||
assert p.dtype == torch.half
|
||||
|
@ -102,22 +107,32 @@ class ColoDDPV2(ColoDDP):
|
|||
self.chunk_manager.append_tensor(p, 'fp16_param')
|
||||
self.chunk_manager.append_tensor(fp32_p, 'fp32_param')
|
||||
self.fp32_params.append(fp32_p)
|
||||
self.grads_device[p] = self.gemini_manager.default_device
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
self.gemini_manager.pre_iter()
|
||||
with use_param_op_hooks(self.param_op_hook):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
return outputs
|
||||
|
||||
def _post_backward(self):
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
def _setup_grads_ptr(self):
|
||||
for p in self.module.parameters():
|
||||
if self.chunk_manager.get_chunk(p).is_free or not p.requires_grad:
|
||||
p.grad = None
|
||||
else:
|
||||
p.grad = p.data
|
||||
|
||||
def _post_backward(self):
|
||||
self.chunk_manager.exec_lazy_release()
|
||||
self._setup_grads_ptr()
|
||||
self._logger.info(
|
||||
f'layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, PCIE move vol: {self.gemini_manager._cpu_gpu_move_volume}B'
|
||||
)
|
||||
self.gemini_manager.post_iter()
|
||||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
with self.param_op_hook.switch_to_backward(), use_param_op_hooks(self.param_op_hook):
|
||||
loss.backward()
|
||||
|
@ -141,7 +156,12 @@ class ColoDDPV2(ColoDDP):
|
|||
self.chunk_manager.release_chunk(chunk)
|
||||
if reduced and not chunk.is_free:
|
||||
self.overflow_counter += chunk.has_inf_or_nan
|
||||
self.chunk_manager.move_chunk(chunk, self.grads_device[p])
|
||||
return empty_grad
|
||||
|
||||
def zero_grad(self, set_to_none: bool = False) -> None:
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
||||
def _set_chunk_grad_device(self, chunk: Chunk, device: torch.device) -> None:
|
||||
for tensor in chunk.get_tensors():
|
||||
self.grads_device[tensor] = device
|
||||
|
|
|
@ -178,6 +178,9 @@ class Chunk:
|
|||
def __eq__(self, __o: object) -> bool:
|
||||
return self is __o
|
||||
|
||||
def get_tensors(self) -> List[torch.Tensor]:
|
||||
return list(self.tensors_info.keys())
|
||||
|
||||
|
||||
class ChunkManager:
|
||||
|
||||
|
@ -234,6 +237,10 @@ class ChunkManager:
|
|||
|
||||
def access_chunk(self, chunk: Chunk) -> None:
|
||||
if chunk in self.accessed_chunks:
|
||||
if chunk.device_type != 'cuda':
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
chunk.move_device(get_current_device())
|
||||
self.total_mem[chunk.device_type] += chunk.mem
|
||||
return
|
||||
if not chunk.is_free:
|
||||
self.total_mem[chunk.device_type] -= chunk.mem
|
||||
|
|
|
@ -5,6 +5,7 @@ from enum import Enum
|
|||
from typing import List
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
class TrainingPhase(Enum):
|
||||
|
@ -14,9 +15,10 @@ class TrainingPhase(Enum):
|
|||
|
||||
class ZeROHookV2(ParamOpHook):
|
||||
|
||||
def __init__(self, chunk_manager: ChunkManager) -> None:
|
||||
def __init__(self, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__()
|
||||
self._chunk_manager = chunk_manager
|
||||
self._gemini_manager = gemini_manager
|
||||
self._chunk_manager = gemini_manager.chunk_manager
|
||||
self._training_phase = TrainingPhase.FORWARD
|
||||
|
||||
def pre_op(self, params):
|
||||
|
@ -24,9 +26,11 @@ class ZeROHookV2(ParamOpHook):
|
|||
for p in params:
|
||||
self._chunk_manager.trans_tensor_state(p, TensorState.COMPUTE)
|
||||
self._chunk_manager.exec_lazy_release()
|
||||
# TODO: evict chunks
|
||||
self._gemini_manager.sample_overall_data()
|
||||
self._gemini_manager.adjust_layout(chunks, 'fp16_param')
|
||||
for chunk in chunks:
|
||||
self._chunk_manager.access_chunk(chunk)
|
||||
self._gemini_manager.sample_model_data()
|
||||
|
||||
def post_op(self, params):
|
||||
for p in params:
|
||||
|
|
|
@ -7,6 +7,7 @@ from typing import Dict
|
|||
from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils import get_current_device, disposable
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -19,6 +20,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
def __init__(self,
|
||||
optim: Optimizer,
|
||||
module: ColoDDPV2,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
growth_factor: float = 2,
|
||||
|
@ -29,6 +31,8 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
super().__init__(optim)
|
||||
assert isinstance(module, ColoDDPV2)
|
||||
self.module = module
|
||||
self.gemini_manager = module.gemini_manager
|
||||
self.chunk_manager = self.gemini_manager.chunk_manager
|
||||
self.optim_state = OptimState.UNSCALED
|
||||
self.fp16_param_to_fp32_param: Dict[torch.Tensor, torch.Tensor] = {}
|
||||
for p, fp32_p in zip(module.parameters(), module.fp32_params):
|
||||
|
@ -45,6 +49,18 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
self._found_overflow: torch.Tensor = torch.zeros(1, dtype=torch.int64, device=torch.cuda.current_device())
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
|
||||
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
|
||||
# and it must set `num_fp32_shards_per_param` correctly
|
||||
self._should_move_fp32_params_h2d: bool = self.gemini_manager.is_cuda_margin_mem_avail and self.gpu_margin_mem_ratio > 0.0 and getattr(
|
||||
optim, 'num_fp32_shards_per_param', 0) >= 2
|
||||
if self.gpu_margin_mem_ratio > 0.0 and not self.gemini_manager.is_cuda_margin_mem_avail:
|
||||
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when placement_policy is not "auto"', ranks=[0])
|
||||
|
||||
self._register_states = disposable(self._register_states_)
|
||||
|
||||
def _update_params_ptr(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
|
@ -82,6 +98,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
return self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
self._maybe_move_fp32_params()
|
||||
# unscale grads if scaled
|
||||
if self.optim_state == OptimState.SCALED:
|
||||
self._unscale_grads()
|
||||
|
@ -94,6 +111,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
return
|
||||
self._update_params_ptr()
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
self._register_states()
|
||||
self._update_fp16_params()
|
||||
return ret
|
||||
|
||||
|
@ -109,3 +127,29 @@ class ZeroOptimizer(ColossalaiOptimizer):
|
|||
|
||||
def backward_by_grad(self, tensor: torch.Tensor, grad: torch.Tensor):
|
||||
self.module.backward_by_grad(tensor, grad)
|
||||
|
||||
def _maybe_move_fp32_params(self):
|
||||
if self._should_move_fp32_params_h2d:
|
||||
self._should_move_fp32_params_h2d = False
|
||||
available_cuda_margin_mem = self.gemini_manager.cuda_margin_mem * self.gpu_margin_mem_ratio
|
||||
fp32_params_available_cuda_margin_mem = available_cuda_margin_mem / self.optim.num_fp32_shards_per_param
|
||||
fp32_params_used_cuda_margin_mem = 0
|
||||
for fp16_param_chunk, fp32_param_chunk in zip(self.chunk_manager.chunk_groups['fp16_param'],
|
||||
self.chunk_manager.chunk_groups['fp32_param']):
|
||||
if fp32_param_chunk.is_free:
|
||||
continue
|
||||
if fp32_params_used_cuda_margin_mem + fp32_param_chunk.mem < fp32_params_available_cuda_margin_mem:
|
||||
self.chunk_manager.move_chunk(fp32_param_chunk, get_current_device())
|
||||
# stores grad now
|
||||
self.chunk_manager.move_chunk(fp16_param_chunk, get_current_device())
|
||||
self.module._set_chunk_grad_device(fp16_param_chunk, get_current_device())
|
||||
fp32_params_used_cuda_margin_mem += fp32_param_chunk.mem
|
||||
self.module._setup_grads_ptr()
|
||||
|
||||
def _register_states_(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
state = self.optim.state[p]
|
||||
for val in state.values():
|
||||
if isinstance(val, torch.Tensor):
|
||||
self.chunk_manager.add_extern_static_tensor(val)
|
||||
|
|
|
@ -14,6 +14,7 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel import ColoDDPV2
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
|
@ -44,7 +45,8 @@ def run_gpt(use_chunk, use_zero):
|
|||
model = model.half()
|
||||
chunk_size = 38 * 1024**2 if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
model = ColoDDPV2(model, chunk_manager)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
model = ColoDDPV2(model, gemini_manager)
|
||||
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
|
||||
print(chunk_manager)
|
||||
check_param_equal(model, torch_model)
|
||||
|
|
|
@ -18,6 +18,7 @@ from colossalai.nn.optimizer import HybridAdam
|
|||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
|
@ -53,7 +54,8 @@ def run_gpt(use_chunk, use_zero):
|
|||
|
||||
chunk_size = 38 * 1024**2 if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
model = ColoDDPV2(model, chunk_manager)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
model = ColoDDPV2(model, gemini_manager)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||
|
||||
|
|
Loading…
Reference in New Issue