[hotfix] fix stm cuda model data size (#710)

pull/715/head
ver217 2022-04-11 15:10:39 +08:00 committed by GitHub
parent 140263a394
commit 715b86eadd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -6,6 +6,7 @@ from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.utils.memory_utils.utils import colo_cuda_memory_capacity
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from typing import Dict, List
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.logging import get_dist_logger
@ -48,14 +49,13 @@ class StatefulTensorMgr(object):
# find stateful tensor in state COMPUTE
move_to_cuda_tensor_list = []
cuda_demand = 0
used_cuda_model_data = 0
used_cuda_model_data = GLOBAL_MODEL_DATA_TRACER.cuda_usage
hold_cuda_tensor_list = []
for tensor in self._stateful_tensor_list:
if tensor.state == TensorState.FREE:
continue
if tensor.device.type == 'cuda':
used_cuda_model_data += colo_tensor_mem_usage(tensor.payload)[0]
if tensor.state in [TensorState.HOLD, TensorState.HOLD_AFTER_BWD, TensorState.HOLD_AFTER_FWD]:
hold_cuda_tensor_list.append(tensor)
elif tensor.device.type == 'cpu':