|
|
|
@ -10,6 +10,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
|
|
|
|
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
|
|
|
|
from colossalai.zero.sharded_param import ShardedParamV2
|
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
|
|
|
|
|
|
|
|
|
# Inserts _post_init_method at the end of init method
|
|
|
|
|
|
|
|
|
@ -126,8 +127,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
self.model_numel_tensor = model_numel_tensor
|
|
|
|
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
|
|
|
|
|
|
|
|
|
def _pre_context_exec(self):
|
|
|
|
|
"""
|
|
|
|
|
The Callback function when entering the context
|
|
|
|
|
"""
|
|
|
|
|
self.logger = get_dist_logger("ZeroInitContext")
|
|
|
|
|
GLOBAL_MODEL_DATA_TRACER.start()
|
|
|
|
|
|
|
|
|
|
def _post_context_exec(self):
|
|
|
|
|
"""The callback function when the context exits.
|
|
|
|
|
"""The callback function when exiting context.
|
|
|
|
|
"""
|
|
|
|
|
if not self.rm_torch_payload_on_the_fly:
|
|
|
|
|
for param in self.initialized_param_list:
|
|
|
|
@ -135,9 +143,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
param.col_attr.remove_torch_payload()
|
|
|
|
|
|
|
|
|
|
del self.initialized_param_list
|
|
|
|
|
GLOBAL_MODEL_DATA_TRACER.close()
|
|
|
|
|
cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
|
|
|
|
|
self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0])
|
|
|
|
|
|
|
|
|
|
def _post_init_method(self, module):
|
|
|
|
|
r"""The function to call at the end of the constructor of each nn.Module.
|
|
|
|
|
def _post_init_method(self, module: torch.nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
The function to call at the end of the constructor of each module.
|
|
|
|
|
NOTE() The module may be passed to this function multiple times.
|
|
|
|
|
"""
|
|
|
|
|
for param in module.parameters():
|
|
|
|
|
# avoid adapting a param to ShardedParam twice
|
|
|
|
@ -165,7 +178,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
|
|
|
|
|
|
if self.shard_param:
|
|
|
|
|
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
|
|
|
|
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
|
|
|
|
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
|
|
|
|
# if param.col_attr.grad and self.shard_grad:
|
|
|
|
|
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
|
|
|
|
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
|
|
|
|