mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] Correct several erroneous code comments (#4794)
parent
54b3ad8924
commit
11f1e426fe
|
@ -50,7 +50,7 @@ class ModulePolicyDescription:
|
||||||
new_weight = shard_rowwise(weight, process_group)
|
new_weight = shard_rowwise(weight, process_group)
|
||||||
module.weight = torch.nn.Parameter(new_weight)
|
module.weight = torch.nn.Parameter(new_weight)
|
||||||
```
|
```
|
||||||
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a ParamReplacementDescription
|
sub_module_replacement (List[SubModuleReplacementDescription]): each element in the list is a SubModuleReplacementDescription
|
||||||
object which specifies the module to be replaced and the target module used to replacement.
|
object which specifies the module to be replaced and the target module used to replacement.
|
||||||
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
|
method_replace (Dict[str, Callable]): key is the method name, value is the method for replacement
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -92,7 +92,7 @@ class BucketStore(BaseStore):
|
||||||
|
|
||||||
def get_flatten_grad(self) -> Tensor:
|
def get_flatten_grad(self) -> Tensor:
|
||||||
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
|
"""Return the flattened gradients slices in the bucket, the data orginization of the flattened tensor:
|
||||||
[grad0_rank0, grad1_rank0, ..., grad_1_rank0, grad1_rank1, ....]
|
[grad0_rank0, grad1_rank0, ..., grad_0_rank1, grad1_rank1, ....]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: the flattened gradients slices in the bucket
|
Tensor: the flattened gradients slices in the bucket
|
||||||
|
|
Loading…
Reference in New Issue