mirror of https://github.com/hpcaitech/ColossalAI
[polish] rename col_attr -> colo_attr (#558)
parent
2c45efc398
commit
7675366fce
|
@ -35,58 +35,58 @@ class ZeroHook(BaseOpHook):
|
|||
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.remove_torch_payload()
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters(recurse=False):
|
||||
colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
|
||||
param.data = param.colo_attr.sharded_data_tensor.payload
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
tensor_list = []
|
||||
for param in module.parameters(recurse=False):
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
assert hasattr(param, 'colo_attr')
|
||||
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
for param in module.parameters(recurse=False):
|
||||
param.col_attr.remove_torch_payload()
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
|
|
@ -45,8 +45,8 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]:
|
|||
cuda_mem_usage = 0
|
||||
cpu_mem_usage = 0
|
||||
for param in model.parameters():
|
||||
if hasattr(param, 'col_attr'):
|
||||
t_cuda, t_cpu = param.col_attr.get_memory_usage()
|
||||
if hasattr(param, 'colo_attr'):
|
||||
t_cuda, t_cpu = param.colo_attr.get_memory_usage()
|
||||
cuda_mem_usage += t_cuda
|
||||
cpu_mem_usage += t_cpu
|
||||
else:
|
||||
|
|
|
@ -162,8 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
"""
|
||||
if not self.rm_torch_payload_on_the_fly:
|
||||
for param in self.initialized_param_list:
|
||||
assert hasattr(param, 'col_attr')
|
||||
param.col_attr.remove_torch_payload()
|
||||
assert hasattr(param, 'colo_attr')
|
||||
param.colo_attr.remove_torch_payload()
|
||||
|
||||
del self.initialized_param_list
|
||||
|
||||
|
@ -178,7 +178,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
for param in module.parameters(recurse=False):
|
||||
# avoid adapting a param to ShardedParam twice
|
||||
if hasattr(param, 'col_attr'):
|
||||
if hasattr(param, 'colo_attr'):
|
||||
continue
|
||||
|
||||
self.model_numel_tensor += param.numel()
|
||||
|
@ -196,10 +196,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
param.colo_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.initialized_param_list.append(param)
|
||||
|
||||
# We must cast buffers
|
||||
|
|
|
@ -70,9 +70,9 @@ class ShardedModelV2(nn.Module):
|
|||
sharded = []
|
||||
unsharded = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr'), 'You must use ZeroInitContext to init your module first.'
|
||||
sharded.append(param.col_attr.param_is_sharded)
|
||||
unsharded.append(not param.col_attr.param_is_sharded)
|
||||
assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.'
|
||||
sharded.append(param.colo_attr.param_is_sharded)
|
||||
unsharded.append(not param.colo_attr.param_is_sharded)
|
||||
assert all(sharded) or all(
|
||||
unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.'
|
||||
self.shard_param = all(sharded)
|
||||
|
@ -103,7 +103,7 @@ class ShardedModelV2(nn.Module):
|
|||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
for param in module.parameters():
|
||||
# Init `offload_grad`
|
||||
param.col_attr.offload_grad = self._cpu_offload
|
||||
param.colo_attr.offload_grad = self._cpu_offload
|
||||
|
||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||
# So we use 1.0 as the default gradient_predivide_factor
|
||||
|
@ -162,13 +162,13 @@ class ShardedModelV2(nn.Module):
|
|||
self._memstats_collector.start_collection()
|
||||
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'col_attr'):
|
||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
if hasattr(p, 'colo_attr'):
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
def _post_forward_operations(self):
|
||||
for p in self.module.parameters():
|
||||
if hasattr(p, 'col_attr'):
|
||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
if hasattr(p, 'colo_attr'):
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD)
|
||||
|
||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||
self._pre_forward_operations()
|
||||
|
@ -228,10 +228,10 @@ class ShardedModelV2(nn.Module):
|
|||
if self.shard_param:
|
||||
tensor_list = []
|
||||
for p in self.module.parameters():
|
||||
if not p.col_attr.param_is_sharded:
|
||||
tensor_list.append(p.col_attr.sharded_data_tensor)
|
||||
p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.col_attr.remove_torch_payload()
|
||||
if not p.colo_attr.param_is_sharded:
|
||||
tensor_list.append(p.colo_attr.sharded_data_tensor)
|
||||
p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
p.colo_attr.remove_torch_payload()
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
|
||||
# 4. move sharded param grad payload to param.grad
|
||||
|
@ -245,27 +245,27 @@ class ShardedModelV2(nn.Module):
|
|||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
# Reduced grad is saved in `p.col_attr.saved_grad`
|
||||
# Reduced grad is saved in `p.colo_attr.saved_grad`
|
||||
# It can be on CPU or CUDA
|
||||
# It can be fp16 or fp32
|
||||
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
|
||||
if self.reuse_fp16_shard:
|
||||
grad_fp16_payload = p.col_attr.sharded_data_tensor.payload
|
||||
grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload
|
||||
else:
|
||||
grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
|
||||
grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload)
|
||||
assert isinstance(grad_fp16_payload, torch.Tensor)
|
||||
if p.col_attr.offload_grad:
|
||||
if p.colo_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(grad_fp16_payload)
|
||||
if not p.col_attr.saved_grad.is_null():
|
||||
if not p.colo_attr.saved_grad.is_null():
|
||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
# Accumulate grad, saved grad must be fp32
|
||||
p.col_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.col_attr.saved_grad.payload))
|
||||
p.col_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.saved_grad.payload))
|
||||
p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload))
|
||||
p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload))
|
||||
else:
|
||||
p.col_attr.saved_grad.reset_payload(grad_fp16_payload)
|
||||
p.colo_attr.saved_grad.reset_payload(grad_fp16_payload)
|
||||
|
||||
p.grad = None
|
||||
p.col_attr.fp16_grad.set_null()
|
||||
p.colo_attr.fp16_grad.set_null()
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
|
@ -273,7 +273,7 @@ class ShardedModelV2(nn.Module):
|
|||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||
full gradient for the local batch. The reduce-scatter op will save
|
||||
a single shard of the summed gradient across all
|
||||
GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example::
|
||||
GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example::
|
||||
|
||||
before reduce_scatter:
|
||||
param.grad (GPU #0): [1, 2, 3, 4]
|
||||
|
@ -285,7 +285,7 @@ class ShardedModelV2(nn.Module):
|
|||
|
||||
The local GPU's ``optim.step`` is responsible for updating a single
|
||||
shard of params, also corresponding to the current GPU's rank. This
|
||||
alignment is created by `param.col_attr.grad`, which ensures that
|
||||
alignment is created by `param.colo_attr.grad`, which ensures that
|
||||
the local optimizer only sees the relevant parameter shard.
|
||||
"""
|
||||
if grad is None:
|
||||
|
@ -323,20 +323,20 @@ class ShardedModelV2(nn.Module):
|
|||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||
if self.reuse_fp16_shard:
|
||||
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
||||
param.col_attr.sharded_data_tensor.is_sharded = True
|
||||
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||
else:
|
||||
param.col_attr.fp16_grad = StatefulTensor(reduced_grad.data)
|
||||
param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
prev_params = {}
|
||||
for p in self.module.parameters():
|
||||
prev_params[p] = p.data
|
||||
p.data = p.col_attr.sharded_data_tensor.payload
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
for p in self.module.parameters():
|
||||
p.data = prev_params[p]
|
||||
|
|
|
@ -10,10 +10,10 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
|
|||
Note the other_model has to be the same as self.
|
||||
"""
|
||||
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
|
||||
assert hasattr(zero_param, 'col_attr')
|
||||
shard_flag = zero_param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert hasattr(zero_param, 'colo_attr')
|
||||
shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.gather([zero_param.col_attr.sharded_data_tensor])
|
||||
param.data = copy.deepcopy(zero_param.col_attr.sharded_data_tensor.payload)
|
||||
sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor])
|
||||
param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload)
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.shard([zero_param.col_attr.sharded_data_tensor])
|
||||
sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor])
|
||||
|
|
|
@ -116,18 +116,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
||||
assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
# TODO (ver217): we may not use shard / gather here
|
||||
# Param is no sharded, which means we use ZeRO-2 here
|
||||
# As we only store param shard, we shard it here
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device)
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)
|
||||
if not is_param_sharded:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
|
@ -201,30 +201,30 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self._logger.debug(
|
||||
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
# Copy master param data (fp32) to payload of col_attr (fp16)
|
||||
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||
# a chunk.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
# We use ZeRO-2 here
|
||||
# The `p.col_attr.sharded_data_tensor` saves full fp16 param
|
||||
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
||||
# But we only have updated fp32 param shard here
|
||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
||||
# Then we will gather them
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
# We have to use `copy_payload` instead of `reset_payload`
|
||||
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
|
||||
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.col_attr.sharded_data_tensor.reset_payload(
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||
|
||||
if not is_param_sharded:
|
||||
# We gather full fp16 param here
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
p.data = p.col_attr.sharded_data_tensor.payload
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
return ret
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
|
@ -292,7 +292,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||
p.col_attr.offload_grad = False
|
||||
p.colo_attr.offload_grad = False
|
||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||
|
||||
def _prepare_grads(self):
|
||||
|
@ -301,7 +301,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
||||
# If we change p.grad directly
|
||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||
# We just set p.data = p.col_attr.saved_grad.payload here
|
||||
p.data = p.col_attr.saved_grad.payload
|
||||
p.grad = p.col_attr.saved_grad.payload
|
||||
p.col_attr.saved_grad.set_null()
|
||||
# We just set p.data = p.colo_attr.saved_grad.payload here
|
||||
p.data = p.colo_attr.saved_grad.payload
|
||||
p.grad = p.colo_attr.saved_grad.payload
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
|
|
@ -61,22 +61,22 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
|
|||
model = MoeModel()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
assert hasattr(param, 'colo_attr')
|
||||
|
||||
# the weights in the gate should be fp32
|
||||
if 'gate' in name:
|
||||
assert param.col_attr.sharded_data_tensor.dtype == torch.float32
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
|
||||
else:
|
||||
assert param.col_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||
|
||||
# the parameters in moe experts and its gate should not be sharded
|
||||
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
|
||||
assert not param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert not param.colo_attr.sharded_data_tensor.is_sharded
|
||||
else:
|
||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
|
||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
|
||||
def _run_dist(rank, world_size, port):
|
||||
|
|
|
@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False):
|
|||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
# zero_grad = zero_p.grad.clone().to(p.device)
|
||||
zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device)
|
||||
zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device)
|
||||
chunks = torch.flatten(p.grad).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
|
|||
if reuse_fp16_shard:
|
||||
zero_p = zero_p.data.to(p.device).float()
|
||||
else:
|
||||
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
|
|
@ -45,11 +45,11 @@ def run_model_test(init_device_type, shard_strategy_class):
|
|||
model = model_builder(checkpoint=True)
|
||||
|
||||
for param in model.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
assert param.col_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
assert hasattr(param, 'colo_attr')
|
||||
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.colo_attr.sharded_data_tensor.is_sharded
|
||||
assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model)
|
||||
model_data_cuda_mem_MB = cuda_mem_use / 1e6
|
||||
|
|
Loading…
Reference in New Issue