mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix test_stateful_tensor_mgr (#762)
parent
6978980f6d
commit
dcca614eee
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||||
|
@ -79,7 +79,7 @@ class AutoTensorPlacementPolicy(TensorPlacementPolicy):
|
||||||
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
next_compute_idx = sorted(next_compute_idx.items(), key=lambda pair: pair[1], reverse=True)
|
||||||
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
|
to_free_tensor_list = [t for (t, idx) in next_compute_idx]
|
||||||
for t in to_free_tensor_list:
|
for t in to_free_tensor_list:
|
||||||
if freed_cuda_model_data > to_free_cuda_model_data:
|
if freed_cuda_model_data >= to_free_cuda_model_data:
|
||||||
break
|
break
|
||||||
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
freed_cuda_model_data += colo_tensor_mem_usage(t)[0]
|
||||||
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
colo_model_data_tensor_move_inline(t, torch.device('cpu'))
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch.multiprocessing as mp
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.memory_tracer import MemStatsCollector
|
from colossalai.utils.memory_tracer import MemStatsCollector
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity, colo_set_process_memory_fraction
|
from colossalai.utils.memory import colo_set_process_memory_fraction
|
||||||
from colossalai.zero.utils import StatefulTensorMgr
|
from colossalai.zero.utils import StatefulTensorMgr
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
|
@ -21,18 +21,22 @@ class Net(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# each parameter is 512 MB
|
# each parameter is 128 MB
|
||||||
self.p0 = Parameter(torch.empty(1024, 1024, 128))
|
self.p0 = Parameter(torch.empty(1024, 1024, 32))
|
||||||
self.p1 = Parameter(torch.empty(1024, 1024, 128))
|
self.p1 = Parameter(torch.empty(1024, 1024, 32))
|
||||||
self.p2 = Parameter(torch.empty(1024, 1024, 128))
|
self.p2 = Parameter(torch.empty(1024, 1024, 32))
|
||||||
|
|
||||||
|
|
||||||
|
def limit_cuda_memory(memory_in_g: float):
|
||||||
|
cuda_capacity = torch.cuda.get_device_properties(get_current_device()).total_memory
|
||||||
|
fraction = (memory_in_g * 1024**3) / cuda_capacity
|
||||||
|
colo_set_process_memory_fraction(fraction)
|
||||||
|
|
||||||
|
|
||||||
def run_stm():
|
def run_stm():
|
||||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
# warmup phase use 20% CUDA memory to store params
|
||||||
fraction = (1.4 * 1024**3) / cuda_capacity
|
# only 2 params can be on CUDA
|
||||||
# limit max memory to 1.4GB
|
limit_cuda_memory(1.26)
|
||||||
# which means only 2 parameters can be on CUDA
|
|
||||||
colo_set_process_memory_fraction(fraction)
|
|
||||||
model = Net()
|
model = Net()
|
||||||
for p in model.parameters():
|
for p in model.parameters():
|
||||||
p.colo_attr = ShardedParamV2(p, set_data_none=True)
|
p.colo_attr = ShardedParamV2(p, set_data_none=True)
|
||||||
|
@ -65,6 +69,8 @@ def run_stm():
|
||||||
stateful_tensor_mgr.reset()
|
stateful_tensor_mgr.reset()
|
||||||
|
|
||||||
# warmup done
|
# warmup done
|
||||||
|
# only 2 params can be on CUDA
|
||||||
|
limit_cuda_memory(0.26)
|
||||||
# use OPT-like eviction strategy
|
# use OPT-like eviction strategy
|
||||||
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
|
apply_adjust(model, model.p0, [model.p0, model.p1], stateful_tensor_mgr)
|
||||||
mem_collector.sample_model_data()
|
mem_collector.sample_model_data()
|
||||||
|
@ -112,7 +118,7 @@ def run_dist(rank, world_size, port):
|
||||||
run_stm()
|
run_stm()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip
|
@pytest.mark.gpu
|
||||||
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*")
|
||||||
def test_stateful_tensor_manager(world_size=1):
|
def test_stateful_tensor_manager(world_size=1):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
|
@ -120,4 +126,5 @@ def test_stateful_tensor_manager(world_size=1):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
# this unit test can pass if available CUDA memory >= 1.5G
|
||||||
test_stateful_tensor_manager()
|
test_stateful_tensor_manager()
|
||||||
|
|
Loading…
Reference in New Issue