[hotfix] Correct several erroneous code comments (#4794)

pull/4815/head
littsk 2023-09-27 10:43:03 +08:00 committed by GitHub
parent 54b3ad8924
commit 11f1e426fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -50,7 +50,7 @@ class ModulePolicyDescription:
new_weight = shard_rowwise(weight, process_group) new_weight = shard_rowwise(weight, process_group)
module.weight = torch.nn.Parameter(new_weight) module.weight = torch.nn.Parameter(new_weight)
``` ```
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription
object which specifies the module to be replaced and the target module used to replacement. object which specifies the module to be replaced and the target module used to replacement.
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
""" """

View File

@ -92,7 +92,7 @@ class BucketStore(BaseStore):
def get_flatten_grad(self) -> Tensor: def get_flatten_grad(self) -> Tensor:
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor: """Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
[grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....] [grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]
Returns: Returns:
Tensor: the flattened gradients slices in the bucket Tensor: the flattened gradients slices in the bucket