[zero] show model data cuda memory usage after zero context init. (#515)

pull/516/head
Jiarui Fang 3 years ago committed by GitHub
parent a2e61d61d4
commit 7ef3507ace
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,13 +22,24 @@ class ModelDataTracer(metaclass=SingletonMeta):
def __init__(self) -> None: def __init__(self) -> None:
self._cuda_usage = 0 self._cuda_usage = 0
self._start_flag = False
def add_tensor(self, t: torch.Tensor): def start(self) -> None:
self._start_flag = True
def close(self) -> None:
self._start_flag = False
def add_tensor(self, t: torch.Tensor) -> None:
if not self._start_flag:
return
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor" assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
mem_use = _col_tensor_mem_usage(t) mem_use = _col_tensor_mem_usage(t)
self._cuda_usage += mem_use self._cuda_usage += mem_use
def delete_tensor(self, t: torch.Tensor): def delete_tensor(self, t: torch.Tensor) -> None:
if not self._start_flag:
return
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor" assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
mem_use = _col_tensor_mem_usage(t) mem_use = _col_tensor_mem_usage(t)
self._cuda_usage -= mem_use self._cuda_usage -= mem_use

@ -10,6 +10,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16 from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from colossalai.zero.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param import ShardedParamV2
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from colossalai.logging import get_dist_logger, disable_existing_loggers
# Inserts _post_init_method at the end of init method # Inserts _post_init_method at the end of init method
@ -126,8 +127,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
self.model_numel_tensor = model_numel_tensor self.model_numel_tensor = model_numel_tensor
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA) self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
def _pre_context_exec(self):
"""
The Callback function when entering the context
"""
self.logger = get_dist_logger("ZeroInitContext")
GLOBAL_MODEL_DATA_TRACER.start()
def _post_context_exec(self): def _post_context_exec(self):
"""The callback function when the context exits. """The callback function when exiting context.
""" """
if not self.rm_torch_payload_on_the_fly: if not self.rm_torch_payload_on_the_fly:
for param in self.initialized_param_list: for param in self.initialized_param_list:
@ -135,9 +143,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
param.col_attr.remove_torch_payload() param.col_attr.remove_torch_payload()
del self.initialized_param_list del self.initialized_param_list
GLOBAL_MODEL_DATA_TRACER.close()
cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0])
def _post_init_method(self, module): def _post_init_method(self, module: torch.nn.Module):
r"""The function to call at the end of the constructor of each nn.Module. """
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
""" """
for param in module.parameters(): for param in module.parameters():
# avoid adapting a param to ShardedParam twice # avoid adapting a param to ShardedParam twice
@ -165,7 +178,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if self.shard_param: if self.shard_param:
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload) if param.col_attr.sharded_data_tensor.device.type == 'cuda':
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
# if param.col_attr.grad and self.shard_grad: # if param.col_attr.grad and self.shard_grad:
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group) # self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)

@ -23,9 +23,11 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
"""A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3. """
Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically A wrapper for the PyTorch module shards the model parameters among multiple GPU memory.
compared to classic data parallelism while the computational granularity and communication efficiency are retained. Only 1/#nproc of parameters, gradients are stored in local CUDA memory, so forward and backward
passes can be executed with limited CUDA memory budget.
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`. Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
Args: Args:

@ -16,6 +16,7 @@ def run_tensor_move(rank):
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl') colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0) assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
GLOBAL_MODEL_DATA_TRACER.start()
src_t = torch.ones(2, 3).cuda() src_t = torch.ones(2, 3).cuda()
GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t) GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t)
@ -39,6 +40,7 @@ def run_tensor_move(rank):
colo_model_data_tensor_move(src_t, tgt_t) colo_model_data_tensor_move(src_t, tgt_t)
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}" assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0" assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
GLOBAL_MODEL_DATA_TRACER.close()
def test_tensor_move(): def test_tensor_move():

Loading…
Cancel
Save