[polish] rename col_attr -> colo_attr (#558)

pull/563/head
Jiarui Fang 2022-03-31 12:25:45 +08:00 committed by GitHub
parent 2c45efc398
commit 7675366fce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 91 additions and 91 deletions

View File

@ -35,58 +35,58 @@ class ZeroHook(BaseOpHook):
def pre_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor)
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
param.data = param.col_attr.sharded_data_tensor.payload
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.colo_attr.sharded_data_tensor.payload
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor)
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
param.col_attr.remove_torch_payload()
param.colo_attr.remove_torch_payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor)
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
param.data = param.col_attr.sharded_data_tensor.payload
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.colo_attr.sharded_data_tensor.payload
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr')
tensor_list.append(param.col_attr.sharded_data_tensor)
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
param.col_attr.remove_torch_payload()
param.colo_attr.remove_torch_payload()
def pre_iter(self):
pass

View File

@ -45,8 +45,8 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
cuda_mem_usage = 0
cpu_mem_usage = 0
for param in model.parameters():
if hasattr(param, 'col_attr'):
t_cuda, t_cpu = param.col_attr.get_memory_usage()
if hasattr(param, 'colo_attr'):
t_cuda, t_cpu = param.colo_attr.get_memory_usage()
cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu
else:

View File

@ -162,8 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
"""
if not self.rm_torch_payload_on_the_fly:
for param in self.initialized_param_list:
assert hasattr(param, 'col_attr')
param.col_attr.remove_torch_payload()
assert hasattr(param, 'colo_attr')
param.colo_attr.remove_torch_payload()
del self.initialized_param_list
@ -178,7 +178,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice
if hasattr(param, 'col_attr'):
if hasattr(param, 'colo_attr'):
continue
self.model_numel_tensor += param.numel()
@ -196,10 +196,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
param.colo_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
if self.shard_param:
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
self.initialized_param_list.append(param)
# We must cast buffers

View File

@ -70,9 +70,9 @@ class ShardedModelV2(nn.Module):
sharded = []
unsharded = []
for param in module.parameters():
assert hasattr(param, 'col_attr'), 'You must use ZeroInitContext to init your module first.'
sharded.append(param.col_attr.param_is_sharded)
unsharded.append(not param.col_attr.param_is_sharded)
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
sharded.append(param.colo_attr.param_is_sharded)
unsharded.append(not param.colo_attr.param_is_sharded)
assert all(sharded) or all(
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
self.shard_param = all(sharded)
@ -103,7 +103,7 @@ class ShardedModelV2(nn.Module):
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
for param in module.parameters():
# Init `offload_grad`
param.col_attr.offload_grad = self._cpu_offload
param.colo_attr.offload_grad = self._cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
@ -162,13 +162,13 @@ class ShardedModelV2(nn.Module):
self._memstats_collector.start_collection()
for p in self.module.parameters():
if hasattr(p, 'col_attr'):
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
if hasattr(p, 'colo_attr'):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def _post_forward_operations(self):
for p in self.module.parameters():
if hasattr(p, 'col_attr'):
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
if hasattr(p, 'colo_attr'):
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations()
@ -228,10 +228,10 @@ class ShardedModelV2(nn.Module):
if self.shard_param:
tensor_list = []
for p in self.module.parameters():
if not p.col_attr.param_is_sharded:
tensor_list.append(p.col_attr.sharded_data_tensor)
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
p.col_attr.remove_torch_payload()
if not p.colo_attr.param_is_sharded:
tensor_list.append(p.colo_attr.sharded_data_tensor)
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
p.colo_attr.remove_torch_payload()
self.shard_strategy.shard(tensor_list, self.process_group)
# 4. move sharded param grad payload to param.grad
@ -245,27 +245,27 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired.
if not self._require_backward_grad_sync:
continue
# Reduced grad is saved in `p.col_attr.saved_grad`
# Reduced grad is saved in `p.colo_attr.saved_grad`
# It can be on CPU or CUDA
# It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if self.reuse_fp16_shard:
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload
else:
grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload)
assert isinstance(grad_fp16_payload, torch.Tensor)
if p.col_attr.offload_grad:
if p.colo_attr.offload_grad:
colo_model_data_move_to_cpu(grad_fp16_payload)
if not p.col_attr.saved_grad.is_null():
if not p.colo_attr.saved_grad.is_null():
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
# Accumulate grad, saved grad must be fp32
p.col_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.col_attr.saved_grad.payload))
p.col_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.saved_grad.payload))
p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload))
p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload))
else:
p.col_attr.saved_grad.reset_payload(grad_fp16_payload)
p.colo_attr.saved_grad.reset_payload(grad_fp16_payload)
p.grad = None
p.col_attr.fp16_grad.set_null()
p.colo_attr.fp16_grad.set_null()
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
@ -273,7 +273,7 @@ class ShardedModelV2(nn.Module):
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example::
GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4]
@ -285,7 +285,7 @@ class ShardedModelV2(nn.Module):
The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param.col_attr.grad`, which ensures that
alignment is created by `param.colo_attr.grad`, which ensures that
the local optimizer only sees the relevant parameter shard.
"""
if grad is None:
@ -323,20 +323,20 @@ class ShardedModelV2(nn.Module):
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
if self.reuse_fp16_shard:
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
param.col_attr.sharded_data_tensor.is_sharded = True
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
param.colo_attr.sharded_data_tensor.is_sharded = True
else:
param.col_attr.fp16_grad = StatefulTensor(reduced_grad.data)
param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
self.process_group)
prev_params = {}
for p in self.module.parameters():
prev_params[p] = p.data
p.data = p.col_attr.sharded_data_tensor.payload
p.data = p.colo_attr.sharded_data_tensor.payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
self.process_group)
for p in self.module.parameters():
p.data = prev_params[p]

View File

@ -10,10 +10,10 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
Note the other_model has to be the same as self.
"""
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
assert hasattr(zero_param, 'col_attr')
shard_flag = zero_param.col_attr.sharded_data_tensor.is_sharded
assert hasattr(zero_param, 'colo_attr')
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
if shard_flag:
sharded_model.shard_strategy.gather([zero_param.col_attr.sharded_data_tensor])
param.data = copy.deepcopy(zero_param.col_attr.sharded_data_tensor.payload)
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload)
if shard_flag:
sharded_model.shard_strategy.shard([zero_param.col_attr.sharded_data_tensor])
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])

View File

@ -116,18 +116,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups:
for p in group['params']:
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device)
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)
if not is_param_sharded:
# In this branch, there's no need to shard param
# So we gather here
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0])
@ -201,30 +201,30 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
ranks=[0])
# Copy master param data (fp32) to payload of col_attr (fp16)
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# We use ZeRO-2 here
# The `p.col_attr.sharded_data_tensor` saves full fp16 param
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.col_attr.sharded_data_tensor.reset_payload(
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded:
# We gather full fp16 param here
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.col_attr.sharded_data_tensor.payload
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.colo_attr.sharded_data_tensor.payload
return ret
def backward(self, loss: Tensor) -> None:
@ -292,7 +292,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.col_attr.offload_grad = False
p.colo_attr.offload_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem
def _prepare_grads(self):
@ -301,7 +301,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.col_attr.saved_grad.payload here
p.data = p.col_attr.saved_grad.payload
p.grad = p.col_attr.saved_grad.payload
p.col_attr.saved_grad.set_null()
# We just set p.data = p.colo_attr.saved_grad.payload here
p.data = p.colo_attr.saved_grad.payload
p.grad = p.colo_attr.saved_grad.payload
p.colo_attr.saved_grad.set_null()

View File

@ -61,22 +61,22 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
model = MoeModel()
for name, param in model.named_parameters():
assert hasattr(param, 'col_attr')
assert hasattr(param, 'colo_attr')
# the weights in the gate should be fp32
if 'gate' in name:
assert param.col_attr.sharded_data_tensor.dtype == torch.float32
assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
else:
assert param.col_attr.sharded_data_tensor.dtype == torch.half
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
# the parameters in moe experts and its gate should not be sharded
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
assert not param.col_attr.sharded_data_tensor.is_sharded
assert not param.colo_attr.sharded_data_tensor.is_sharded
else:
assert param.col_attr.sharded_data_tensor.is_sharded
assert param.colo_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
def _run_dist(rank, world_size, port):

View File

@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
# zero_grad = zero_p.grad.clone().to(p.device)
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks):
continue
@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float()
else:
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks):
continue

View File

@ -45,11 +45,11 @@ def run_model_test(init_device_type, shard_strategy_class):
model = model_builder(checkpoint=True)
for param in model.parameters():
assert hasattr(param, 'col_attr')
assert param.col_attr.sharded_data_tensor.dtype == torch.half
assert param.col_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
assert hasattr(param, 'colo_attr')
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
assert param.colo_attr.sharded_data_tensor.is_sharded
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
model_data_cuda_mem_MB = cuda_mem_use / 1e6