mirror of https://github.com/hpcaitech/ColossalAI
[zero] stateful tensor manager (#687)
* [WIP] stateful tensor manager * add eviction strategy * polish code * polish code * polish comment * add unit test * fix sampler bug * polish code * fix max sampling cnt resetting bug * fix sampler bug * polish code * fix bug * fix unit test Co-authored-by: jiaruifang <fangjiarui123@gmail.com>pull/705/head
parent
70e8dd418b
commit
3c9cd5bb5e
|
@ -7,6 +7,7 @@ from colossalai.utils import get_current_device
|
|||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
|
||||
|
@ -21,31 +22,41 @@ class ZeroHook(BaseOpHook):
|
|||
|
||||
def __init__(self,
|
||||
shard_strategy: BaseShardStrategy,
|
||||
memstarts_collector: Optional[MemStatsCollector],
|
||||
memstarts_collector: Optional[MemStatsCollector] = None,
|
||||
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
||||
process_group: Optional[dist.ProcessGroup] = None):
|
||||
super().__init__()
|
||||
self.shard_strategy = shard_strategy
|
||||
self.process_group = process_group
|
||||
|
||||
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
||||
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
||||
|
||||
self._memstarts_collector = memstarts_collector
|
||||
self._stateful_tensor_mgr = stateful_tensor_mgr
|
||||
|
||||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
if self._stateful_tensor_mgr:
|
||||
self._stateful_tensor_mgr.adjust_layout()
|
||||
else:
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters(recurse=False):
|
||||
|
@ -60,19 +71,27 @@ class ZeroHook(BaseOpHook):
|
|||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
if self._stateful_tensor_mgr:
|
||||
self._stateful_tensor_mgr.adjust_layout()
|
||||
else:
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters(recurse=False):
|
||||
|
@ -91,4 +110,5 @@ class ZeroHook(BaseOpHook):
|
|||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
||||
if self._stateful_tensor_mgr:
|
||||
self._stateful_tensor_mgr.reset()
|
||||
|
|
|
@ -20,10 +20,12 @@ class SamplingCounter:
|
|||
assert self._max_sampling_cnt is not None
|
||||
return (self._samplint_cnt + 1) % self._max_sampling_cnt
|
||||
|
||||
@property
|
||||
def sampling_cnt(self):
|
||||
def current(self):
|
||||
return self._samplint_cnt
|
||||
|
||||
def max(self):
|
||||
return self._max_sampling_cnt
|
||||
|
||||
def reset(self):
|
||||
self._max_sampling_cnt = self._samplint_cnt
|
||||
self._samplint_cnt = 0
|
||||
|
@ -37,7 +39,7 @@ class MemStatsCollector:
|
|||
The first iteration of DNN training.
|
||||
Phase 2. Runtime Phase: use the read-only collected stats
|
||||
The rest iterations of DNN training.
|
||||
|
||||
|
||||
It has a Sampling counter which is reset after DNN training iteration.
|
||||
"""
|
||||
|
||||
|
@ -50,6 +52,8 @@ class MemStatsCollector:
|
|||
self._model_data_cpu_list = []
|
||||
self._overall_cpu_list = []
|
||||
|
||||
self._non_model_data_cuda_list = []
|
||||
self._non_model_data_cpu_list = []
|
||||
self._sampling_time = []
|
||||
|
||||
self._start_flag = False
|
||||
|
@ -96,18 +100,20 @@ class MemStatsCollector:
|
|||
raise TypeError
|
||||
|
||||
if device_type == 'cuda':
|
||||
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cuda_list, self._model_data_cuda_list)]
|
||||
return [elem / scale for elem in self._non_model_data_cuda_list]
|
||||
elif device_type == 'cpu':
|
||||
return [(v1 - v2) / scale for v1, v2 in zip(self._overall_cpu_list, self._model_data_cpu_list)]
|
||||
return [elem / scale for elem in self._non_model_data_cpu_list]
|
||||
else:
|
||||
raise TypeError
|
||||
|
||||
def current_non_model_data(self, device_type: str) -> int:
|
||||
"""get the non model data of current sampling moment
|
||||
"""get the non model data of the current sampling moment
|
||||
"""
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.sampling_cnt]
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.current()]
|
||||
|
||||
def next_non_model_data(self, device_type: str):
|
||||
"""get the non model data of the next sampling moment
|
||||
"""
|
||||
return self.non_model_data_list(device_type)[self._sampling_cnter.next()]
|
||||
|
||||
@property
|
||||
|
@ -128,18 +134,20 @@ class MemStatsCollector:
|
|||
Advance the sampling cnter.
|
||||
"""
|
||||
if self._start_flag:
|
||||
sampling_cnt = self._sampling_cnter.sampling_cnt
|
||||
sampling_cnt = self._sampling_cnter.current()
|
||||
assert sampling_cnt == len(self._overall_cuda_list)
|
||||
self._model_data_cuda_list.append(GLOBAL_MODEL_DATA_TRACER.cuda_usage)
|
||||
self._overall_cuda_list.append(self._mem_monitor.finish())
|
||||
self._non_model_data_cuda_list.append(self._model_data_cuda_list[-1] - self._overall_cuda_list[-1])
|
||||
|
||||
self._model_data_cpu_list.append(GLOBAL_MODEL_DATA_TRACER.cpu_usage)
|
||||
|
||||
# FIXME() cpu sys used should also return from self._mem_monitor()
|
||||
# FIXME(jiaruifang) cpu sys used should also return from self._mem_monitor()
|
||||
self._overall_cpu_list.append(colo_device_memory_used(torch.device(f'cpu')))
|
||||
|
||||
self._non_model_data_cpu_list.append(self._overall_cpu_list[-1] - self._model_data_cpu_list[-1])
|
||||
self._sampling_time.append(time.time())
|
||||
self._mem_monitor.start()
|
||||
# TODO(ver217): refactor sampler
|
||||
# print(f'{self._sampling_cnter.current()} / {self._sampling_cnter.max()}, len = {len(self._sampling_time)}')
|
||||
self._sampling_cnter.advance()
|
||||
|
||||
def reset_sampling_cnter(self) -> None:
|
||||
|
@ -155,4 +163,4 @@ class MemStatsCollector:
|
|||
|
||||
self._start_flag = False
|
||||
self._sampling_cnter.reset()
|
||||
self._mem_monitor.finish()
|
||||
self._mem_monitor.finish()
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .base_shard_strategy import BaseShardStrategy
|
||||
from .bucket_tensor_shard_strategy import BucketTensorShardStrategy
|
||||
from .tensor_shard_strategy import TensorShardStrategy
|
||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy']
|
||||
__all__ = ['BaseShardStrategy', 'TensorShardStrategy', 'BucketTensorShardStrategy', 'StatefulTensorMgr']
|
||||
|
|
|
@ -1,26 +1,43 @@
|
|||
import functools
|
||||
import torch
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
import types
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
||||
from typing import Set
|
||||
from typing import Dict, List
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class StatefulTensorMgr(SingletonMeta):
|
||||
_stateful_tensor_list: Set[ShardedParamV2] = set()
|
||||
class StatefulTensorMgr(object):
|
||||
"""
|
||||
Stateful Tensor Manager, inspired from PatrickStar
|
||||
|
||||
def register_param(self, param: ShardedParamV2) -> None:
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, mem_stats_collector: MemStatsCollector) -> None:
|
||||
self._stateful_tensor_list: List[StatefulTensor] = []
|
||||
self._mem_stats_collector = mem_stats_collector
|
||||
self._logger = get_dist_logger("StatefulTensorMgr")
|
||||
|
||||
self._warmup = True
|
||||
self._warmup_cuda_available_ratio = 0.2
|
||||
|
||||
self._compute_list: List[StatefulTensor] = []
|
||||
self._compute_idx: int = -1
|
||||
|
||||
def register_stateful_param(self, param: ShardedParamV2) -> None:
|
||||
assert isinstance(param, ShardedParamV2)
|
||||
for t in param.get_payload_tensors():
|
||||
assert isinstance(t, StatefulTensor)
|
||||
self._stateful_tensor_list.add(t)
|
||||
self._stateful_tensor_list.append(t)
|
||||
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
|
||||
|
||||
def evict_tensors(self) -> None:
|
||||
pass
|
||||
|
||||
def adjust_layout(self, mem_stats_collector: MemStatsCollector) -> None:
|
||||
def adjust_layout(self) -> None:
|
||||
""" Adjust the layout of statefuil tensor according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
|
||||
|
@ -41,29 +58,62 @@ class StatefulTensorMgr(SingletonMeta):
|
|||
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
|
||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||
hold_cuda_tensor_list.append(tensor)
|
||||
else:
|
||||
elif tensor.device.type == 'cpu':
|
||||
if tensor.state == TensorState.COMPUTE:
|
||||
move_to_cuda_tensor_list.append(tensor)
|
||||
cuda_demand += colo_tensor_mem_usage(tensor.payload)[0]
|
||||
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = max(mem_stats_collector.current_non_model_data('cuda'),
|
||||
mem_stats_collector.next_non_model_data('cuda'))
|
||||
cuda_demand += colo_tensor_mem_usage(tensor.payload)[1]
|
||||
else:
|
||||
raise RuntimeError
|
||||
cuda_capacity = colo_cuda_memory_capacity()
|
||||
cuda_model_data_period = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
if cuda_model_data_period < used_cuda_model_data + cuda_demand:
|
||||
# move cuda_model_data_period - cuda_demand - used_cuda_model_data volume of tensor
|
||||
# Here use a naive eviction strategy.
|
||||
acc_size = 0
|
||||
for t in hold_cuda_tensor_list:
|
||||
if acc_size > cuda_demand:
|
||||
break
|
||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||
t_size = colo_tensor_mem_usage(t)
|
||||
acc_size += t_size
|
||||
if acc_size < cuda_demand:
|
||||
raise RuntimeError("Adjust layout failed! No enough CUDA memory!")
|
||||
|
||||
if self._warmup:
|
||||
# We designate a part of CUDA memory for model data in warmup iterations.
|
||||
max_cuda_non_model_data_per_period = cuda_capacity * self._warmup_cuda_available_ratio
|
||||
else:
|
||||
# max non-model-data cuda memory consumption of this sampling moment and the next sampling moment.
|
||||
max_cuda_non_model_data_per_period = max(self._mem_stats_collector.current_non_model_data('cuda'),
|
||||
self._mem_stats_collector.next_non_model_data('cuda'))
|
||||
|
||||
total_cuda_model_data = cuda_capacity - max_cuda_non_model_data_per_period
|
||||
avail_cuda_model_data = total_cuda_model_data - used_cuda_model_data
|
||||
|
||||
if avail_cuda_model_data < cuda_demand:
|
||||
# Move cuda_demand - avail_cuda_model_data volume of tensors
|
||||
# to_free_cuda_model_data = cuda_demand - avail_cuda_model_data
|
||||
self.evict_tensors(hold_cuda_tensor_list, cuda_demand - avail_cuda_model_data)
|
||||
# move COMPUTE tensors to CUDA
|
||||
for t in move_to_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||
|
||||
def reset(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
|
||||
def evict_tensors(self, hold_cuda_tensor_list, to_free_cuda_model_data):
|
||||
freed_cuda_model_data = 0
|
||||
to_free_tensor_list = hold_cuda_tensor_list
|
||||
if not self._warmup:
|
||||
next_compute_idx: Dict[StatefulTensor, int] = {t: len(self._compute_list) for t in hold_cuda_tensor_list}
|
||||
for i in range(len(self._compute_list) - 1, self._compute_idx, -1):
|
||||
if self._compute_list[i] in next_compute_idx:
|
||||
next_compute_idx[self._compute_list[i]] = i
|
||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
|
||||
for t in to_free_tensor_list:
|
||||
if freed_cuda_model_data > to_free_cuda_model_data:
|
||||
break
|
||||
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||
if freed_cuda_model_data < to_free_cuda_model_data:
|
||||
raise RuntimeError(
|
||||
f"Adjust layout failed! No enough CUDA memory! Need {to_free_cuda_model_data}, freed {freed_cuda_model_data}"
|
||||
)
|
||||
|
||||
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
||||
trans_state_func(state)
|
||||
if state == TensorState.COMPUTE:
|
||||
self._compute_idx += 1
|
||||
if self._warmup:
|
||||
self._compute_list.append(stateful_tensor)
|
||||
|
|
|
@ -23,6 +23,7 @@ from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
|||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.zero.shard_utils.stateful_tensor_mgr import StatefulTensorMgr
|
||||
|
||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
|
@ -36,7 +37,6 @@ class ShardedModelV2(nn.Module):
|
|||
|
||||
Note:
|
||||
You must use ``ShardedModelV2`` with ``ShardedOptimizerV2``.
|
||||
|
||||
Note:
|
||||
Make sure you don't use gradient accumulation and your optimizer can work with fp16 gradient and fp32 parameter,
|
||||
if you enable ``reuse_fp16_shard``.
|
||||
|
@ -106,12 +106,21 @@ class ShardedModelV2(nn.Module):
|
|||
if self._use_memory_tracer:
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(self)
|
||||
self._memstats_collector = MemStatsCollector()
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._memstats_collector)
|
||||
# for param in module.parameters():
|
||||
for submodule in module.modules():
|
||||
for param in submodule.parameters(recurse=False):
|
||||
if hasattr(param, 'colo_attr'):
|
||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||
else:
|
||||
self._memstats_collector = None
|
||||
self._stateful_tensor_mgr = None
|
||||
self._iter_cnter = 0
|
||||
|
||||
# Register hooks
|
||||
self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]
|
||||
self._ophook_list = [
|
||||
ZeroHook(self.shard_strategy, self._memstats_collector, self._stateful_tensor_mgr, self.process_group)
|
||||
]
|
||||
register_ophooks_recursively(self.module, self._ophook_list, filter_fn=lambda m: not m.param_is_sharded)
|
||||
self.param_hook_mgr = BaseParamHookMgr(self.sharded_params)
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
@ -138,6 +147,9 @@ class ShardedModelV2(nn.Module):
|
|||
self._cuda_margin_space = 0
|
||||
self.reuse_fp16_shard = reuse_fp16_shard
|
||||
|
||||
def adjust_stateful_tensor_layout(self) -> None:
|
||||
self._stateful_tensor_mgr.adjust_layout()
|
||||
|
||||
@property
|
||||
def use_memory_tracer(self):
|
||||
return self._use_memory_tracer
|
||||
|
@ -150,20 +162,15 @@ class ShardedModelV2(nn.Module):
|
|||
def cpu_offload(self):
|
||||
return self._cpu_offload
|
||||
|
||||
def dump_memory_stats(self, filename: str = 'dump_mem_stats.log') -> None:
|
||||
"""Dummy memory tracer collected infomation to a file.
|
||||
|
||||
Example::
|
||||
|
||||
try:
|
||||
# forward: model(inputs)
|
||||
# backward: optimizer.backward()
|
||||
except Exception as e:
|
||||
model.dump_memory_stats()
|
||||
exit(0)
|
||||
|
||||
Args:
|
||||
filename (str, optional): Output file name. Defaults to 'dump_mem_stats.log'.
|
||||
def dump_memory_stats(self, filename: Optional[str] = 'dump_mem_stats.log') -> None:
|
||||
"""
|
||||
dummy memory tracer collected infomation to a file.
|
||||
try:
|
||||
# forward: model(inputs)
|
||||
# backward: optimizer.backward()
|
||||
except Exception as e:
|
||||
model.dump_memory_stats()
|
||||
exit(0)
|
||||
"""
|
||||
if self._use_memory_tracer:
|
||||
self.logger.error(f'dump memort tracer collected infomation to a {filename}', ranks=[0])
|
||||
|
@ -172,12 +179,12 @@ class ShardedModelV2(nn.Module):
|
|||
f.write(f'cuda reserved {torch.cuda.memory_reserved(get_current_device())/1e9} GB\n')
|
||||
f.write(f'cuda max allocated {torch.cuda.max_memory_allocated(get_current_device())/1e9} GB\n')
|
||||
f.write('CUDA model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.model_data_cuda_list('cuda', 'GB')))
|
||||
f.write(str(self._memstats_collector.model_data_list('cuda', 'GB')))
|
||||
f.write('\n')
|
||||
f.write('CUDA non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.non_model_data_cuda_list('cuda', 'GB')))
|
||||
f.write(str(self._memstats_collector.non_model_data_list('cuda', 'GB')))
|
||||
f.write('CPU non model data (GB)\n')
|
||||
f.write(str(self._memstats_collector.non_model_data_cuda_list('cpu', 'GB')))
|
||||
f.write(str(self._memstats_collector.non_model_data_list('cpu', 'GB')))
|
||||
f.write('\n')
|
||||
|
||||
def _pre_forward_operations(self):
|
||||
|
|
|
@ -350,7 +350,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
|
||||
|
||||
if not is_param_sharded and not self.keep_unshard:
|
||||
# We gather full fp16 param here
|
||||
|
|
|
@ -26,7 +26,7 @@ class ShardedParamV2(object):
|
|||
def get_payload_tensors(self) -> List[StatefulTensor]:
|
||||
"""returns stateful tensors kept by this class.
|
||||
"""
|
||||
return [self._sharded_data_tensor, self.saved_grad]
|
||||
return [self._sharded_data_tensor]
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
import torch
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity, colo_set_process_memory_fraction
|
||||
from colossalai.zero.shard_utils import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from torch.nn.parameter import Parameter
|
||||
from typing import List
|
||||
from functools import partial
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
# each parameter is 512 MB
|
||||
self.p0 = Parameter(torch.empty(1024, 1024, 128))
|
||||
self.p1 = Parameter(torch.empty(1024, 1024, 128))
|
||||
self.p2 = Parameter(torch.empty(1024, 1024, 128))
|
||||
|
||||
|
||||
def run_stm():
|
||||
cuda_capacity = colo_cuda_memory_capacity()
|
||||
fraction = (1.4 * 1024**3) / cuda_capacity
|
||||
# limit max memory to 1.4GB
|
||||
# which means only 2 parameters can be on CUDA
|
||||
colo_set_process_memory_fraction(fraction)
|
||||
model = Net()
|
||||
for p in model.parameters():
|
||||
p.colo_attr = ShardedParamV2(p, rm_torch_payload=True)
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(model)
|
||||
mem_collector = MemStatsCollector()
|
||||
stateful_tensor_mgr = StatefulTensorMgr(mem_collector)
|
||||
for p in model.parameters():
|
||||
stateful_tensor_mgr.register_stateful_param(p.colo_attr)
|
||||
|
||||
mem_collector.start_collection()
|
||||
# Compute order: 0 1 2 0 1
|
||||
# warmup
|
||||
# use naive eviction strategy
|
||||
apply_adjust(model, model.p0, [model.p0], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p2, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
mem_collector.finish_collection()
|
||||
mem_collector.reset_sampling_cnter()
|
||||
stateful_tensor_mgr.reset()
|
||||
|
||||
# warmup done
|
||||
# use OPT-like eviction strategy
|
||||
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p1, [model.p0, model.p1], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p2, [model.p0, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p0, [model.p0, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
apply_adjust(model, model.p1, [model.p1, model.p2], stateful_tensor_mgr)
|
||||
mem_collector.sample_memstats()
|
||||
|
||||
|
||||
def apply_adjust(model: torch.nn.Module, compute_param: Parameter, cuda_param_after_adjust: List[Parameter],
|
||||
stateful_tensor_mgr: StatefulTensorMgr):
|
||||
compute_param.colo_attr._sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
for p in model.parameters():
|
||||
if p is not compute_param and p.colo_attr._sharded_data_tensor.state != TensorState.HOLD:
|
||||
p.colo_attr._sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
stateful_tensor_mgr.adjust_layout()
|
||||
print_stats(model)
|
||||
device = torch.device(torch.cuda.current_device())
|
||||
cuda_param_after_adjust = [hash(p) for p in cuda_param_after_adjust]
|
||||
for n, p in model.named_parameters():
|
||||
if hash(p) in cuda_param_after_adjust:
|
||||
assert p.colo_attr._sharded_data_tensor.device == device, f'{n} {p.colo_attr._sharded_data_tensor.device} vs {device}'
|
||||
else:
|
||||
assert p.colo_attr._sharded_data_tensor.device == torch.device('cpu')
|
||||
|
||||
|
||||
def print_stats(model: torch.nn.Module):
|
||||
msgs = []
|
||||
for n, p in model.named_parameters():
|
||||
msgs.append(f'{n}: {p.colo_attr._sharded_data_tensor.state}({p.colo_attr._sharded_data_tensor.device})')
|
||||
print(f'[ {", ".join(msgs)} ]')
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_stm()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||
def test_stateful_tensor_manager(world_size=1):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_stateful_tensor_manager()
|
Loading…
Reference in New Issue