mirror of https://github.com/hpcaitech/ColossalAI
fix some typo colossalai/shardformer (#4160)
parent
c77b3b19be
commit
2ac24040eb
|
@ -252,7 +252,7 @@ class ModelSharder:
|
||||||
|
|
||||||
def shard(self) -> None:
|
def shard(self) -> None:
|
||||||
"""
|
"""
|
||||||
Shard model with parallelelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
|
Shard model with parallelism with the help of pre-processing, replace_model_class, replace_module, and post-processing.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
|
@ -48,13 +48,13 @@ class DistCrossEntropy(Function):
|
||||||
|
|
||||||
# [down, up) => false, other device and -100 => true
|
# [down, up) => false, other device and -100 => true
|
||||||
delta = (global_vocab_size + world_size - 1) // world_size
|
delta = (global_vocab_size + world_size - 1) // world_size
|
||||||
down_shreshold = rank * delta
|
down_threshold = rank * delta
|
||||||
up_shreshold = down_shreshold + delta
|
up_threshold = down_threshold + delta
|
||||||
mask = (target < down_shreshold) | (target >= up_shreshold)
|
mask = (target < down_threshold) | (target >= up_threshold)
|
||||||
masked_target = target.clone() - down_shreshold
|
masked_target = target.clone() - down_threshold
|
||||||
masked_target[mask] = 0
|
masked_target[mask] = 0
|
||||||
|
|
||||||
# reshape the logist and target
|
# reshape the logits and target
|
||||||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
||||||
# reshape the labels to [bath_size * seq_len]
|
# reshape the labels to [bath_size * seq_len]
|
||||||
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
||||||
|
@ -79,7 +79,7 @@ class DistCrossEntropy(Function):
|
||||||
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
|
loss = torch.where(target == ignore_index, 0.0, torch.log(sum_exp_logits) - pred_logits)
|
||||||
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
|
loss = torch.sum(loss).div_(torch.sum(loss != 0.0))
|
||||||
|
|
||||||
# caculate the softmax
|
# calculate the softmax
|
||||||
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ class Policy(ABC):
|
||||||
like BertPolicy for Bert Model or OPTPolicy for OPT model.
|
like BertPolicy for Bert Model or OPTPolicy for OPT model.
|
||||||
|
|
||||||
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
|
Shardformer has provided many built-in sharding policies for the mainstream models. You can use the
|
||||||
built-in policies by setting `policy = None`, which is already the default arguemnt for `Shardformer.optimize`.
|
built-in policies by setting `policy = None`, which is already the default argument for `Shardformer.optimize`.
|
||||||
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
|
If you want to define your own policy, you can inherit from this class and overwrite the methods you want to modify.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -73,7 +73,7 @@ class ModelSharder(object):
|
||||||
layer (torch.nn.Module): The object of layer to shard
|
layer (torch.nn.Module): The object of layer to shard
|
||||||
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
|
origin_cls (Union[str, torch.nn.Module]): The origin layer class or a string of layer class name.
|
||||||
attr_replacement (Dict): The attribute dict to modify
|
attr_replacement (Dict): The attribute dict to modify
|
||||||
param_replacement (List[Callable]): The function list to get parameter shard information in polic
|
param_replacement (List[Callable]): The function list to get parameter shard information in policy
|
||||||
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
|
sub_module_replacement (List[Callable]): The function list to get sub module shard information in policy
|
||||||
"""
|
"""
|
||||||
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
|
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
|
||||||
|
|
Loading…
Reference in New Issue