mirror of https://github.com/hpcaitech/ColossalAI
[zero] refactor ShardedParamV2 for convenience (#742)
parent
340e59f968
commit
22c4b88d56
|
@ -215,7 +215,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
assert hasattr(param, 'colo_attr')
|
||||
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
|
||||
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
|
||||
param.colo_attr.remove_torch_payload()
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
del self.param_list
|
||||
|
||||
|
@ -252,11 +252,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.colo_attr = ShardedParamV2(param, rm_torch_payload=False)
|
||||
param.colo_attr = ShardedParamV2(param, set_data_none=False)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
|
||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.colo_attr.is_replicated = self.is_replicated
|
||||
|
|
|
@ -260,7 +260,7 @@ class ShardedModelV2(nn.Module):
|
|||
if not p.colo_attr.param_is_sharded:
|
||||
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.colo_attr.remove_torch_payload()
|
||||
p.colo_attr.set_data_none()
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. set all parameters' grad to None
|
||||
|
@ -357,8 +357,8 @@ class ShardedModelV2(nn.Module):
|
|||
assert param.colo_attr.saved_grad.is_null(
|
||||
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
|
||||
param.colo_attr.saved_grad.reset_payload(grad)
|
||||
param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param
|
||||
param.colo_attr.reset_grad_payload(grad)
|
||||
param.colo_attr.reset_grad_payload(grad) # release the memory of param
|
||||
|
||||
if param.colo_attr.is_replicated:
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||
|
@ -367,9 +367,9 @@ class ShardedModelV2(nn.Module):
|
|||
fp32_grad = cast_tensor_to_fp32(grad)
|
||||
|
||||
if param.colo_attr.saved_grad.is_null():
|
||||
param.colo_attr.saved_grad.reset_payload(fp32_grad)
|
||||
param.colo_attr.reset_grad_payload(fp32_grad)
|
||||
else:
|
||||
param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload))
|
||||
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
|
||||
|
||||
# keep saved_grad in HOLD state
|
||||
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||
|
@ -377,11 +377,11 @@ class ShardedModelV2(nn.Module):
|
|||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
for p in self.sharded_params:
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
p.data = p.colo_attr.data_payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
|
||||
for p in self.sharded_params:
|
||||
p.colo_attr.remove_torch_payload()
|
||||
p.colo_attr.set_data_none()
|
||||
return gathered_state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):
|
||||
|
|
|
@ -14,6 +14,6 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
|
|||
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
|
||||
param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload)
|
||||
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])
|
||||
|
|
|
@ -266,8 +266,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
if shard_flag:
|
||||
# we always shard replicated paramters
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = StatefulTensor(
|
||||
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
|
||||
self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device)))
|
||||
if shard_flag:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
|
@ -296,10 +295,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# If we change p.grad directly
|
||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||
# We just set p.data = p.colo_attr.saved_grad.payload here
|
||||
p.data = p.colo_attr.saved_grad.payload
|
||||
p.grad = p.colo_attr.saved_grad.payload
|
||||
p.data = p.colo_attr.grad_payload
|
||||
p.grad = p.colo_attr.grad_payload
|
||||
# Set p.data to empty tensor, in case of memory leaking
|
||||
p.colo_attr.remove_torch_payload()
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
def _point_param_fp16_to_master_param(self):
|
||||
# assign master param pointers to p.data.
|
||||
|
@ -325,9 +324,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.data = self.master_params[p].payload
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
|
||||
p.colo_attr.remove_torch_payload()
|
||||
p.colo_attr.reset_data_payload(
|
||||
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
# We gather full fp16 param here
|
||||
|
|
|
@ -10,10 +10,20 @@ from typing import List
|
|||
# empty tensor is expected to raise error when get used
|
||||
FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu')
|
||||
|
||||
EMPTY_TENSOR_DICT = {}
|
||||
|
||||
|
||||
def get_empty_tensor(device: torch.device, dtype: torch.dtype):
|
||||
key = (device, dtype)
|
||||
if key not in EMPTY_TENSOR_DICT:
|
||||
EMPTY_TENSOR_DICT[key] = FAKE_EMPTY_TENSOR.to(device, dtype)
|
||||
|
||||
return EMPTY_TENSOR_DICT[key]
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
||||
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
||||
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
# This attribute must be initialized in ShardedModel
|
||||
|
@ -25,24 +35,47 @@ class ShardedParamV2(object):
|
|||
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
# So we can not empty the .data at this time
|
||||
self.param = param
|
||||
if rm_torch_payload:
|
||||
self.remove_torch_payload()
|
||||
if set_data_none:
|
||||
self.set_data_none()
|
||||
|
||||
def get_payload_tensors(self) -> List[StatefulTensor]:
|
||||
"""returns stateful tensors kept by this class.
|
||||
"""
|
||||
return [self._sharded_data_tensor]
|
||||
|
||||
def remove_torch_payload(self):
|
||||
self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype)
|
||||
def set_data_none(self):
|
||||
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype)
|
||||
|
||||
def set_grad_none(self):
|
||||
self.saved_grad.set_null()
|
||||
|
||||
@property
|
||||
def sharded_data_tensor(self):
|
||||
return self._sharded_data_tensor
|
||||
|
||||
@property
|
||||
def data_payload(self):
|
||||
return self.sharded_data_tensor.payload
|
||||
|
||||
@property
|
||||
def grad_payload(self):
|
||||
assert not self.saved_grad.is_null()
|
||||
return self.saved_grad.payload
|
||||
|
||||
@property
|
||||
def param_is_sharded(self):
|
||||
return self._sharded_data_tensor.is_sharded
|
||||
return self.sharded_data_tensor.is_sharded
|
||||
|
||||
def reset_data_payload(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.sharded_data_tensor.reset_payload(tensor)
|
||||
self.set_data_none()
|
||||
|
||||
def reset_grad_payload(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.saved_grad.reset_payload(tensor)
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
|
@ -63,11 +96,11 @@ class ShardedParamV2(object):
|
|||
cpu_mem_use += t_cpu
|
||||
|
||||
address_set = set()
|
||||
_update_mem_use(self.sharded_data_tensor.payload)
|
||||
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
||||
_update_mem_use(self.data_payload)
|
||||
address_set.add(self.data_payload.data_ptr())
|
||||
|
||||
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.saved_grad.payload)
|
||||
_update_mem_use(self.grad_payload)
|
||||
address_set.add(self.saved_grad.data_ptr())
|
||||
|
||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||
|
|
|
@ -9,6 +9,7 @@ class ShardedTensor(StatefulTensor):
|
|||
r"""
|
||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||
"""
|
||||
assert tensor.requires_grad is False
|
||||
super().__init__(tensor, state)
|
||||
|
||||
# kept the shape, numel and dtype of the init tensor.
|
||||
|
@ -17,6 +18,11 @@ class ShardedTensor(StatefulTensor):
|
|||
self._origin_dtype = tensor.dtype
|
||||
self._is_sharded = False
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
assert self._payload.dtype == self._origin_dtype
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def origin_numel(self) -> int:
|
||||
return self._origin_numel
|
||||
|
|
|
@ -19,11 +19,11 @@ class StatefulTensor(object):
|
|||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
self._payload = tensor
|
||||
if self._state == TensorState.FREE:
|
||||
assert self._payload is None, f"payload has to None if {self._state}"
|
||||
assert self._payload is None, f"payload has to None if state is {self._state}"
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
|
@ -50,13 +50,13 @@ class StatefulTensor(object):
|
|||
self._payload = None
|
||||
|
||||
@property
|
||||
def payload(self) -> int:
|
||||
def payload(self) -> Optional[torch.Tensor]:
|
||||
return self._payload
|
||||
|
||||
def copy_payload(self, tensor) -> int:
|
||||
def copy_payload(self, tensor) -> None:
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def reset_payload(self, tensor) -> int:
|
||||
def reset_payload(self, tensor) -> None:
|
||||
del self._payload
|
||||
self._payload = tensor
|
||||
self.trans_state(TensorState.HOLD)
|
||||
|
@ -67,15 +67,14 @@ class StatefulTensor(object):
|
|||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
assert self._payload.dtype == self._origin_dtype
|
||||
return self._origin_dtype
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||
|
||||
def to_(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
|
|
@ -60,7 +60,7 @@ class ZeroHook(BaseOpHook):
|
|||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
|
@ -79,7 +79,7 @@ class ZeroHook(BaseOpHook):
|
|||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.remove_torch_payload()
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
|
||||
|
@ -105,7 +105,7 @@ class ZeroHook(BaseOpHook):
|
|||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
param.data = param.colo_attr.data_payload
|
||||
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
|
@ -124,7 +124,7 @@ class ZeroHook(BaseOpHook):
|
|||
|
||||
# remove torch payload
|
||||
for param in module.parameters(recurse=False):
|
||||
param.colo_attr.remove_torch_payload()
|
||||
param.colo_attr.set_data_none()
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
|
|
@ -77,10 +77,10 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
|||
assert param.colo_attr.is_replicated
|
||||
|
||||
if param.colo_attr.param_is_sharded:
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
assert param.colo_attr.data_payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}'
|
||||
else:
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == 'cuda'
|
||||
assert param.colo_attr.data_payload.device.type == 'cuda'
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
|
|
|
@ -37,7 +37,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
|
|||
# check whether parameters are identical in ddp
|
||||
for name, p in zero_model.named_parameters():
|
||||
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)
|
||||
assert_equal_in_group(p.colo_attr.data_payload)
|
||||
|
||||
model = MoeModel(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
|
|
|
@ -76,7 +76,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
|||
# check whether parameters are identical in ddp
|
||||
for name, p in zero_model.named_parameters():
|
||||
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
|
||||
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
|
||||
assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device()))
|
||||
|
||||
model = MoeModel(checkpoint=True).half()
|
||||
col_model_deepcopy(zero_model, model)
|
||||
|
@ -100,7 +100,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
|
|||
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
|
||||
if 'gate' in n:
|
||||
p.data = p.float()
|
||||
p.data.copy_(zp.colo_attr.sharded_data_tensor.payload)
|
||||
p.data.copy_(zp.colo_attr.data_payload)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
if i > 5:
|
||||
|
|
|
@ -94,7 +94,7 @@ def check_grads_padding(model, zero_model, loose=False):
|
|||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||
if zero_p.colo_attr.is_replicated:
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
||||
zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
@ -102,7 +102,7 @@ def check_grads_padding(model, zero_model, loose=False):
|
|||
if zero_grad.size(0) > grad.size(0):
|
||||
zero_grad = zero_grad[:grad.size(0)]
|
||||
else:
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload
|
||||
zero_grad = zero_p.colo_attr.grad_payload
|
||||
grad = p.grad.to(zero_grad.dtype)
|
||||
|
||||
assert grad.dtype == zero_grad.dtype
|
||||
|
@ -127,7 +127,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
|
|||
rank = dist.get_rank()
|
||||
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
|
||||
if zero_p.colo_attr.param_is_sharded:
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
zero_p = zero_p.colo_attr.data_payload.to(p.device).float()
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
@ -135,7 +135,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
|
|||
if zero_p.size(0) > p.size(0):
|
||||
zero_p = zero_p[:p.size(0)]
|
||||
else:
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device)
|
||||
zero_p = zero_p.colo_attr.data_payload.to(p.device)
|
||||
|
||||
assert p.dtype == zero_p.dtype
|
||||
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'
|
||||
|
|
|
@ -55,7 +55,7 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
|
|||
data, label = data.cuda(), label.cuda()
|
||||
_run_step(zero_model, sharded_optim, data, label, criterion, False)
|
||||
for param in zero_model.parameters():
|
||||
assert not has_inf_or_nan(param.colo_attr.sharded_data_tensor.payload)
|
||||
assert not has_inf_or_nan(param.colo_attr.data_payload)
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
|
|
|
@ -46,8 +46,8 @@ def run_model_test(init_device_type, shard_strategy_class):
|
|||
assert hasattr(param, 'colo_attr')
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
assert param.colo_attr.data_payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}'
|
||||
|
||||
cuda_mem_use, _ = colo_model_mem_usage(model)
|
||||
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
||||
|
|
|
@ -50,27 +50,27 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
param_ref = deepcopy(param)
|
||||
sparam = ShardedParamV2(param=param)
|
||||
|
||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||
allclose(sparam.data_payload, param_ref.data)
|
||||
|
||||
# Test get memory usage
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||
|
||||
sparam.remove_torch_payload()
|
||||
sparam.set_data_none()
|
||||
assert (param.data.numel() == 0)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
# 4 is size of dummy tensor of param.data
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.remove_torch_payload()
|
||||
sparam.set_data_none()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
|
||||
# append a grad to torch param
|
||||
param.data = sparam.sharded_data_tensor.payload
|
||||
param.data = sparam.data_payload
|
||||
param.grad = torch.randn(2, 3)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}"
|
||||
|
|
|
@ -34,7 +34,7 @@ def run_stm():
|
|||
colo_set_process_memory_fraction(fraction)
|
||||
model = Net()
|
||||
for p in model.parameters():
|
||||
p.colo_attr = ShardedParamV2(p, rm_torch_payload=True)
|
||||
p.colo_attr = ShardedParamV2(p, set_data_none=True)
|
||||
GLOBAL_MODEL_DATA_TRACER.register_model(model)
|
||||
mem_collector = MemStatsCollector()
|
||||
stateful_tensor_mgr = StatefulTensorMgr(mem_collector)
|
||||
|
|
Loading…
Reference in New Issue