From 11f1e426fe2b549f8745d5036d4e20bc3dc411ed Mon Sep 17 00:00:00 2001 From: littsk <1214689160@qq.com> Date: Wed, 27 Sep 2023 10:43:03 +0800 Subject: [PATCH] [hotfix] Correct several erroneous code comments (#4794) --- colossalai/shardformer/policies/base_policy.py | 2 +- colossalai/zero/low_level/bookkeeping/bucket_store.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index e7f199129..eb0350053 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -50,7 +50,7 @@ class ModulePolicyDescription: new_weight = shard_rowwise(weight, process_group) module.weight = torch.nn.Parameter(new_weight) ``` - sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription + sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription object which specifies the module to be replaced and the target module used to replacement. method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement """ diff --git a/colossalai/zero/low_level/bookkeeping/bucket_store.py b/colossalai/zero/low_level/bookkeeping/bucket_store.py index 2a75d7047..2828d5175 100644 --- a/colossalai/zero/low_level/bookkeeping/bucket_store.py +++ b/colossalai/zero/low_level/bookkeeping/bucket_store.py @@ -92,7 +92,7 @@ class BucketStore(BaseStore): def get_flatten_grad(self) -> Tensor: """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: - [grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....] + [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....] Returns: Tensor: the flattened gradients slices in the bucket