mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
74 lines
2.8 KiB
74 lines
2.8 KiB
3 years ago
|
import pytest
|
||
|
import torch
|
||
|
|
||
|
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
|
||
|
|
||
|
|
||
|
@pytest.mark.dist
|
||
|
def test_gemini_manager():
|
||
|
# reset the manager, in case that there exists memory information left
|
||
|
manager = StatefulTensor.GST_MGR
|
||
|
manager.reset()
|
||
|
|
||
|
# occupation 8
|
||
|
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
|
||
|
# occupation 60
|
||
|
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
|
||
|
|
||
|
# occupation 28
|
||
|
t1 = torch.empty(7, device='cuda')
|
||
|
# occupation 12
|
||
|
t2 = torch.empty(3, device='cpu')
|
||
|
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
|
||
|
st4 = StatefulTensor(None, TensorState.FREE)
|
||
|
|
||
|
assert manager.total_number == 4
|
||
|
assert manager.total_mem['cpu'] == 60
|
||
|
assert manager.total_mem['cuda'] == 36
|
||
|
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
|
||
|
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
|
||
|
|
||
|
st4.payload_reset(t2)
|
||
|
st3.payload_reset(t2)
|
||
|
|
||
|
assert manager.total_number == 4
|
||
|
assert manager.total_mem['cpu'] == 84
|
||
|
assert manager.total_mem['cuda'] == 8
|
||
|
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
|
||
|
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||
|
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
|
||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
|
||
|
|
||
|
st1.move_to(torch.device('cpu'))
|
||
|
st2.move_to(torch.device('cpu'))
|
||
|
st3.move_to(torch.device('cuda', 0))
|
||
|
|
||
|
assert manager.total_number == 4
|
||
|
assert manager.total_mem['cpu'] == 80
|
||
|
assert manager.total_mem['cuda'] == 12
|
||
|
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
|
||
|
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||
|
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||
|
|
||
|
st1.trans_state(TensorState.COMPUTE)
|
||
|
st2.trans_state(TensorState.COMPUTE)
|
||
|
st2.trans_state(TensorState.HOLD_AFTER_BWD)
|
||
|
|
||
|
assert manager.total_number == 4
|
||
|
assert manager.total_mem['cpu'] == 80
|
||
|
assert manager.total_mem['cuda'] == 12
|
||
|
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
|
||
|
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||
|
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||
|
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
|
||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
|
||
|
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
|
||
|
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
test_gemini_manager()
|