mirror of https://github.com/hpcaitech/ColossalAI
change 'xxx if xxx else None' to 'xxx or None'
parent
a83a2336e8
commit
d84d68601a
|
@ -3,13 +3,18 @@ import warnings
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
from transformers.cache_utils import Cache, DynamicCache
|
from transformers.cache_utils import Cache, DynamicCache
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||||
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv
|
from transformers.models.cohere.modeling_cohere import (
|
||||||
|
CohereForCausalLM,
|
||||||
|
CohereModel,
|
||||||
|
StaticCache,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
@ -584,6 +589,7 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
||||||
|
|
||||||
return forward
|
return forward
|
||||||
|
|
||||||
|
|
||||||
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
||||||
from transformers import CohereForCausalLM
|
from transformers import CohereForCausalLM
|
||||||
|
|
||||||
|
|
|
@ -67,7 +67,7 @@ class BertPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for Bert"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -50,7 +50,7 @@ class BloomPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for BLOOM"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -57,7 +57,7 @@ class ChatGLMPolicy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for ChatGLM2"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -73,11 +73,9 @@ class CommandPolicy(Policy):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
f"For Command, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||||
)
|
)
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = (
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None
|
|
||||||
)
|
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
|
@ -112,7 +110,6 @@ class CommandPolicy(Policy):
|
||||||
target_key=CohereModel,
|
target_key=CohereModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
assert (
|
assert (
|
||||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
|
|
@ -65,7 +65,7 @@ class GPT2Policy(Policy):
|
||||||
else:
|
else:
|
||||||
norm_cls = col_nn.LayerNorm
|
norm_cls = col_nn.LayerNorm
|
||||||
|
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
|
assert sp_mode != "all_to_all", "all_to_all sequence parallelism is not supported for GPT2"
|
||||||
if sp_mode == "ring":
|
if sp_mode == "ring":
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
|
|
|
@ -73,11 +73,9 @@ class LlamaPolicy(Policy):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
f"For llama, sequence parallelism is currently not compatible with pipeline parallelism, set to be False"
|
||||||
)
|
)
|
||||||
sp_mode = self.shard_config.sequence_parallelism_mode if self.shard_config.enable_sequence_parallelism else None
|
sp_mode = self.shard_config.sequence_parallelism_mode or None
|
||||||
sp_size = self.shard_config.sequence_parallel_size if self.shard_config.enable_sequence_parallelism else None
|
sp_size = self.shard_config.sequence_parallel_size or None
|
||||||
sp_group = (
|
sp_group = self.shard_config.sequence_parallel_process_group or None
|
||||||
self.shard_config.sequence_parallel_process_group if self.shard_config.enable_sequence_parallelism else None
|
|
||||||
)
|
|
||||||
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
sp_partial_derived = sp_mode in ["split_gather", "ring"]
|
||||||
|
|
||||||
if sp_mode == "all_to_all":
|
if sp_mode == "all_to_all":
|
||||||
|
|
Loading…
Reference in New Issue