Browse Source

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
pull/5818/head^2
pre-commit-ci[bot] 5 months ago
parent
commit
996c65077e
  1. 10
      colossalai/shardformer/modeling/command.py
  2. 1
      colossalai/shardformer/policies/command.py

10
colossalai/shardformer/modeling/command.py

@ -3,13 +3,18 @@ import warnings
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv from transformers.models.cohere.modeling_cohere import (
CohereForCausalLM,
CohereModel,
StaticCache,
apply_rotary_pos_emb,
repeat_kv,
)
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
@ -584,6 +589,7 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz
return forward return forward
def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
from transformers import CohereForCausalLM from transformers import CohereForCausalLM

1
colossalai/shardformer/policies/command.py

@ -112,7 +112,6 @@ class CommandPolicy(Policy):
target_key=CohereModel, target_key=CohereModel,
) )
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
assert ( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0

Loading…
Cancel
Save