diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py index 05c53d741..e5a742f64 100644 --- a/colossalai/utils/memory_tracer/model_data_memtracer.py +++ b/colossalai/utils/memory_tracer/model_data_memtracer.py @@ -22,13 +22,24 @@ class ModelDataTracer(metaclass=SingletonMeta): def __init__(self) -> None: self._cuda_usage = 0 + self._start_flag = False - def add_tensor(self, t: torch.Tensor): + def start(self) -> None: + self._start_flag = True + + def close(self) -> None: + self._start_flag = False + + def add_tensor(self, t: torch.Tensor) -> None: + if not self._start_flag: + return assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor" mem_use = _col_tensor_mem_usage(t) self._cuda_usage += mem_use - def delete_tensor(self, t: torch.Tensor): + def delete_tensor(self, t: torch.Tensor) -> None: + if not self._start_flag: + return assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor" mem_use = _col_tensor_mem_usage(t) self._cuda_usage -= mem_use diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 321548f98..be73da796 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -10,6 +10,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup +from colossalai.logging import get_dist_logger, disable_existing_loggers # Inserts _post_init_method at the end of init method @@ -126,8 +127,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): self.model_numel_tensor = model_numel_tensor self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) + def _pre_context_exec(self): + """ + The Callback function when entering the context + """ + self.logger = get_dist_logger("ZeroInitContext") + GLOBAL_MODEL_DATA_TRACER.start() + def _post_context_exec(self): - """The callback function when the context exits. + """The callback function when exiting context. """ if not self.rm_torch_payload_on_the_fly: for param in self.initialized_param_list: @@ -135,9 +143,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): param.col_attr.remove_torch_payload() del self.initialized_param_list + GLOBAL_MODEL_DATA_TRACER.close() + cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6 + self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0]) - def _post_init_method(self, module): - r"""The function to call at the end of the constructor of each nn.Module. + def _post_init_method(self, module: torch.nn.Module): + """ + The function to call at the end of the constructor of each module. + NOTE() The module may be passed to this function multiple times. """ for param in module.parameters(): # avoid adapting a param to ShardedParam twice @@ -165,7 +178,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.shard_param: self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) - GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) + if param.col_attr.sharded_data_tensor.device.type == 'cuda': + GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) # if param.col_attr.grad and self.shard_grad: # self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group) # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 5bc6f3cc6..392d25226 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -23,9 +23,11 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso class ShardedModelV2(nn.Module): - """A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3. - Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically - compared to classic data parallelism while the computational granularity and communication efficiency are retained. + """ + A wrapper for the PyTorch module shards the model parameters among multiple GPU memory. + Only 1/#nproc of parameters, gradients are stored in local CUDA memory, so forward and backward + passes can be executed with limited CUDA memory budget. + Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`. Args: diff --git a/tests/test_utils/test_commons.py b/tests/test_utils/test_commons.py index a96594a5f..9b42b10f9 100644 --- a/tests/test_utils/test_commons.py +++ b/tests/test_utils/test_commons.py @@ -16,6 +16,7 @@ def run_tensor_move(rank): colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0) + GLOBAL_MODEL_DATA_TRACER.start() src_t = torch.ones(2, 3).cuda() GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t) @@ -39,6 +40,7 @@ def run_tensor_move(rank): colo_model_data_tensor_move(src_t, tgt_t) assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}" assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" + GLOBAL_MODEL_DATA_TRACER.close() def test_tensor_move():