[polish] rename col_attr -> colo_attr (#558)

pull/563/head
Jiarui Fang 2022-03-31 12:25:45 +08:00 committed by GitHub
parent 2c45efc398
commit 7675366fce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 91 additions and 91 deletions

View File

@ -35,58 +35,58 @@ class ZeroHook(BaseOpHook):
def pre_fwd_exec(self, module: torch.nn.Module, *args): def pre_fwd_exec(self, module: torch.nn.Module, *args):
tensor_list = [] tensor_list = []
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'colo_attr')
tensor_list.append(param.col_attr.sharded_data_tensor) tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group) self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.col_attr.sharded_data_tensor.payload param.data = param.colo_attr.sharded_data_tensor.payload
if self._memstarts_collector: if self._memstarts_collector:
self._memstarts_collector.sample_memstats() self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_fwd_exec(self, module: torch.nn.Module, *args): def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
tensor_list = [] tensor_list = []
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'colo_attr')
tensor_list.append(param.col_attr.sharded_data_tensor) tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group) self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.col_attr.remove_torch_payload() param.colo_attr.remove_torch_payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output): def pre_bwd_exec(self, module: torch.nn.Module, input, output):
tensor_list = [] tensor_list = []
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'colo_attr')
tensor_list.append(param.col_attr.sharded_data_tensor) tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group) self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.col_attr.sharded_data_tensor.payload param.data = param.colo_attr.sharded_data_tensor.payload
if self._memstarts_collector: if self._memstarts_collector:
self._memstarts_collector.sample_memstats() self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_bwd_exec(self, module: torch.nn.Module, input): def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
tensor_list = [] tensor_list = []
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
assert hasattr(param, 'col_attr') assert hasattr(param, 'colo_attr')
tensor_list.append(param.col_attr.sharded_data_tensor) tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group) self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
param.col_attr.remove_torch_payload() param.colo_attr.remove_torch_payload()
def pre_iter(self): def pre_iter(self):
pass pass

View File

@ -45,8 +45,8 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
cuda_mem_usage = 0 cuda_mem_usage = 0
cpu_mem_usage = 0 cpu_mem_usage = 0
for param in model.parameters(): for param in model.parameters():
if hasattr(param, 'col_attr'): if hasattr(param, 'colo_attr'):
t_cuda, t_cpu = param.col_attr.get_memory_usage() t_cuda, t_cpu = param.colo_attr.get_memory_usage()
cuda_mem_usage += t_cuda cuda_mem_usage += t_cuda
cpu_mem_usage += t_cpu cpu_mem_usage += t_cpu
else: else:

View File

@ -162,8 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
""" """
if not self.rm_torch_payload_on_the_fly: if not self.rm_torch_payload_on_the_fly:
for param in self.initialized_param_list: for param in self.initialized_param_list:
assert hasattr(param, 'col_attr') assert hasattr(param, 'colo_attr')
param.col_attr.remove_torch_payload() param.colo_attr.remove_torch_payload()
del self.initialized_param_list del self.initialized_param_list
@ -178,7 +178,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
for param in module.parameters(recurse=False): for param in module.parameters(recurse=False):
# avoid adapting a param to ShardedParam twice # avoid adapting a param to ShardedParam twice
if hasattr(param, 'col_attr'): if hasattr(param, 'colo_attr'):
continue continue
self.model_numel_tensor += param.numel() self.model_numel_tensor += param.numel()
@ -196,10 +196,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if param.grad is not None: if param.grad is not None:
param.grad = param.grad.to(target_device) param.grad = param.grad.to(target_device)
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly) param.colo_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
if self.shard_param: if self.shard_param:
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
self.initialized_param_list.append(param) self.initialized_param_list.append(param)
# We must cast buffers # We must cast buffers

View File

@ -70,9 +70,9 @@ class ShardedModelV2(nn.Module):
sharded = [] sharded = []
unsharded = [] unsharded = []
for param in module.parameters(): for param in module.parameters():
assert hasattr(param, 'col_attr'), 'You must use ZeroInitContext to init your module first.' assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
sharded.append(param.col_attr.param_is_sharded) sharded.append(param.colo_attr.param_is_sharded)
unsharded.append(not param.col_attr.param_is_sharded) unsharded.append(not param.colo_attr.param_is_sharded)
assert all(sharded) or all( assert all(sharded) or all(
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.' unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
self.shard_param = all(sharded) self.shard_param = all(sharded)
@ -103,7 +103,7 @@ class ShardedModelV2(nn.Module):
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
for param in module.parameters(): for param in module.parameters():
# Init `offload_grad` # Init `offload_grad`
param.col_attr.offload_grad = self._cpu_offload param.colo_attr.offload_grad = self._cpu_offload
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor # So we use 1.0 as the default gradient_predivide_factor
@ -162,13 +162,13 @@ class ShardedModelV2(nn.Module):
self._memstats_collector.start_collection() self._memstats_collector.start_collection()
for p in self.module.parameters(): for p in self.module.parameters():
if hasattr(p, 'col_attr'): if hasattr(p, 'colo_attr'):
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD) p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def _post_forward_operations(self): def _post_forward_operations(self):
for p in self.module.parameters(): for p in self.module.parameters():
if hasattr(p, 'col_attr'): if hasattr(p, 'colo_attr'):
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD) p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
self._pre_forward_operations() self._pre_forward_operations()
@ -228,10 +228,10 @@ class ShardedModelV2(nn.Module):
if self.shard_param: if self.shard_param:
tensor_list = [] tensor_list = []
for p in self.module.parameters(): for p in self.module.parameters():
if not p.col_attr.param_is_sharded: if not p.colo_attr.param_is_sharded:
tensor_list.append(p.col_attr.sharded_data_tensor) tensor_list.append(p.colo_attr.sharded_data_tensor)
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
p.col_attr.remove_torch_payload() p.colo_attr.remove_torch_payload()
self.shard_strategy.shard(tensor_list, self.process_group) self.shard_strategy.shard(tensor_list, self.process_group)
# 4. move sharded param grad payload to param.grad # 4. move sharded param grad payload to param.grad
@ -245,27 +245,27 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired. # We also allows to interleave no-sync pass with sync passes, if desired.
if not self._require_backward_grad_sync: if not self._require_backward_grad_sync:
continue continue
# Reduced grad is saved in `p.col_attr.saved_grad` # Reduced grad is saved in `p.colo_attr.saved_grad`
# It can be on CPU or CUDA # It can be on CPU or CUDA
# It can be fp16 or fp32 # It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`. # We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if self.reuse_fp16_shard: if self.reuse_fp16_shard:
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload
else: else:
grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload) grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload)
assert isinstance(grad_fp16_payload, torch.Tensor) assert isinstance(grad_fp16_payload, torch.Tensor)
if p.col_attr.offload_grad: if p.colo_attr.offload_grad:
colo_model_data_move_to_cpu(grad_fp16_payload) colo_model_data_move_to_cpu(grad_fp16_payload)
if not p.col_attr.saved_grad.is_null(): if not p.colo_attr.saved_grad.is_null():
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
# Accumulate grad, saved grad must be fp32 # Accumulate grad, saved grad must be fp32
p.col_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.col_attr.saved_grad.payload)) p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload))
p.col_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.saved_grad.payload)) p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload))
else: else:
p.col_attr.saved_grad.reset_payload(grad_fp16_payload) p.colo_attr.saved_grad.reset_payload(grad_fp16_payload)
p.grad = None p.grad = None
p.col_attr.fp16_grad.set_null() p.colo_attr.fp16_grad.set_null()
@torch.no_grad() @torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
@ -273,7 +273,7 @@ class ShardedModelV2(nn.Module):
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all a single shard of the summed gradient across all
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example:: GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example::
before reduce_scatter: before reduce_scatter:
param.grad (GPU #0): [1, 2, 3, 4] param.grad (GPU #0): [1, 2, 3, 4]
@ -285,7 +285,7 @@ class ShardedModelV2(nn.Module):
The local GPU's ``optim.step`` is responsible for updating a single The local GPU's ``optim.step`` is responsible for updating a single
shard of params, also corresponding to the current GPU's rank. This shard of params, also corresponding to the current GPU's rank. This
alignment is created by `param.col_attr.grad`, which ensures that alignment is created by `param.colo_attr.grad`, which ensures that
the local optimizer only sees the relevant parameter shard. the local optimizer only sees the relevant parameter shard.
""" """
if grad is None: if grad is None:
@ -323,20 +323,20 @@ class ShardedModelV2(nn.Module):
# Average grad by world_size for consistency with PyTorch DDP. # Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor) reduced_grad.data.div_(self.gradient_postdivide_factor)
if self.reuse_fp16_shard: if self.reuse_fp16_shard:
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data) param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
param.col_attr.sharded_data_tensor.is_sharded = True param.colo_attr.sharded_data_tensor.is_sharded = True
else: else:
param.col_attr.fp16_grad = StatefulTensor(reduced_grad.data) param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()], self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
self.process_group) self.process_group)
prev_params = {} prev_params = {}
for p in self.module.parameters(): for p in self.module.parameters():
prev_params[p] = p.data prev_params[p] = p.data
p.data = p.col_attr.sharded_data_tensor.payload p.data = p.colo_attr.sharded_data_tensor.payload
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()], self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
self.process_group) self.process_group)
for p in self.module.parameters(): for p in self.module.parameters():
p.data = prev_params[p] p.data = prev_params[p]

View File

@ -10,10 +10,10 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
Note the other_model has to be the same as self. Note the other_model has to be the same as self.
""" """
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
assert hasattr(zero_param, 'col_attr') assert hasattr(zero_param, 'colo_attr')
shard_flag = zero_param.col_attr.sharded_data_tensor.is_sharded shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
if shard_flag: if shard_flag:
sharded_model.shard_strategy.gather([zero_param.col_attr.sharded_data_tensor]) sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
param.data = copy.deepcopy(zero_param.col_attr.sharded_data_tensor.payload) param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload)
if shard_flag: if shard_flag:
sharded_model.shard_strategy.shard([zero_param.col_attr.sharded_data_tensor]) sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])

View File

@ -116,18 +116,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded: if not is_param_sharded:
# TODO (ver217): we may not use shard / gather here # TODO (ver217): we may not use shard / gather here
# Param is no sharded, which means we use ZeRO-2 here # Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here # As we only store param shard, we shard it here
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device) self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)
if not is_param_sharded: if not is_param_sharded:
# In this branch, there's no need to shard param # In this branch, there's no need to shard param
# So we gather here # So we gather here
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
ranks=[0]) ranks=[0])
@ -201,30 +201,30 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger.debug( self._logger.debug(
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
ranks=[0]) ranks=[0])
# Copy master param data (fp32) to payload of col_attr (fp16) # Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering # TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk. # a chunk.
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded: if not is_param_sharded:
# We use ZeRO-2 here # We use ZeRO-2 here
# The `p.col_attr.sharded_data_tensor` saves full fp16 param # The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here # But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it # So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them # Then we will gather them
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload` # We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16 # Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16) # TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.col_attr.sharded_data_tensor.reset_payload( p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device())) colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded: if not is_param_sharded:
# We gather full fp16 param here # We gather full fp16 param here
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.col_attr.sharded_data_tensor.payload p.data = p.colo_attr.sharded_data_tensor.payload
return ret return ret
def backward(self, loss: Tensor) -> None: def backward(self, loss: Tensor) -> None:
@ -292,7 +292,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device()) self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device()) p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.col_attr.offload_grad = False p.colo_attr.offload_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem fp32_shards_used_cuda_margin_mem += shard_mem
def _prepare_grads(self): def _prepare_grads(self):
@ -301,7 +301,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation # FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly # If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad # it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.col_attr.saved_grad.payload here # We just set p.data = p.colo_attr.saved_grad.payload here
p.data = p.col_attr.saved_grad.payload p.data = p.colo_attr.saved_grad.payload
p.grad = p.col_attr.saved_grad.payload p.grad = p.colo_attr.saved_grad.payload
p.col_attr.saved_grad.set_null() p.colo_attr.saved_grad.set_null()

View File

@ -61,22 +61,22 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
model = MoeModel() model = MoeModel()
for name, param in model.named_parameters(): for name, param in model.named_parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'colo_attr')
# the weights in the gate should be fp32 # the weights in the gate should be fp32
if 'gate' in name: if 'gate' in name:
assert param.col_attr.sharded_data_tensor.dtype == torch.float32 assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
else: else:
assert param.col_attr.sharded_data_tensor.dtype == torch.half assert param.colo_attr.sharded_data_tensor.dtype == torch.half
# the parameters in moe experts and its gate should not be sharded # the parameters in moe experts and its gate should not be sharded
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
assert not param.col_attr.sharded_data_tensor.is_sharded assert not param.colo_attr.sharded_data_tensor.is_sharded
else: else:
assert param.col_attr.sharded_data_tensor.is_sharded assert param.colo_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \ assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
def _run_dist(rank, world_size, port): def _run_dist(rank, world_size, port):

View File

@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
rank = dist.get_rank() rank = dist.get_rank()
for p, zero_p in zip(model.parameters(), zero_model.parameters()): for p, zero_p in zip(model.parameters(), zero_model.parameters()):
# zero_grad = zero_p.grad.clone().to(p.device) # zero_grad = zero_p.grad.clone().to(p.device)
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device) zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
if rank >= len(chunks): if rank >= len(chunks):
continue continue
@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
if reuse_fp16_shard: if reuse_fp16_shard:
zero_p = zero_p.data.to(p.device).float() zero_p = zero_p.data.to(p.device).float()
else: else:
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float() zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
chunks = torch.flatten(p).chunk(dist.get_world_size()) chunks = torch.flatten(p).chunk(dist.get_world_size())
if rank >= len(chunks): if rank >= len(chunks):
continue continue

View File

@ -45,11 +45,11 @@ def run_model_test(init_device_type, shard_strategy_class):
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
for param in model.parameters(): for param in model.parameters():
assert hasattr(param, 'col_attr') assert hasattr(param, 'colo_attr')
assert param.col_attr.sharded_data_tensor.dtype == torch.half assert param.colo_attr.sharded_data_tensor.dtype == torch.half
assert param.col_attr.sharded_data_tensor.is_sharded assert param.colo_attr.sharded_data_tensor.is_sharded
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \ assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model) cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
model_data_cuda_mem_MB = cuda_mem_use / 1e6 model_data_cuda_mem_MB = cuda_mem_use / 1e6