From 996c65077eb18f8e58a69ca95f01e908c780b7ea Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Jun 2024 03:32:29 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- colossalai/shardformer/modeling/command.py | 12 +++++++++--- colossalai/shardformer/policies/command.py | 1 - 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index 83f4b97ff..07a7f6cbf 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -3,13 +3,18 @@ import warnings from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, apply_rotary_pos_emb, repeat_kv +from transformers.models.cohere.modeling_cohere import ( + CohereForCausalLM, + CohereModel, + StaticCache, + apply_rotary_pos_emb, + repeat_kv, +) from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -584,6 +589,7 @@ def get_command_flash_attention_model_forward(shard_config, sp_mode=None, sp_siz return forward + def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): from transformers import CohereForCausalLM @@ -683,4 +689,4 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig): attentions=outputs.attentions, ) - return forward \ No newline at end of file + return forward diff --git a/colossalai/shardformer/policies/command.py b/colossalai/shardformer/policies/command.py index 77f96e462..553436f4a 100644 --- a/colossalai/shardformer/policies/command.py +++ b/colossalai/shardformer/policies/command.py @@ -112,7 +112,6 @@ class CommandPolicy(Policy): target_key=CohereModel, ) - if self.shard_config.enable_tensor_parallelism: assert ( self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0