fix some typo colossalai/shardformer (#4160)

pull/4167/head
digger yu 2023-07-04 17:53:39 +08:00 committed by GitHub
parent c77b3b19be
commit 2ac24040eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 9 deletions

View File

@ -252,7 +252,7 @@ class ModelSharder:
def shard(self) -> None:
"""
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
"""
...

View File

@ -48,13 +48,13 @@ class DistCrossEntropy(Function):
# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_shreshold = rank * delta
up_shreshold = down_shreshold + delta
mask = (target < down_shreshold) | (target >= up_shreshold)
masked_target = target.clone() - down_shreshold
down_threshold = rank * delta
up_threshold = down_threshold + delta
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
# reshape the logist and target
# reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size)
@ -79,7 +79,7 @@ class DistCrossEntropy(Function):
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
# caculate the softmax
# calculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, mask, masked_target_1d)

View File

@ -66,7 +66,7 @@ class Policy(ABC):
like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`.
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
"""

View File

@ -73,7 +73,7 @@ class ModelSharder(object):
layer (torch.nn.Module): The object of layer to shard
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
attr_replacement (Dict): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in polic
param_replacement (List[Callable]): The function list to get parameter shard information in policy
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
"""
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \