|
|
|
@ -3,13 +3,18 @@ import warnings
|
|
|
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
import torch.utils.checkpoint |
|
|
|
|
from torch import nn |
|
|
|
|
from torch.nn import CrossEntropyLoss |
|
|
|
|
from transformers.cache_utils import Cache, DynamicCache |
|
|
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
|
|
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv |
|
|
|
|
from transformers.models.cohere.modeling_cohere import ( |
|
|
|
|
CohereForCausalLM, |
|
|
|
|
CohereModel, |
|
|
|
|
StaticCache, |
|
|
|
|
apply_rotary_pos_emb, |
|
|
|
|
repeat_kv, |
|
|
|
|
) |
|
|
|
|
from transformers.utils import logging |
|
|
|
|
|
|
|
|
|
from colossalai.pipeline.stage_manager import PipelineStageManager |
|
|
|
@ -584,6 +589,7 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
|
|
|
|
|
|
|
|
|
|
return forward |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): |
|
|
|
|
from transformers import CohereForCausalLM |
|
|
|
|
|
|
|
|
|