mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix stm cuda model data size (#710)
parent
140263a394
commit
715b86eadd
|
@ -6,6 +6,7 @@ from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
|||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from typing import Dict, List
|
||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -48,14 +49,13 @@ class StatefulTensorMgr(object):
|
|||
# find stateful tensor in state COMPUTE
|
||||
move_to_cuda_tensor_list = []
|
||||
cuda_demand = 0
|
||||
used_cuda_model_data = 0
|
||||
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage
|
||||
hold_cuda_tensor_list = []
|
||||
for tensor in self._stateful_tensor_list:
|
||||
if tensor.state == TensorState.FREE:
|
||||
continue
|
||||
|
||||
if tensor.device.type == 'cuda':
|
||||
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
|
||||
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
|
||||
hold_cuda_tensor_list.append(tensor)
|
||||
elif tensor.device.type == 'cpu':
|
||||
|
|
Loading…
Reference in New Issue