mirror of https://github.com/hpcaitech/ColossalAI
[zero] add colo move inline (#521)
parent
7be397ca9c
commit
920c5889a7
|
@ -10,6 +10,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
|
|||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
|
||||
from ._base_ophook import BaseOpHook
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
|
@ -37,9 +38,7 @@ class ZeroHook(BaseOpHook):
|
|||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
if param.col_attr.sharded_data_tensor.device != self.computing_device:
|
||||
param.col_attr.sharded_data_tensor.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
|
@ -61,9 +60,7 @@ class ZeroHook(BaseOpHook):
|
|||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
if param.col_attr.sharded_data_tensor.device != self.computing_device:
|
||||
param.col_attr.sharded_data_tensor.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
|
|
|
@ -65,6 +65,34 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
|
|||
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
|
||||
|
||||
|
||||
def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> None:
|
||||
"""
|
||||
move a tensor to the target_device
|
||||
Args:
|
||||
t (Union[ShardedTensor, torch.Tensor]): the tensor be moved
|
||||
"""
|
||||
|
||||
if isinstance(t, ShardedTensor):
|
||||
t_payload = t.payload
|
||||
elif isinstance(t, torch.Tensor):
|
||||
t_payload = t
|
||||
else:
|
||||
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
|
||||
assert isinstance(target_device, torch.device)
|
||||
|
||||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
||||
if t_payload.device.type == target_device.type:
|
||||
return
|
||||
|
||||
if target_device.type == 'cuda':
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||
elif target_device.type == 'cpu':
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||
|
||||
t_payload.data = t_payload.data.to(target_device)
|
||||
|
||||
|
||||
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
||||
"""colo_model_data_move_to_cpu
|
||||
|
||||
|
|
|
@ -143,7 +143,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
del self.initialized_param_list
|
||||
GLOBAL_MODEL_DATA_TRACER.close()
|
||||
model_data_cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
|
||||
self.logger.info(f"Existing ZeRO Context: Model Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||
self.logger.info(f"Existing ZeRO Context.\nModel Data CUDA Memory {model_data_cuda_mem_MB} MB", ranks=[0])
|
||||
sys_cuda_mem_MB = colo_cuda_memory_used() / 1e6
|
||||
self.logger.info(f"System CUDA Memory Usage {sys_cuda_mem_MB} MB", ranks=[0])
|
||||
self.logger.info(f"Model Number Parameter {self.model_numel_tensor.numpy()[0]/1e6} M", ranks=[0])
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.utils import free_port
|
||||
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
|
@ -40,6 +40,12 @@ def run_tensor_move(rank):
|
|||
colo_model_data_tensor_move(src_t, tgt_t)
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
||||
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||
|
||||
assert (tgt_t.device.type == 'cuda')
|
||||
colo_model_data_tensor_move_inline(tgt_t, torch.device('cpu'))
|
||||
assert (tgt_t.device.type == 'cpu')
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 12), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
||||
|
||||
GLOBAL_MODEL_DATA_TRACER.close()
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue