mirror of https://github.com/hpcaitech/ColossalAI
[format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: github-actions <github-actions@github.com>pull/4309/merge
parent
19e1a5cf16
commit
e6707a6e8d
|
@ -1302,7 +1302,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
)
|
||||
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
|
|
|
@ -15,10 +15,8 @@ from transformers.utils import logging
|
|||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
|
||||
from ..layer import ColoAttention, cross_entropy_1d
|
||||
|
||||
|
||||
try:
|
||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||
|
||||
|
|
|
@ -291,13 +291,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
GPT2LMHeadModel: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output}
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": not self.shard_config.parallel_output},
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
if self.shard_config.parallel_output:
|
||||
addon_module[GPT2LMHeadModel].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
||||
addon_module[GPT2LMHeadModel].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
|
|
|
@ -265,12 +265,18 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output})
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"gather_output": not self.shard_config.parallel_output},
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
if self.shard_config.parallel_output:
|
||||
new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
||||
new_item[LlamaForCausalLM].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
|
|
Loading…
Reference in New Issue