mirror of https://github.com/hpcaitech/ColossalAI
[zero] show model data cuda memory usage after zero context init. (#515)
parent
a2e61d61d4
commit
7ef3507ace
|
@ -22,13 +22,24 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
|||
|
||||
def __init__(self) -> None:
|
||||
self._cuda_usage = 0
|
||||
self._start_flag = False
|
||||
|
||||
def add_tensor(self, t: torch.Tensor):
|
||||
def start(self) -> None:
|
||||
self._start_flag = True
|
||||
|
||||
def close(self) -> None:
|
||||
self._start_flag = False
|
||||
|
||||
def add_tensor(self, t: torch.Tensor) -> None:
|
||||
if not self._start_flag:
|
||||
return
|
||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
|
||||
mem_use = _col_tensor_mem_usage(t)
|
||||
self._cuda_usage += mem_use
|
||||
|
||||
def delete_tensor(self, t: torch.Tensor):
|
||||
def delete_tensor(self, t: torch.Tensor) -> None:
|
||||
if not self._start_flag:
|
||||
return
|
||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
|
||||
mem_use = _col_tensor_mem_usage(t)
|
||||
self._cuda_usage -= mem_use
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
|||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
||||
from colossalai.zero.sharded_param import ShardedParamV2
|
||||
from torch.distributed import ProcessGroup
|
||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||
|
||||
# Inserts _post_init_method at the end of init method
|
||||
|
||||
|
@ -126,8 +127,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self.model_numel_tensor = model_numel_tensor
|
||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||
|
||||
def _pre_context_exec(self):
|
||||
"""
|
||||
The Callback function when entering the context
|
||||
"""
|
||||
self.logger = get_dist_logger("ZeroInitContext")
|
||||
GLOBAL_MODEL_DATA_TRACER.start()
|
||||
|
||||
def _post_context_exec(self):
|
||||
"""The callback function when the context exits.
|
||||
"""The callback function when exiting context.
|
||||
"""
|
||||
if not self.rm_torch_payload_on_the_fly:
|
||||
for param in self.initialized_param_list:
|
||||
|
@ -135,9 +143,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
param.col_attr.remove_torch_payload()
|
||||
|
||||
del self.initialized_param_list
|
||||
GLOBAL_MODEL_DATA_TRACER.close()
|
||||
cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
|
||||
self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0])
|
||||
|
||||
def _post_init_method(self, module):
|
||||
r"""The function to call at the end of the constructor of each nn.Module.
|
||||
def _post_init_method(self, module: torch.nn.Module):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
"""
|
||||
for param in module.parameters():
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
|
@ -165,7 +178,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
# if param.col_attr.grad and self.shard_grad:
|
||||
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
|
|
|
@ -23,9 +23,11 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso
|
|||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
"""A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3.
|
||||
Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically
|
||||
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
|
||||
"""
|
||||
A wrapper for the PyTorch module shards the model parameters among multiple GPU memory.
|
||||
Only 1/#nproc of parameters, gradients are stored in local CUDA memory, so forward and backward
|
||||
passes can be executed with limited CUDA memory budget.
|
||||
|
||||
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -16,6 +16,7 @@ def run_tensor_move(rank):
|
|||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||
GLOBAL_MODEL_DATA_TRACER.start()
|
||||
|
||||
src_t = torch.ones(2, 3).cuda()
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t)
|
||||
|
@ -39,6 +40,7 @@ def run_tensor_move(rank):
|
|||
colo_model_data_tensor_move(src_t, tgt_t)
|
||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
||||
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||
GLOBAL_MODEL_DATA_TRACER.close()
|
||||
|
||||
|
||||
def test_tensor_move():
|
||||
|
|
Loading…
Reference in New Issue