mirror of https://github.com/hpcaitech/ColossalAI
[zero] add tensor placement policies (#743)
* add tensor placement policies * polish comments * polish comments * update moe unit testspull/748/head
parent
22c4b88d56
commit
e396bb71f2
|
@ -23,6 +23,7 @@ from colossalai.zero.sharded_param.tensorful_state import TensorState
|
|||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.zero.utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.zero.utils.tensor_placement_policy import TENSOR_PLACEMENT_POLICIES, TensorPlacementPolicy
|
||||
|
||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
|
@ -48,6 +49,11 @@ class ShardedModelV2(nn.Module):
|
|||
Generally, it should be `None`, and it's the same as `process_group`. Defaults to None.
|
||||
reduce_scatter_bucket_size_mb (int, optional): Reduce-scatter bucket size in *MB*. Defaults to 25.
|
||||
fp32_reduce_scatter (bool, optional): If set to `True`, gradients are forced to FP32 before reduce-scatter. Defaults to False.
|
||||
tensor_placement_policy (str): Which device to place *held* tensors. It can be 'cpu', 'cuda' and 'auto'.
|
||||
If it's 'cpu', parameters, gradients and optimizer states will be offloaded to CPU, which means min CUDA memory will be used.
|
||||
If it's 'cuda', they won't be offloaded, which means max CUDA memory will be used.
|
||||
If it's 'auto', they are moving dynamically based on CPU and CUDA memory usage. It will utilize heterogeneous memory space evenly and well.
|
||||
Defaults to 'cuda'.
|
||||
offload_config (Optional[dict], optional): We currently only support CPU offload. Set to `{"device": "cpu"}` to enable CPU offload. Defaults to None.
|
||||
gradient_predivide_factor (Optional[float], optional): Gradient is divived by this value before reduce-scatter. Defaults to 1.0.
|
||||
use_memory_tracer (bool, optional): Whether to use memoty tracer. Defaults to False.
|
||||
|
@ -65,9 +71,8 @@ class ShardedModelV2(nn.Module):
|
|||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_bucket_size_mb: int = 25,
|
||||
fp32_reduce_scatter: bool = False,
|
||||
offload_config: Optional[dict] = None,
|
||||
tensor_placement_policy: str = 'cuda',
|
||||
gradient_predivide_factor: Optional[float] = 1.0,
|
||||
use_memory_tracer: bool = False,
|
||||
reuse_fp16_shard: bool = False):
|
||||
super().__init__()
|
||||
self.logger = get_dist_logger()
|
||||
|
@ -100,20 +105,22 @@ class ShardedModelV2(nn.Module):
|
|||
self.rank = dist.get_rank(self.process_group)
|
||||
self.shard_strategy = shard_strategy
|
||||
|
||||
assert tensor_placement_policy in TENSOR_PLACEMENT_POLICIES, f'Invalid tensor_placement_policy, got {tensor_placement_policy}'
|
||||
# Init Memory Statistics Collector
|
||||
self._use_memory_tracer = use_memory_tracer
|
||||
self._use_memory_tracer = tensor_placement_policy == 'auto'
|
||||
if self._use_memory_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
|
||||
for param in module.parameters():
|
||||
if hasattr(param, 'colo_attr'):
|
||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||
self._start_collect_memstats = disposable(self._memstats_collector.start_collection)
|
||||
self._finish_collect_memstats = disposable(self._memstats_collector.finish_collection)
|
||||
else:
|
||||
self._memstats_collector = None
|
||||
self._stateful_tensor_mgr = None
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = TENSOR_PLACEMENT_POLICIES[tensor_placement_policy](
|
||||
mem_stats_collector=self._memstats_collector)
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
||||
for param in module.parameters():
|
||||
if hasattr(param, 'colo_attr'):
|
||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||
|
||||
# Register hooks
|
||||
self._ophook_list = [
|
||||
|
@ -124,7 +131,7 @@ class ShardedModelV2(nn.Module):
|
|||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
self.fp32_reduce_scatter = fp32_reduce_scatter
|
||||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
self._cpu_offload: bool = tensor_placement_policy != 'cuda'
|
||||
for param in module.parameters():
|
||||
# Init `offload_grad`
|
||||
param.colo_attr.offload_grad = self._cpu_offload
|
||||
|
|
|
@ -16,12 +16,12 @@ from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_m
|
|||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
|
||||
|
||||
class OptimState(Enum):
|
||||
|
@ -57,10 +57,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
sharded_model (ShardedModelV2): A sharded model initialized by class ShardedModelV2. The optimizer will use the
|
||||
shard strategy provided by sharded model to shard param fp32 tensors.
|
||||
optimizer (Optimizer): An Optimizer instance.
|
||||
cpu_offload (bool, optional): Is offloading the optimizer states to CPU.. Defaults to False.
|
||||
gpu_margin_mem_ratio (float, optional): The ratio of GPU remaining memory (after the first forward-backward)
|
||||
which will be used when using hybrid CPU optimizer.
|
||||
Make sure `reuse_fp16_shard` is enabled in `ShardedModelV2`, if `gpu_margin_mem_ratio` > `0.0`.
|
||||
This argument is meaningless when `tensor_placement_policy` of `ShardedModelV2` is not "auto".
|
||||
Defaults to 0.0.
|
||||
initial_scale (float, optional): Initial scale used by DynamicGradScaler. Defaults to 2**32.
|
||||
min_scale (float, optional): Min scale used by DynamicGradScaler. Defaults to 1.
|
||||
|
@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
def __init__(self,
|
||||
sharded_model: ShardedModelV2,
|
||||
optimizer: Optimizer,
|
||||
cpu_offload: bool = False,
|
||||
gpu_margin_mem_ratio: float = 0.0,
|
||||
initial_scale: float = 2**32,
|
||||
min_scale: float = 1,
|
||||
|
@ -95,18 +94,15 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
super().__init__(optimizer)
|
||||
self.shard_strategy = sharded_model.shard_strategy
|
||||
self.model: ShardedModelV2 = sharded_model
|
||||
if cpu_offload and not sharded_model.cpu_offload:
|
||||
raise RuntimeError(
|
||||
f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload"
|
||||
)
|
||||
|
||||
self.gpu_margin_mem_ratio: float = float(gpu_margin_mem_ratio)
|
||||
assert 0.0 <= self.gpu_margin_mem_ratio <= 1.0, f'gpu_margin_mem_ratio must >=0.0 and <=1.0'
|
||||
# Only move fp32 shards from CPU to GPU when user allows and inner optimizer is valid
|
||||
# Inner optimizer must support optimizing hybrid (CPU and CUDA) tensors,
|
||||
# and it must set `num_fp32_shards_per_param` correctly
|
||||
self._should_move_fp32_shards_h2d: bool = cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
|
||||
self._should_move_fp32_shards_h2d: bool = sharded_model.cpu_offload and self.gpu_margin_mem_ratio > 0.0 and getattr(
|
||||
optimizer, 'num_fp32_shards_per_param', 0) >= 2
|
||||
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
|
||||
self.device = sharded_model._tensor_placement_policy.device or torch.device('cpu')
|
||||
self.optim_state: OptimState = OptimState.UNSCALED
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.mp_process_group = mp_process_group or gpc.get_group(ParallelMode.MODEL)
|
||||
|
@ -123,7 +119,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
# Store fp32 param shards
|
||||
self._register_master_weight()
|
||||
|
||||
if self.gpu_margin_mem_ratio != 0.0 and isinstance(sharded_model._tensor_placement_policy,
|
||||
AutoTensorPlacementPolicy):
|
||||
self._logger.warning(f'gpu_margin_mem_ratio is meaningless when tensor_placement_policy is not "auto"')
|
||||
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0] / 1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
|
||||
|
|
|
@ -5,10 +5,8 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from typing import Dict, List
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.zero.utils.tensor_placement_policy import TensorPlacementPolicy
|
||||
from typing import List
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
|
@ -20,13 +18,12 @@ class StatefulTensorMgr(object):
|
|||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, mem_stats_collector: MemStatsCollector) -> None:
|
||||
def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None:
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy
|
||||
self._stateful_tensor_list: List[StatefulTensor] = []
|
||||
self._mem_stats_collector = mem_stats_collector
|
||||
self._logger = get_dist_logger("StatefulTensorMgr")
|
||||
|
||||
self._warmup = True
|
||||
self._warmup_cuda_available_ratio = 0.2
|
||||
|
||||
self._compute_list: List[StatefulTensor] = []
|
||||
self._compute_idx: int = -1
|
||||
|
@ -47,9 +44,8 @@ class StatefulTensorMgr(object):
|
|||
It contains non-model footprint of a DNN model.
|
||||
"""
|
||||
# find stateful tensor in state COMPUTE
|
||||
move_to_cuda_tensor_list = []
|
||||
cuda_demand = 0
|
||||
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage
|
||||
move_to_cuda_tensor_list = []
|
||||
hold_cuda_tensor_list = []
|
||||
for tensor in self._stateful_tensor_list:
|
||||
if tensor.state == TensorState.FREE:
|
||||
|
@ -64,22 +60,11 @@ class StatefulTensorMgr(object):
|
|||
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
|
||||
else:
|
||||
raise RuntimeError
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
|
||||
if self._warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self._mem_stats_collector.max_non_model_data('cuda')
|
||||
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
|
||||
if avail_cuda_model_data < cuda_demand:
|
||||
# Move cuda_demand - avail_cuda_model_data volume of tensors
|
||||
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
self.evict_tensors(hold_cuda_tensor_list, cuda_demand - avail_cuda_model_data)
|
||||
self._tensor_placement_policy.evict_tensors(hold_cuda_tensor_list,
|
||||
cuda_demand=cuda_demand,
|
||||
warmup=self._warmup,
|
||||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
# move COMPUTE tensors to CUDA
|
||||
for t in move_to_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||
|
@ -90,26 +75,6 @@ class StatefulTensorMgr(object):
|
|||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list, to_free_cuda_model_data):
|
||||
freed_cuda_model_data = 0
|
||||
to_free_tensor_list = hold_cuda_tensor_list
|
||||
if not self._warmup:
|
||||
next_compute_idx: Dict[StatefulTensor, int] = {t: len(self._compute_list) for t in hold_cuda_tensor_list}
|
||||
for i in range(len(self._compute_list) - 1, self._compute_idx, -1):
|
||||
if self._compute_list[i] in next_compute_idx:
|
||||
next_compute_idx[self._compute_list[i]] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
|
||||
for t in to_free_tensor_list:
|
||||
if freed_cuda_model_data > to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
|
||||
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
||||
trans_state_func(state)
|
||||
if state == TensorState.COMPUTE:
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
from typing import List, Optional, Dict
|
||||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
|
||||
__all__ = ['TENSOR_PLACEMENT_POLICIES']
|
||||
|
||||
|
||||
class TensorPlacementPolicy:
|
||||
|
||||
def __init__(self, device: Optional[torch.device], mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
self.device: Optional[torch.device] = device
|
||||
self.mem_stats_collector: Optional[MemStatsCollector] = mem_stats_collector
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CPUTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
super().__init__(torch.device('cpu'), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
||||
for t in hold_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, self.device)
|
||||
|
||||
|
||||
class CUDATensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
assert torch.cuda.is_available(), 'Cannot use CUDATensorPlacementPolicy when CUDA is not available'
|
||||
super().__init__(get_current_device(), mem_stats_collector=mem_stats_collector)
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list: List[StatefulTensor], **kwargs) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||
|
||||
def __init__(self, mem_stats_collector: Optional[MemStatsCollector] = None) -> None:
|
||||
super().__init__(None, mem_stats_collector=mem_stats_collector)
|
||||
self._warmup_non_model_data_ratio: float = 0.2
|
||||
|
||||
def evict_tensors(self,
|
||||
hold_cuda_tensor_list: List[StatefulTensor],
|
||||
cuda_demand: int = 0,
|
||||
warmup: bool = True,
|
||||
compute_list: List[StatefulTensor] = [],
|
||||
compute_idx: int = 0,
|
||||
**kwargs) -> None:
|
||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage
|
||||
if warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_non_model_data_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = self.mem_stats_collector.max_non_model_data('cuda')
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
if avail_cuda_model_data < cuda_demand:
|
||||
# Move cuda_demand - avail_cuda_model_data volume of tensors
|
||||
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
freed_cuda_model_data = 0
|
||||
to_free_tensor_list = hold_cuda_tensor_list
|
||||
if not warmup:
|
||||
next_compute_idx: Dict[StatefulTensor, int] = {t: len(compute_list) for t in hold_cuda_tensor_list}
|
||||
for i in range(len(compute_list) - 1, compute_idx, -1):
|
||||
if compute_list[i] in next_compute_idx:
|
||||
next_compute_idx[compute_list[i]] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
|
||||
for t in to_free_tensor_list:
|
||||
if freed_cuda_model_data > to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
|
||||
|
||||
TENSOR_PLACEMENT_POLICIES = {
|
||||
'cpu': CPUTensorPlacementPolicy,
|
||||
'cuda': CUDATensorPlacementPolicy,
|
||||
'auto': AutoTensorPlacementPolicy
|
||||
}
|
|
@ -32,7 +32,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
|||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = MoeModel(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
|
||||
# check whether parameters are identical in ddp
|
||||
for name, p in zero_model.named_parameters():
|
||||
|
|
|
@ -69,8 +69,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
|||
|
||||
zero_model = ShardedModelV2(zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
tensor_placement_policy='cpu' if cpu_offload else 'cuda',
|
||||
reuse_fp16_shard=reuse_fp16_shard)
|
||||
|
||||
# check whether parameters are identical in ddp
|
||||
|
@ -88,7 +87,6 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
|||
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
sharded_optim,
|
||||
cpu_offload=cpu_offload,
|
||||
initial_scale=2**5,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
|
|
|
@ -13,14 +13,12 @@ MP_PARALLEL_CONFIG = dict(fp16=dict(mode=None,), parallel=dict(pipeline=dict(siz
|
|||
|
||||
_ZERO_MODEL_CONFIG = dict(reduce_scatter_bucket_size_mb=25,
|
||||
fp32_reduce_scatter=False,
|
||||
offload_config=None,
|
||||
tensor_placement_policy='cuda',
|
||||
gradient_predivide_factor=1.0,
|
||||
use_memory_tracer=False,
|
||||
shard_strategy=TensorShardStrategy(),
|
||||
reuse_fp16_shard=False)
|
||||
|
||||
_ZERO_OPTIMIZER_CONFIG = dict(cpu_offload=False,
|
||||
initial_scale=2**5,
|
||||
_ZERO_OPTIMIZER_CONFIG = dict(initial_scale=2**5,
|
||||
min_scale=1,
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
|
|
|
@ -37,16 +37,12 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
|
|||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
tensor_placement_policy='cpu' if cpu_offload else 'cuda',
|
||||
reuse_fp16_shard=True,
|
||||
)
|
||||
|
||||
sharded_optim = HybridAdam(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
sharded_optim,
|
||||
cpu_offload=cpu_offload,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model, sharded_optim, gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 1:
|
||||
|
|
|
@ -33,7 +33,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
|||
shard_strategy=shard_strategy,
|
||||
shard_param=True):
|
||||
zero_model = model_builder(checkpoint=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy, use_memory_tracer=True)
|
||||
zero_model = ShardedModelV2(zero_model, shard_strategy)
|
||||
|
||||
model = model_builder(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
|
|
|
@ -64,8 +64,7 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
|||
zero_model = ShardedModelV2(
|
||||
zero_model,
|
||||
shard_strategy,
|
||||
offload_config=dict(device='cpu') if cpu_offload else None,
|
||||
use_memory_tracer=gpu_margin_mem_ratio > 0.0,
|
||||
tensor_placement_policy='cpu' if cpu_offload else 'cuda',
|
||||
reuse_fp16_shard=use_cpuadam,
|
||||
)
|
||||
|
||||
|
@ -79,7 +78,6 @@ def _run_test_sharded_optim_v2(cpu_offload, shard_strategy_class, use_cpuadam, g
|
|||
sharded_optim = optimizer_class(zero_model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(zero_model,
|
||||
sharded_optim,
|
||||
cpu_offload=cpu_offload,
|
||||
initial_scale=2**5,
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from colossalai.testing import rerun_on_exception
|
|||
from torch.nn.parameter import Parameter
|
||||
from typing import List
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
@ -37,7 +38,8 @@ def run_stm():
|
|||
p.colo_attr = ShardedParamV2(p, set_data_none=True)
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(model)
|
||||
mem_collector = MemStatsCollector()
|
||||
stateful_tensor_mgr = StatefulTensorMgr(mem_collector)
|
||||
tensor_placement_policy = AutoTensorPlacementPolicy(mem_stats_collector=mem_collector)
|
||||
stateful_tensor_mgr = StatefulTensorMgr(tensor_placement_policy)
|
||||
for p in model.parameters():
|
||||
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
|
||||
|
||||
|
|
Loading…
Reference in New Issue