From 7675366fce17cf6883c7e1a22c4cc6f52241c313 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 31 Mar 2022 12:25:45 +0800 Subject: [PATCH] [polish] rename col_attr -> colo_attr (#558) --- colossalai/engine/ophooks/zero_hook.py | 36 ++++++------ .../memory_tracer/model_data_memtracer.py | 4 +- colossalai/zero/init_ctx/init_context.py | 10 ++-- .../zero/sharded_model/sharded_model_v2.py | 58 +++++++++---------- colossalai/zero/sharded_model/utils.py | 10 ++-- .../zero/sharded_optim/sharded_optim_v2.py | 36 ++++++------ tests/test_moe/test_moe_zero_init.py | 14 ++--- tests/test_zero_data_parallel/common.py | 4 +- .../test_init_context.py | 10 ++-- 9 files changed, 91 insertions(+), 91 deletions(-) diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index a3e40ba8b..8f937750d 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -35,58 +35,58 @@ class ZeroHook(BaseOpHook): def pre_fwd_exec(self, module: torch.nn.Module, *args): tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'col_attr') - tensor_list.append(param.col_attr.sharded_data_tensor) + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) for param in module.parameters(recurse=False): - colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) - param.data = param.col_attr.sharded_data_tensor.payload + colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) + param.data = param.colo_attr.sharded_data_tensor.payload if self._memstarts_collector: self._memstarts_collector.sample_memstats() for param in module.parameters(recurse=False): - param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) def post_fwd_exec(self, module: torch.nn.Module, *args): for param in module.parameters(recurse=False): - param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) + param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'col_attr') - tensor_list.append(param.col_attr.sharded_data_tensor) + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) for param in module.parameters(recurse=False): - param.col_attr.remove_torch_payload() + param.colo_attr.remove_torch_payload() def pre_bwd_exec(self, module: torch.nn.Module, input, output): tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'col_attr') - tensor_list.append(param.col_attr.sharded_data_tensor) + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.gather(tensor_list, self.process_group) for param in module.parameters(recurse=False): - colo_model_data_tensor_move_inline(param.col_attr.sharded_data_tensor, self.computing_device) - param.data = param.col_attr.sharded_data_tensor.payload + colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device) + param.data = param.colo_attr.sharded_data_tensor.payload if self._memstarts_collector: self._memstarts_collector.sample_memstats() for param in module.parameters(recurse=False): - param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) def post_bwd_exec(self, module: torch.nn.Module, input): for param in module.parameters(recurse=False): - param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) tensor_list = [] for param in module.parameters(recurse=False): - assert hasattr(param, 'col_attr') - tensor_list.append(param.col_attr.sharded_data_tensor) + assert hasattr(param, 'colo_attr') + tensor_list.append(param.colo_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) for param in module.parameters(recurse=False): - param.col_attr.remove_torch_payload() + param.colo_attr.remove_torch_payload() def pre_iter(self): pass diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py index c2205b693..e38587367 100644 --- a/colossalai/utils/memory_tracer/model_data_memtracer.py +++ b/colossalai/utils/memory_tracer/model_data_memtracer.py @@ -45,8 +45,8 @@ def colo_model_mem_usage(model: torch.nn.Module) -> Tuple[int, int]: cuda_mem_usage = 0 cpu_mem_usage = 0 for param in model.parameters(): - if hasattr(param, 'col_attr'): - t_cuda, t_cpu = param.col_attr.get_memory_usage() + if hasattr(param, 'colo_attr'): + t_cuda, t_cpu = param.colo_attr.get_memory_usage() cuda_mem_usage += t_cuda cpu_mem_usage += t_cpu else: diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index e093ea8db..db6431f7d 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -162,8 +162,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): """ if not self.rm_torch_payload_on_the_fly: for param in self.initialized_param_list: - assert hasattr(param, 'col_attr') - param.col_attr.remove_torch_payload() + assert hasattr(param, 'colo_attr') + param.colo_attr.remove_torch_payload() del self.initialized_param_list @@ -178,7 +178,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): for param in module.parameters(recurse=False): # avoid adapting a param to ShardedParam twice - if hasattr(param, 'col_attr'): + if hasattr(param, 'colo_attr'): continue self.model_numel_tensor += param.numel() @@ -196,10 +196,10 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if param.grad is not None: param.grad = param.grad.to(target_device) - param.col_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly) + param.colo_attr = ShardedParamV2(param, rm_torch_payload=self.rm_torch_payload_on_the_fly) if self.shard_param: - self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group) + self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group) self.initialized_param_list.append(param) # We must cast buffers diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index c4aac0001..a27da5e3b 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -70,9 +70,9 @@ class ShardedModelV2(nn.Module): sharded = [] unsharded = [] for param in module.parameters(): - assert hasattr(param, 'col_attr'), 'You must use ZeroInitContext to init your module first.' - sharded.append(param.col_attr.param_is_sharded) - unsharded.append(not param.col_attr.param_is_sharded) + assert hasattr(param, 'colo_attr'), 'You must use ZeroInitContext to init your module first.' + sharded.append(param.colo_attr.param_is_sharded) + unsharded.append(not param.colo_attr.param_is_sharded) assert all(sharded) or all( unsharded), 'Parameters must be all sharded or all unsharded! Parameters are partially sharded now.' self.shard_param = all(sharded) @@ -103,7 +103,7 @@ class ShardedModelV2(nn.Module): self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False for param in module.parameters(): # Init `offload_grad` - param.col_attr.offload_grad = self._cpu_offload + param.colo_attr.offload_grad = self._cpu_offload # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem # So we use 1.0 as the default gradient_predivide_factor @@ -162,13 +162,13 @@ class ShardedModelV2(nn.Module): self._memstats_collector.start_collection() for p in self.module.parameters(): - if hasattr(p, 'col_attr'): - p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + if hasattr(p, 'colo_attr'): + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) def _post_forward_operations(self): for p in self.module.parameters(): - if hasattr(p, 'col_attr'): - p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + if hasattr(p, 'colo_attr'): + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD) def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: self._pre_forward_operations() @@ -228,10 +228,10 @@ class ShardedModelV2(nn.Module): if self.shard_param: tensor_list = [] for p in self.module.parameters(): - if not p.col_attr.param_is_sharded: - tensor_list.append(p.col_attr.sharded_data_tensor) - p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) - p.col_attr.remove_torch_payload() + if not p.colo_attr.param_is_sharded: + tensor_list.append(p.colo_attr.sharded_data_tensor) + p.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + p.colo_attr.remove_torch_payload() self.shard_strategy.shard(tensor_list, self.process_group) # 4. move sharded param grad payload to param.grad @@ -245,27 +245,27 @@ class ShardedModelV2(nn.Module): # We also allows to interleave no-sync pass with sync passes, if desired. if not self._require_backward_grad_sync: continue - # Reduced grad is saved in `p.col_attr.saved_grad` + # Reduced grad is saved in `p.colo_attr.saved_grad` # It can be on CPU or CUDA # It can be fp16 or fp32 # We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`. if self.reuse_fp16_shard: - grad_fp16_payload = p.col_attr.sharded_data_tensor.payload + grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload else: - grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload) + grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload) assert isinstance(grad_fp16_payload, torch.Tensor) - if p.col_attr.offload_grad: + if p.colo_attr.offload_grad: colo_model_data_move_to_cpu(grad_fp16_payload) - if not p.col_attr.saved_grad.is_null(): + if not p.colo_attr.saved_grad.is_null(): assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' # Accumulate grad, saved grad must be fp32 - p.col_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.col_attr.saved_grad.payload)) - p.col_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.saved_grad.payload)) + p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload)) + p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload)) else: - p.col_attr.saved_grad.reset_payload(grad_fp16_payload) + p.colo_attr.saved_grad.reset_payload(grad_fp16_payload) p.grad = None - p.col_attr.fp16_grad.set_null() + p.colo_attr.fp16_grad.set_null() @torch.no_grad() def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: @@ -273,7 +273,7 @@ class ShardedModelV2(nn.Module): At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all - GPUs to param.col_attr.grad. This shard will align with the current GPU rank. For example:: + GPUs to param.colo_attr.grad. This shard will align with the current GPU rank. For example:: before reduce_scatter: param.grad (GPU #0): [1, 2, 3, 4] @@ -285,7 +285,7 @@ class ShardedModelV2(nn.Module): The local GPU's ``optim.step`` is responsible for updating a single shard of params, also corresponding to the current GPU's rank. This - alignment is created by `param.col_attr.grad`, which ensures that + alignment is created by `param.colo_attr.grad`, which ensures that the local optimizer only sees the relevant parameter shard. """ if grad is None: @@ -323,20 +323,20 @@ class ShardedModelV2(nn.Module): # Average grad by world_size for consistency with PyTorch DDP. reduced_grad.data.div_(self.gradient_postdivide_factor) if self.reuse_fp16_shard: - param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data) - param.col_attr.sharded_data_tensor.is_sharded = True + param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data) + param.colo_attr.sharded_data_tensor.is_sharded = True else: - param.col_attr.fp16_grad = StatefulTensor(reduced_grad.data) + param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data) def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': - self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()], + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()], self.process_group) prev_params = {} for p in self.module.parameters(): prev_params[p] = p.data - p.data = p.col_attr.sharded_data_tensor.payload + p.data = p.colo_attr.sharded_data_tensor.payload gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars) - self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()], + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor for p in self.module.parameters()], self.process_group) for p in self.module.parameters(): p.data = prev_params[p] diff --git a/colossalai/zero/sharded_model/utils.py b/colossalai/zero/sharded_model/utils.py index 4489afdc9..9777e0f63 100644 --- a/colossalai/zero/sharded_model/utils.py +++ b/colossalai/zero/sharded_model/utils.py @@ -10,10 +10,10 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu Note the other_model has to be the same as self. """ for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()): - assert hasattr(zero_param, 'col_attr') - shard_flag = zero_param.col_attr.sharded_data_tensor.is_sharded + assert hasattr(zero_param, 'colo_attr') + shard_flag = zero_param.colo_attr.sharded_data_tensor.is_sharded if shard_flag: - sharded_model.shard_strategy.gather([zero_param.col_attr.sharded_data_tensor]) - param.data = copy.deepcopy(zero_param.col_attr.sharded_data_tensor.payload) + sharded_model.shard_strategy.gather([zero_param.colo_attr.sharded_data_tensor]) + param.data = copy.deepcopy(zero_param.colo_attr.sharded_data_tensor.payload) if shard_flag: - sharded_model.shard_strategy.shard([zero_param.col_attr.sharded_data_tensor]) + sharded_model.shard_strategy.shard([zero_param.colo_attr.sharded_data_tensor]) diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 86fa7aadd..539350101 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -116,18 +116,18 @@ class ShardedOptimizerV2(ColossalaiOptimizer): for group in self.optim.param_groups: for p in group['params']: - assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' - is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded + assert hasattr(p, 'colo_attr'), 'The parameter must be wrapped with ShardedParam' + is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded if not is_param_sharded: # TODO (ver217): we may not use shard / gather here # Param is no sharded, which means we use ZeRO-2 here # As we only store param shard, we shard it here - self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group) - self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device) + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) + self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device) if not is_param_sharded: # In this branch, there's no need to shard param # So we gather here - self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) self._logger.debug(f"After init ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory!", ranks=[0]) @@ -201,30 +201,30 @@ class ShardedOptimizerV2(ColossalaiOptimizer): self._logger.debug( f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", ranks=[0]) - # Copy master param data (fp32) to payload of col_attr (fp16) + # Copy master param data (fp32) to payload of colo_attr (fp16) # TODO() improve efficiency by gathering tensors into a chunk and transfering # a chunk. for group in self.optim.param_groups: for p in group['params']: - is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded + is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded if not is_param_sharded: # We use ZeRO-2 here - # The `p.col_attr.sharded_data_tensor` saves full fp16 param + # The `p.colo_attr.sharded_data_tensor` saves full fp16 param # But we only have updated fp32 param shard here # So we first shard full fp16 param and copy fp32 param shard to it # Then we will gather them - self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group) + self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) # We have to use `copy_payload` instead of `reset_payload` - # Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16 + # Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16 # TODO() optimize this line CPU (fp32) -> GPU (fp16) - p.col_attr.sharded_data_tensor.reset_payload( + p.colo_attr.sharded_data_tensor.reset_payload( colo_model_tensor_clone(p.half(), torch.cuda.current_device())) if not is_param_sharded: # We gather full fp16 param here - self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group) - p.data = p.col_attr.sharded_data_tensor.payload + self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group) + p.data = p.colo_attr.sharded_data_tensor.payload return ret def backward(self, loss: Tensor) -> None: @@ -292,7 +292,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: self.master_params[p] = self.master_params[p].to(torch.cuda.current_device()) p.grad.data = p.grad.data.to(torch.cuda.current_device()) - p.col_attr.offload_grad = False + p.colo_attr.offload_grad = False fp32_shards_used_cuda_margin_mem += shard_mem def _prepare_grads(self): @@ -301,7 +301,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer): # FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation # If we change p.grad directly # it may raise error because of different shape/dtype/device of p.data and p.grad - # We just set p.data = p.col_attr.saved_grad.payload here - p.data = p.col_attr.saved_grad.payload - p.grad = p.col_attr.saved_grad.payload - p.col_attr.saved_grad.set_null() + # We just set p.data = p.colo_attr.saved_grad.payload here + p.data = p.colo_attr.saved_grad.payload + p.grad = p.colo_attr.saved_grad.payload + p.colo_attr.saved_grad.set_null() diff --git a/tests/test_moe/test_moe_zero_init.py b/tests/test_moe/test_moe_zero_init.py index 0c4e66ffa..d8f46ddea 100644 --- a/tests/test_moe/test_moe_zero_init.py +++ b/tests/test_moe/test_moe_zero_init.py @@ -61,22 +61,22 @@ def run_moe_zero_init(init_device_type, shard_strategy_class): model = MoeModel() for name, param in model.named_parameters(): - assert hasattr(param, 'col_attr') + assert hasattr(param, 'colo_attr') # the weights in the gate should be fp32 if 'gate' in name: - assert param.col_attr.sharded_data_tensor.dtype == torch.float32 + assert param.colo_attr.sharded_data_tensor.dtype == torch.float32 else: - assert param.col_attr.sharded_data_tensor.dtype == torch.half + assert param.colo_attr.sharded_data_tensor.dtype == torch.half # the parameters in moe experts and its gate should not be sharded if ('experts' in name) or ('gate' in name) or ('residual_combine' in name): - assert not param.col_attr.sharded_data_tensor.is_sharded + assert not param.colo_attr.sharded_data_tensor.is_sharded else: - assert param.col_attr.sharded_data_tensor.is_sharded + assert param.colo_attr.sharded_data_tensor.is_sharded - assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \ - f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' + assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \ + f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' def _run_dist(rank, world_size, port): diff --git a/tests/test_zero_data_parallel/common.py b/tests/test_zero_data_parallel/common.py index cdac165ee..2abc8c53d 100644 --- a/tests/test_zero_data_parallel/common.py +++ b/tests/test_zero_data_parallel/common.py @@ -93,7 +93,7 @@ def check_grads_padding(model, zero_model, loose=False): rank = dist.get_rank() for p, zero_p in zip(model.parameters(), zero_model.parameters()): # zero_grad = zero_p.grad.clone().to(p.device) - zero_grad = zero_p.col_attr.saved_grad.payload.clone().to(p.device) + zero_grad = zero_p.colo_attr.saved_grad.payload.clone().to(p.device) chunks = torch.flatten(p.grad).chunk(dist.get_world_size()) if rank >= len(chunks): continue @@ -124,7 +124,7 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard= if reuse_fp16_shard: zero_p = zero_p.data.to(p.device).float() else: - zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float() + zero_p = zero_p.colo_attr.sharded_data_tensor.payload.to(p.device).float() chunks = torch.flatten(p).chunk(dist.get_world_size()) if rank >= len(chunks): continue diff --git a/tests/test_zero_data_parallel/test_init_context.py b/tests/test_zero_data_parallel/test_init_context.py index 061ebc6b8..bcdc51b97 100644 --- a/tests/test_zero_data_parallel/test_init_context.py +++ b/tests/test_zero_data_parallel/test_init_context.py @@ -45,11 +45,11 @@ def run_model_test(init_device_type, shard_strategy_class): model = model_builder(checkpoint=True) for param in model.parameters(): - assert hasattr(param, 'col_attr') - assert param.col_attr.sharded_data_tensor.dtype == torch.half - assert param.col_attr.sharded_data_tensor.is_sharded - assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \ - f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' + assert hasattr(param, 'colo_attr') + assert param.colo_attr.sharded_data_tensor.dtype == torch.half + assert param.colo_attr.sharded_data_tensor.is_sharded + assert param.colo_attr.sharded_data_tensor.payload.device.type == init_device.type, \ + f'{param.colo_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}' cuda_mem_use, cpu_mem_use = colo_model_mem_usage(model) model_data_cuda_mem_MB = cuda_mem_use / 1e6