[hotfix] fix test_stateful_tensor_mgr (#762)

pull/754/head^2
ver217 2022-04-14 15:50:09 +08:00 committed by GitHub
parent 6978980f6d
commit dcca614eee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 13 deletions

View File

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Optional, Dict
from typing import List, Optional
import torch
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
@ -79,7 +79,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
for t in to_free_tensor_list:
if freed_cuda_model_data > to_free_cuda_model_data:
if freed_cuda_model_data >= to_free_cuda_model_data:
break
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
colo_model_data_tensor_move_inline(t, torch.device('cpu'))

View File

@ -5,7 +5,7 @@ import torch.multiprocessing as mp
from colossalai.utils.cuda import get_current_device
from colossalai.utils.memory_tracer import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
from colossalai.utils.memory import colo_set_process_memory_fraction
from colossalai.zero.utils import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
@ -21,18 +21,22 @@ class Net(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
# each parameter is 512 MB
self.p0 = Parameter(torch.empty(1024, 1024, 128))
self.p1 = Parameter(torch.empty(1024, 1024, 128))
self.p2 = Parameter(torch.empty(1024, 1024, 128))
# each parameter is 128 MB
self.p0 = Parameter(torch.empty(1024, 1024, 32))
self.p1 = Parameter(torch.empty(1024, 1024, 32))
self.p2 = Parameter(torch.empty(1024, 1024, 32))
def limit_cuda_memory(memory_in_g: float):
cuda_capacity = torch.cuda.get_device_properties(get_current_device()).total_memory
fraction = (memory_in_g * 1024**3) / cuda_capacity
colo_set_process_memory_fraction(fraction)
def run_stm():
cuda_capacity = colo_device_memory_capacity(get_current_device())
fraction = (1.4 * 1024**3) / cuda_capacity
# limit max memory to 1.4GB
# which means only 2 parameters can be on CUDA
colo_set_process_memory_fraction(fraction)
# warmup phase use 20% CUDA memory to store params
# only 2 params can be on CUDA
limit_cuda_memory(1.26)
model = Net()
for p in model.parameters():
p.colo_attr = ShardedParamV2(p, set_data_none=True)
@ -65,6 +69,8 @@ def run_stm():
stateful_tensor_mgr.reset()
# warmup done
# only 2 params can be on CUDA
limit_cuda_memory(0.26)
# use OPT-like eviction strategy
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
mem_collector.sample_model_data()
@ -112,7 +118,7 @@ def run_dist(rank, world_size, port):
run_stm()
@pytest.mark.skip
@pytest.mark.gpu
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
def test_stateful_tensor_manager(world_size=1):
run_func = partial(run_dist, world_size=world_size, port=free_port())
@ -120,4 +126,5 @@ def test_stateful_tensor_manager(world_size=1):
if __name__ == '__main__':
# this unit test can pass if available CUDA memory >= 1.5G
test_stateful_tensor_manager()