mirror of https://github.com/hpcaitech/ColossalAI
fix some typo colossalai/shardformer (#4160)
parent
c77b3b19be
commit
2ac24040eb
|
@ -252,7 +252,7 @@ class ModelSharder:
|
|||
|
||||
def shard(self) -> None:
|
||||
"""
|
||||
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
|
||||
Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
|
||||
"""
|
||||
...
|
||||
|
||||
|
|
|
@ -48,13 +48,13 @@ class DistCrossEntropy(Function):
|
|||
|
||||
# [down, up) => false, other device and -100 => true
|
||||
delta = (global_vocab_size + world_size - 1) // world_size
|
||||
down_shreshold = rank * delta
|
||||
up_shreshold = down_shreshold + delta
|
||||
mask = (target < down_shreshold) | (target >= up_shreshold)
|
||||
masked_target = target.clone() - down_shreshold
|
||||
down_threshold = rank * delta
|
||||
up_threshold = down_threshold + delta
|
||||
mask = (target < down_threshold) | (target >= up_threshold)
|
||||
masked_target = target.clone() - down_threshold
|
||||
masked_target[mask] = 0
|
||||
|
||||
# reshape the logist and target
|
||||
# reshape the logits and target
|
||||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
||||
# reshape the labels to [bath_size * seq_len]
|
||||
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
||||
|
@ -79,7 +79,7 @@ class DistCrossEntropy(Function):
|
|||
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
|
||||
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
|
||||
|
||||
# caculate the softmax
|
||||
# calculate the softmax
|
||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ class Policy(ABC):
|
|||
like BertPolicy for Bert Model or OPTPolicy for OPT model.
|
||||
|
||||
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
|
||||
built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`.
|
||||
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
|
||||
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
|
||||
"""
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ class ModelSharder(object):
|
|||
layer (torch.nn.Module): The object of layer to shard
|
||||
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
|
||||
attr_replacement (Dict): The attribute dict to modify
|
||||
param_replacement (List[Callable]): The function list to get parameter shard information in polic
|
||||
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
||||
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
|
||||
"""
|
||||
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
|
||||
|
|
Loading…
Reference in New Issue