mirror of https://github.com/hpcaitech/ColossalAI
[gemini] polish stateful_tensor_mgr (#876)
parent
e43f83aa5c
commit
425b4a96b8
|
@ -6,7 +6,6 @@ from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, c
|
|||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
|
||||
from typing import List
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class StatefulTensorMgr(object):
|
||||
|
@ -20,23 +19,30 @@ class StatefulTensorMgr(object):
|
|||
def __init__(self, tensor_placement_policy: TensorPlacementPolicy) -> None:
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = tensor_placement_policy
|
||||
self._stateful_tensor_list: List[StatefulTensor] = []
|
||||
self._logger = get_dist_logger("StatefulTensorMgr")
|
||||
|
||||
self._warmup = True
|
||||
|
||||
self._compute_list: List[StatefulTensor] = []
|
||||
self._compute_idx: int = -1
|
||||
|
||||
self._cpu_gpu_move_volume = 0
|
||||
self._warmup = True
|
||||
|
||||
def register_stateful_param(self, param) -> None:
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
assert isinstance(param, ShardedParamV2)
|
||||
for t in param.get_payload_tensors():
|
||||
def register_stateful_tensor_list(self, tensor_list: List[StatefulTensor]) -> None:
|
||||
assert self._stateful_tensor_list == [], "Can't register stateful tensors for manager twice"
|
||||
self._stateful_tensor_list = tensor_list
|
||||
for t in self._stateful_tensor_list:
|
||||
assert isinstance(t, StatefulTensor)
|
||||
self._stateful_tensor_list.append(t)
|
||||
t.trans_state = types.MethodType(functools.partial(self._trans_state, t.trans_state), t)
|
||||
|
||||
def start_iter(self):
|
||||
pass
|
||||
|
||||
def finish_iter(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
self._cpu_gpu_move_volume = 0
|
||||
|
||||
def adjust_layout(self) -> None:
|
||||
""" Adjust the layout of statefuil tensor according to the information provided
|
||||
by mem_stats_collector, which should belongs to a Sharded Model.
|
||||
|
@ -63,21 +69,14 @@ class StatefulTensorMgr(object):
|
|||
compute_list=self._compute_list,
|
||||
compute_idx=self._compute_idx)
|
||||
# move COMPUTE tensors to CUDA
|
||||
self._cpu_gpu_move_volume += cuda_demand
|
||||
for t in move_to_cuda_tensor_list:
|
||||
colo_model_data_tensor_move_inline(t, get_current_device())
|
||||
self._cpu_gpu_move_volume += t.payload_size
|
||||
|
||||
@property
|
||||
def cpu_gpu_move_volume(self):
|
||||
return self._cpu_gpu_move_volume
|
||||
|
||||
def reset(self):
|
||||
"""This function must be called when each iteration finishes
|
||||
"""
|
||||
self._warmup = False
|
||||
self._compute_idx = -1
|
||||
self._cpu_gpu_move_volume = 0
|
||||
|
||||
def _trans_state(self, trans_state_func, stateful_tensor, state):
|
||||
trans_state_func(state)
|
||||
if state == TensorState.COMPUTE:
|
||||
|
|
|
@ -111,10 +111,10 @@ class ShardedModelV2(nn.Module):
|
|||
self._memstats_collector = None
|
||||
self._tensor_placement_policy: TensorPlacementPolicy = TensorPlacementPolicyFactory.create(
|
||||
tensor_placement_policy)(mem_stats_collector=self._memstats_collector)
|
||||
|
||||
self._stateful_tensor_mgr = StatefulTensorMgr(self._tensor_placement_policy)
|
||||
for param in module.parameters():
|
||||
if hasattr(param, 'colo_attr'):
|
||||
self._stateful_tensor_mgr.register_stateful_param(param.colo_attr)
|
||||
param_tensor_list = [p.colo_attr.sharded_data_tensor for p in module.parameters() if hasattr(p, 'colo_attr')]
|
||||
self._stateful_tensor_mgr.register_stateful_tensor_list(param_tensor_list)
|
||||
|
||||
# Register hooks
|
||||
self._ophook_list = [
|
||||
|
@ -198,6 +198,8 @@ class ShardedModelV2(nn.Module):
|
|||
if hasattr(p, 'colo_attr'):
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
self._stateful_tensor_mgr.start_iter()
|
||||
|
||||
def _post_forward_operations(self):
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'colo_attr'):
|
||||
|
|
|
@ -115,4 +115,4 @@ class ZeroHook(BaseOpHook):
|
|||
if self._stateful_tensor_mgr:
|
||||
self.logger.info(
|
||||
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB", ranks=[0])
|
||||
self._stateful_tensor_mgr.reset()
|
||||
self._stateful_tensor_mgr.finish_iter()
|
||||
|
|
Loading…
Reference in New Issue