[zero] refactor ShardedParamV2 for convenience (#742)

pull/748/head
HELSON 2022-04-13 14:54:26 +08:00 committed by GitHub
parent 340e59f968
commit 22c4b88d56
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 98 additions and 61 deletions

View File

@ -215,7 +215,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
assert hasattr(param, 'colo_attr')
if not param.colo_attr.param_is_sharded and param.colo_attr.is_replicated:
dist.broadcast(tensor=param.data, src=src_rank, group=self.dp_process_group)
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()
del self.param_list
@ -252,11 +252,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.colo_attr = ShardedParamV2(param, rm_torch_payload=False)
param.colo_attr = ShardedParamV2(param, set_data_none=False)
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.sharded_data_tensor.payload # set param.data to payload
param.data = param.colo_attr.data_payload # set param.data to payload
# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated

View File

@ -260,7 +260,7 @@ class ShardedModelV2(nn.Module):
if not p.colo_attr.param_is_sharded:
tensor_list.append(p.colo_attr.sharded_data_tensor)
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
p.colo_attr.remove_torch_payload()
p.colo_attr.set_data_none()
self.shard_strategy.shard(tensor_list, self.process_group)
# 4. set all parameters' grad to None
@ -357,8 +357,8 @@ class ShardedModelV2(nn.Module):
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.saved_grad.reset_payload(grad)
param.colo_attr.sharded_data_tensor.reset_payload(grad) # release the memory of param
param.colo_attr.reset_grad_payload(grad)
param.colo_attr.reset_grad_payload(grad) # release the memory of param
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
@ -367,9 +367,9 @@ class ShardedModelV2(nn.Module):
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.saved_grad.reset_payload(fp32_grad)
param.colo_attr.reset_grad_payload(fp32_grad)
else:
param.colo_attr.saved_grad.payload.add_(fp32_grad.view_as(param.colo_attr.saved_grad.payload))
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
# keep saved_grad in HOLD state
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
@ -377,11 +377,11 @@ class ShardedModelV2(nn.Module):
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
for p in self.sharded_params:
p.data = p.colo_attr.sharded_data_tensor.payload
p.data = p.colo_attr.data_payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.sharded_params], self.process_group)
for p in self.sharded_params:
p.colo_attr.remove_torch_payload()
p.colo_attr.set_data_none()
return gathered_state_dict
def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True):

View File

@ -14,6 +14,6 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload)
param.data = copy.deepcopy(zero_param.colo_attr.data_payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])

View File

@ -266,8 +266,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if shard_flag:
# we always shard replicated paramters
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload.to(self.device)))
self.master_params[p] = StatefulTensor(cast_tensor_to_fp32(p.colo_attr.data_payload.to(self.device)))
if shard_flag:
# In this branch, there's no need to shard param
# So we gather here
@ -296,10 +295,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.colo_attr.saved_grad.payload here
p.data = p.colo_attr.saved_grad.payload
p.grad = p.colo_attr.saved_grad.payload
p.data = p.colo_attr.grad_payload
p.grad = p.colo_attr.grad_payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.remove_torch_payload()
p.colo_attr.set_data_none()
def _point_param_fp16_to_master_param(self):
# assign master param pointers to p.data.
@ -325,9 +324,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.data = self.master_params[p].payload
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.remove_torch_payload()
p.colo_attr.reset_data_payload(
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.set_data_none()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
# We gather full fp16 param here

View File

@ -10,10 +10,20 @@ from typing import List
# empty tensor is expected to raise error when get used
FAKE_EMPTY_TENSOR = torch.BoolTensor([], device='cpu')
EMPTY_TENSOR_DICT = {}
def get_empty_tensor(device: torch.device, dtype: torch.dtype):
key = (device, dtype)
if key not in EMPTY_TENSOR_DICT:
EMPTY_TENSOR_DICT[key] = FAKE_EMPTY_TENSOR.to(device, dtype)
return EMPTY_TENSOR_DICT[key]
class ShardedParamV2(object):
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
def __init__(self, param: torch.nn.Parameter, set_data_none: bool = False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel
@ -25,24 +35,47 @@ class ShardedParamV2(object):
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# So we can not empty the .data at this time
self.param = param
if rm_torch_payload:
self.remove_torch_payload()
if set_data_none:
self.set_data_none()
def get_payload_tensors(self) -> List[StatefulTensor]:
"""returns stateful tensors kept by this class.
"""
return [self._sharded_data_tensor]
def remove_torch_payload(self):
self.param.data = FAKE_EMPTY_TENSOR.to(self._sharded_data_tensor.device, self._sharded_data_tensor.dtype)
def set_data_none(self):
self.param.data = get_empty_tensor(self.sharded_data_tensor.device, self.sharded_data_tensor.dtype)
def set_grad_none(self):
self.saved_grad.set_null()
@property
def sharded_data_tensor(self):
return self._sharded_data_tensor
@property
def data_payload(self):
return self.sharded_data_tensor.payload
@property
def grad_payload(self):
assert not self.saved_grad.is_null()
return self.saved_grad.payload
@property
def param_is_sharded(self):
return self._sharded_data_tensor.is_sharded
return self.sharded_data_tensor.is_sharded
def reset_data_payload(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.sharded_data_tensor.reset_payload(tensor)
self.set_data_none()
def reset_grad_payload(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.saved_grad.reset_payload(tensor)
def get_memory_usage(self) -> Tuple[int, int]:
"""
@ -63,11 +96,11 @@ class ShardedParamV2(object):
cpu_mem_use += t_cpu
address_set = set()
_update_mem_use(self.sharded_data_tensor.payload)
address_set.add(self.sharded_data_tensor.payload.data_ptr())
_update_mem_use(self.data_payload)
address_set.add(self.data_payload.data_ptr())
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
_update_mem_use(self.saved_grad.payload)
_update_mem_use(self.grad_payload)
address_set.add(self.saved_grad.data_ptr())
if self.param.data is not None and self.param.data.data_ptr() not in address_set:

View File

@ -9,6 +9,7 @@ class ShardedTensor(StatefulTensor):
r"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
"""
assert tensor.requires_grad is False
super().__init__(tensor, state)
# kept the shape, numel and dtype of the init tensor.
@ -17,6 +18,11 @@ class ShardedTensor(StatefulTensor):
self._origin_dtype = tensor.dtype
self._is_sharded = False
@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._payload.dtype
@property
def origin_numel(self) -> int:
return self._origin_numel

View File

@ -19,11 +19,11 @@ class StatefulTensor(object):
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None:
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = tensor
if self._state == TensorState.FREE:
assert self._payload is None, f"payload has to None if {self._state}"
assert self._payload is None, f"payload has to None if state is {self._state}"
def data_ptr(self):
if self._payload is None:
@ -50,13 +50,13 @@ class StatefulTensor(object):
self._payload = None
@property
def payload(self) -> int:
def payload(self) -> Optional[torch.Tensor]:
return self._payload
def copy_payload(self, tensor) -> int:
def copy_payload(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor) -> int:
def reset_payload(self, tensor) -> None:
del self._payload
self._payload = tensor
self.trans_state(TensorState.HOLD)
@ -67,15 +67,14 @@ class StatefulTensor(object):
@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._origin_dtype
return self._payload.dtype
@property
def shape(self):
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
@property
def shape(self):
return self._payload.shape

View File

@ -60,7 +60,7 @@ class ZeroHook(BaseOpHook):
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.data = param.colo_attr.sharded_data_tensor.payload
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
def post_fwd_exec(self, module: torch.nn.Module, *args):
@ -79,7 +79,7 @@ class ZeroHook(BaseOpHook):
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
@ -105,7 +105,7 @@ class ZeroHook(BaseOpHook):
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.data = param.colo_attr.sharded_data_tensor.payload
param.data = param.colo_attr.data_payload
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
def post_bwd_exec(self, module: torch.nn.Module, input):
@ -124,7 +124,7 @@ class ZeroHook(BaseOpHook):
# remove torch payload
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
param.colo_attr.set_data_none()
def pre_iter(self):
pass

View File

@ -77,10 +77,10 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
assert param.colo_attr.is_replicated
if param.colo_attr.param_is_sharded:
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
assert param.colo_attr.data_payload.device.type == init_device.type, \
f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}'
else:
assert param.colo_attr.sharded_data_tensor.payload.device.type == 'cuda'
assert param.colo_attr.data_payload.device.type == 'cuda'
def _run_dist(rank, world_size, port):

View File

@ -37,7 +37,7 @@ def run_model_test(enable_autocast, shard_strategy_class):
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload)
assert_equal_in_group(p.colo_attr.data_payload)
model = MoeModel(checkpoint=True).half()
col_model_deepcopy(zero_model, model)

View File

@ -76,7 +76,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
# check whether parameters are identical in ddp
for name, p in zero_model.named_parameters():
if not p.colo_attr.param_is_sharded and p.colo_attr.is_replicated:
assert_equal_in_group(p.colo_attr.sharded_data_tensor.payload.to(get_current_device()))
assert_equal_in_group(p.colo_attr.data_payload.to(get_current_device()))
model = MoeModel(checkpoint=True).half()
col_model_deepcopy(zero_model, model)
@ -100,7 +100,7 @@ def _run_test_sharded_optim_v2(cpu_offload,
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
if 'gate' in n:
p.data = p.float()
p.data.copy_(zp.colo_attr.sharded_data_tensor.payload)
p.data.copy_(zp.colo_attr.data_payload)
for i, (data, label) in enumerate(train_dataloader):
if i > 5:

View File

@ -94,7 +94,7 @@ def check_grads_padding(model, zero_model, loose=False):
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
# zero_grad = zero_p.grad.clone().to(p.device)
if zero_p.colo_attr.is_replicated:
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
zero_grad = zero_p.colo_attr.grad_payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
@ -102,7 +102,7 @@ def check_grads_padding(model, zero_model, loose=False):
if zero_grad.size(0) > grad.size(0):
zero_grad = zero_grad[:grad.size(0)]
else:
zero_grad = zero_p.colo_attr.saved_grad.payload
zero_grad = zero_p.colo_attr.grad_payload
grad = p.grad.to(zero_grad.dtype)
assert grad.dtype == zero_grad.dtype
@ -127,7 +127,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
rank = dist.get_rank()
for (name, p), (zero_name, zero_p) in zip(model.named_parameters(), zero_model.named_parameters()):
if zero_p.colo_attr.param_is_sharded:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
zero_p = zero_p.colo_attr.data_payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
@ -135,7 +135,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
if zero_p.size(0) > p.size(0):
zero_p = zero_p[:p.size(0)]
else:
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device)
zero_p = zero_p.colo_attr.data_payload.to(p.device)
assert p.dtype == zero_p.dtype
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'

View File

@ -55,7 +55,7 @@ def _run_test_found_inf(cpu_offload, shard_strategy_class, gpu_margin_mem_ratio)
data, label = data.cuda(), label.cuda()
_run_step(zero_model, sharded_optim, data, label, criterion, False)
for param in zero_model.parameters():
assert not has_inf_or_nan(param.colo_attr.sharded_data_tensor.payload)
assert not has_inf_or_nan(param.colo_attr.data_payload)
def _run_dist(rank, world_size, port):

View File

@ -46,8 +46,8 @@ def run_model_test(init_device_type, shard_strategy_class):
assert hasattr(param, 'colo_attr')
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
assert param.colo_attr.sharded_data_tensor.is_sharded
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
assert param.colo_attr.data_payload.device.type == init_device.type, \
f'{param.colo_attr.data_payload.device.type} vs. {init_device.type}'
cuda_mem_use, _ = colo_model_mem_usage(model)
model_data_cuda_mem_MB = cuda_mem_use / 1e6

View File

@ -50,27 +50,27 @@ def _run_shard_param_v2(rank, world_size, port):
param_ref = deepcopy(param)
sparam = ShardedParamV2(param=param)
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
allclose(sparam.data_payload, param_ref.data)
# Test get memory usage
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
sparam.remove_torch_payload()
sparam.set_data_none()
assert (param.data.numel() == 0)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
# 4 is size of dummy tensor of param.data
assert cpu_mem_use == 2 * 3 * 4 * 2
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload()
sparam.set_data_none()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0
# append a grad to torch param
param.data = sparam.sharded_data_tensor.payload
param.data = sparam.data_payload
param.grad = torch.randn(2, 3)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 2 * 3 * 4, f"cpu_mem_use {cpu_mem_use}"

View File

@ -34,7 +34,7 @@ def run_stm():
colo_set_process_memory_fraction(fraction)
model = Net()
for p in model.parameters():
p.colo_attr = ShardedParamV2(p, rm_torch_payload=True)
p.colo_attr = ShardedParamV2(p, set_data_none=True)
GLOBAL_MODEL_DATA_TRACER.register_model(model)
mem_collector = MemStatsCollector()
stateful_tensor_mgr = StatefulTensorMgr(mem_collector)