fix some typo colossalai/shardformer (#4160)

pull/4167/head
digger yu 2023-07-04 17:53:39 +08:00 committed by GitHub
parent c77b3b19be
commit 2ac24040eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 9 deletions

View File

@ -252,7 +252,7 @@ class ModelSharder:
def shard(self) -> None: def shard(self) -> None:
""" """
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing. Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
""" """
... ...

View File

@ -48,13 +48,13 @@ class DistCrossEntropy(Function):
# [down, up) => false, other device and -100 => true # [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size delta = (global_vocab_size + world_size - 1) // world_size
down_shreshold = rank * delta down_threshold = rank * delta
up_shreshold = down_shreshold + delta up_threshold = down_threshold + delta
mask = (target < down_shreshold) | (target >= up_shreshold) mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_shreshold masked_target = target.clone() - down_threshold
masked_target[mask] = 0 masked_target[mask] = 0
# reshape the logist and target # reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len] # reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size) logits_2d = vocab_logits.view(-1, partition_vocab_size)
@ -79,7 +79,7 @@ class DistCrossEntropy(Function):
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits) loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
loss = torch.sum(loss).div_(torch.sum(loss != 0.0)) loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
# caculate the softmax # calculate the softmax
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, mask, masked_target_1d) ctx.save_for_backward(exp_logits, mask, masked_target_1d)

View File

@ -66,7 +66,7 @@ class Policy(ABC):
like BertPolicy for Bert Model or OPTPolicy for OPT model. like BertPolicy for Bert Model or OPTPolicy for OPT model.
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`. built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify. If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
""" """

View File

@ -73,7 +73,7 @@ class ModelSharder(object):
layer (torch.nn.Module): The object of layer to shard layer (torch.nn.Module): The object of layer to shard
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name. origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
attr_replacement (Dict): The attribute dict to modify attr_replacement (Dict): The attribute dict to modify
param_replacement (List[Callable]): The function list to get parameter shard information in polic param_replacement (List[Callable]): The function list to get parameter shard information in policy
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
""" """
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \ if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \