[zero] add colo move inline (#521)

pull/522/head
Jiarui Fang 2022-03-25 14:02:55 +08:00 committed by GitHub
parent 7be397ca9c
commit 920c5889a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 8 deletions

View File

@ -10,6 +10,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
from colossalai.zero.shard_utils import BaseShardStrategy
from ._base_ophook import BaseOpHook
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
@OPHOOKS.register_module
@ -37,9 +38,7 @@ class ZeroHook(BaseOpHook):
tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters():
if param.col_attr.sharded_data_tensor.device != self.computing_device:
param.col_attr.sharded_data_tensor.to(self.computing_device)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
param.data = param.col_attr.sharded_data_tensor.payload
if self._memstarts_collector:
@ -61,9 +60,7 @@ class ZeroHook(BaseOpHook):
tensor_list.append(param.col_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters():
if param.col_attr.sharded_data_tensor.device != self.computing_device:
param.col_attr.sharded_data_tensor.to(self.computing_device)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
param.data = param.col_attr.sharded_data_tensor.payload
# Store local accumulated grad shard
if param.grad is not None:

View File

@ -65,6 +65,34 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> None:
"""
move a tensor to the target_device
Args:
t (Union[ShardedTensor, torch.Tensor]): the tensor be moved
"""
if isinstance(t, ShardedTensor):
t_payload = t.payload
elif isinstance(t, torch.Tensor):
t_payload = t
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
assert isinstance(target_device, torch.device)
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:
return
if target_device.type == 'cuda':
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
elif target_device.type == 'cpu':
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
t_payload.data = t_payload.data.to(target_device)
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu

View File

@ -143,7 +143,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
del self.initialized_param_list
GLOBAL_MODEL_DATA_TRACER.close()
model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
self.logger.info(f"Existing ZeRO Context: Model Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
self.logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])

View File

@ -1,5 +1,5 @@
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.zero.sharded_param import ShardedTensor
@ -40,6 +40,12 @@ def run_tensor_move(rank):
colo_model_data_tensor_move(src_t, tgt_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
assert (tgt_t.device.type == 'cuda')
colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu'))
assert (tgt_t.device.type == 'cpu')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 12), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
GLOBAL_MODEL_DATA_TRACER.close()