mirror of https://github.com/hpcaitech/ColossalAI
[format] applied code formatting on changed files in pull request 5510 (#5517)
Co-authored-by: github-actions <github-actions@github.com>pull/4309/merge
parent
19e1a5cf16
commit
e6707a6e8d
|
@ -1302,7 +1302,6 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if not shard_config.parallel_output:
|
if not shard_config.parallel_output:
|
||||||
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
||||||
|
|
||||||
|
|
|
@ -15,10 +15,8 @@ from transformers.utils import logging
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
|
|
||||||
from ..layer import ColoAttention, cross_entropy_1d
|
from ..layer import ColoAttention, cross_entropy_1d
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
from transformers.models.llama.modeling_llama import _prepare_4d_causal_attention_mask
|
||||||
|
|
||||||
|
|
|
@ -291,13 +291,17 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
||||||
GPT2LMHeadModel: ModulePolicyDescription(
|
GPT2LMHeadModel: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output}
|
suffix="lm_head",
|
||||||
|
target_module=col_nn.Linear1D_Col,
|
||||||
|
kwargs={"gather_output": not self.shard_config.parallel_output},
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if self.shard_config.parallel_output:
|
if self.shard_config.parallel_output:
|
||||||
addon_module[GPT2LMHeadModel].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
addon_module[GPT2LMHeadModel].method_replacement = {
|
||||||
|
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||||
|
}
|
||||||
module_policy.update(addon_module)
|
module_policy.update(addon_module)
|
||||||
|
|
||||||
if self.pipeline_stage_manager is not None:
|
if self.pipeline_stage_manager is not None:
|
||||||
|
|
|
@ -265,12 +265,18 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
new_item = {
|
new_item = {
|
||||||
LlamaForCausalLM: ModulePolicyDescription(
|
LlamaForCausalLM: ModulePolicyDescription(
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": not self.shard_config.parallel_output})
|
SubModuleReplacementDescription(
|
||||||
|
suffix="lm_head",
|
||||||
|
target_module=Linear1D_Col,
|
||||||
|
kwargs={"gather_output": not self.shard_config.parallel_output},
|
||||||
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if self.shard_config.parallel_output:
|
if self.shard_config.parallel_output:
|
||||||
new_item[LlamaForCausalLM].method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)}
|
new_item[LlamaForCausalLM].method_replacement = {
|
||||||
|
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||||
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
|
|
||||||
if self.pipeline_stage_manager:
|
if self.pipeline_stage_manager:
|
||||||
|
|
Loading…
Reference in New Issue