mirror of https://github.com/hpcaitech/ColossalAI
[zero] show model data cuda memory usage after zero context init. (#515)
parent
a2e61d61d4
commit
7ef3507ace
|
@ -22,13 +22,24 @@ class ModelDataTracer(metaclass=SingletonMeta):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._cuda_usage = 0
|
self._cuda_usage = 0
|
||||||
|
self._start_flag = False
|
||||||
|
|
||||||
def add_tensor(self, t: torch.Tensor):
|
def start(self) -> None:
|
||||||
|
self._start_flag = True
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self._start_flag = False
|
||||||
|
|
||||||
|
def add_tensor(self, t: torch.Tensor) -> None:
|
||||||
|
if not self._start_flag:
|
||||||
|
return
|
||||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
|
assert isinstance(t, torch.Tensor), f"ModelDataTracer add_tensor() should accept a torch.Tensor"
|
||||||
mem_use = _col_tensor_mem_usage(t)
|
mem_use = _col_tensor_mem_usage(t)
|
||||||
self._cuda_usage += mem_use
|
self._cuda_usage += mem_use
|
||||||
|
|
||||||
def delete_tensor(self, t: torch.Tensor):
|
def delete_tensor(self, t: torch.Tensor) -> None:
|
||||||
|
if not self._start_flag:
|
||||||
|
return
|
||||||
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
|
assert isinstance(t, torch.Tensor), f"ModelDataTracer delete_tensor() should accept a torch.Tensor"
|
||||||
mem_use = _col_tensor_mem_usage(t)
|
mem_use = _col_tensor_mem_usage(t)
|
||||||
self._cuda_usage -= mem_use
|
self._cuda_usage -= mem_use
|
||||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
|
||||||
from colossalai.zero.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||||
|
|
||||||
# Inserts _post_init_method at the end of init method
|
# Inserts _post_init_method at the end of init method
|
||||||
|
|
||||||
|
@ -126,8 +127,15 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
self.model_numel_tensor = model_numel_tensor
|
self.model_numel_tensor = model_numel_tensor
|
||||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||||
|
|
||||||
|
def _pre_context_exec(self):
|
||||||
|
"""
|
||||||
|
The Callback function when entering the context
|
||||||
|
"""
|
||||||
|
self.logger = get_dist_logger("ZeroInitContext")
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.start()
|
||||||
|
|
||||||
def _post_context_exec(self):
|
def _post_context_exec(self):
|
||||||
"""The callback function when the context exits.
|
"""The callback function when exiting context.
|
||||||
"""
|
"""
|
||||||
if not self.rm_torch_payload_on_the_fly:
|
if not self.rm_torch_payload_on_the_fly:
|
||||||
for param in self.initialized_param_list:
|
for param in self.initialized_param_list:
|
||||||
|
@ -135,9 +143,14 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
param.col_attr.remove_torch_payload()
|
param.col_attr.remove_torch_payload()
|
||||||
|
|
||||||
del self.initialized_param_list
|
del self.initialized_param_list
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.close()
|
||||||
|
cuda_mem_MB = GLOBAL_MODEL_DATA_TRACER.cuda_usage / 1e6
|
||||||
|
self.logger.info(f"Existing ZeRO Context Model Data CUDA Memory Usage {cuda_mem_MB} MB", [0])
|
||||||
|
|
||||||
def _post_init_method(self, module):
|
def _post_init_method(self, module: torch.nn.Module):
|
||||||
r"""The function to call at the end of the constructor of each nn.Module.
|
"""
|
||||||
|
The function to call at the end of the constructor of each module.
|
||||||
|
NOTE() The module may be passed to this function multiple times.
|
||||||
"""
|
"""
|
||||||
for param in module.parameters():
|
for param in module.parameters():
|
||||||
# avoid adapting a param to ShardedParam twice
|
# avoid adapting a param to ShardedParam twice
|
||||||
|
@ -165,7 +178,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
if param.col_attr.sharded_data_tensor.device.type == 'cuda':
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||||
# if param.col_attr.grad and self.shard_grad:
|
# if param.col_attr.grad and self.shard_grad:
|
||||||
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
||||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||||
|
|
|
@ -23,9 +23,11 @@ from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tenso
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
"""A wrapper for a sharded module, which implements Zero Redundancy Optimizer (ZeRO) stage 3.
|
"""
|
||||||
Parameter, gradient and optimizer states are sharded, so memory efficiency is boosted drastically
|
A wrapper for the PyTorch module shards the model parameters among multiple GPU memory.
|
||||||
compared to classic data parallelism while the computational granularity and communication efficiency are retained.
|
Only 1/#nproc of parameters, gradients are stored in local CUDA memory, so forward and backward
|
||||||
|
passes can be executed with limited CUDA memory budget.
|
||||||
|
|
||||||
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
|
Note that you must use `ShardedModelV2` with `ShardedOptimizerV2`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
|
@ -16,6 +16,7 @@ def run_tensor_move(rank):
|
||||||
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
|
|
||||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 0)
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.start()
|
||||||
|
|
||||||
src_t = torch.ones(2, 3).cuda()
|
src_t = torch.ones(2, 3).cuda()
|
||||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t)
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(src_t)
|
||||||
|
@ -39,6 +40,7 @@ def run_tensor_move(rank):
|
||||||
colo_model_data_tensor_move(src_t, tgt_t)
|
colo_model_data_tensor_move(src_t, tgt_t)
|
||||||
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
assert (GLOBAL_MODEL_DATA_TRACER.cuda_usage == 24), f"cuda usage {GLOBAL_MODEL_DATA_TRACER.cuda_usage}"
|
||||||
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
assert (torch.sum(tgt_t.payload) == 6.0), f"{torch.sum(tgt_t.payload)} vs. 6.0"
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.close()
|
||||||
|
|
||||||
|
|
||||||
def test_tensor_move():
|
def test_tensor_move():
|
||||||
|
|
Loading…
Reference in New Issue