fix precommit

pull/5818/head
GuangyaoZhang 2024-06-14 08:09:24 +00:00
parent 1016bb3257
commit 9a290ab013
7 changed files with 35 additions and 86 deletions

View File

@ -4,7 +4,7 @@ from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm, CohereLayerNorm, FusedCohereLayerNorm
from .normalization import CohereLayerNorm, FusedCohereLayerNorm, FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import FusedLinear1D_Col, GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row

View File

@ -250,7 +250,6 @@ class FusedLayerNorm(BaseLayerNorm):
return layernorm
class CohereLayerNorm(BaseLayerNorm):
r"""
This is a wrapper around the transformers.models.cohere.CohereLayerNorm. It is meant to be used only with the from_native_module interface.

View File

@ -3,22 +3,12 @@ import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import CrossEntropyLoss
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.models.cohere.modeling_cohere import (
CohereForCausalLM,
CohereModel,
StaticCache,
repeat_kv,
)
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.cohere.modeling_cohere import CohereForCausalLM, CohereModel, StaticCache, repeat_kv
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
@ -343,10 +333,9 @@ class CommandPipelineForwards:
hidden_states = outputs.get("hidden_states")
return {"hidden_states": hidden_states}
def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb
from transformers.models.cohere.modeling_cohere import repeat_kv
def get_command_flash_attention_forward(shard_config, sp_mode, sp_group, sp_size):
from transformers.models.cohere.modeling_cohere import CohereAttention, apply_rotary_pos_emb, repeat_kv
def forward(
self: CohereAttention,
@ -728,7 +717,6 @@ def get_command_seq_parallel_attention_forward(sp_mode, sp_size, sp_group):
else:
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:

View File

@ -7,12 +7,12 @@ from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import (
CohereLayerNorm,
FusedCohereLayerNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
CohereLayerNorm,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
@ -383,7 +383,9 @@ class CommandForCausalLMPolicy(CommandPolicy):
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(
model_cls=CohereForCausalLM, new_forward=CommandPipelineForwards.command_for_causal_lm_forward, policy=policy
model_cls=CohereForCausalLM,
new_forward=CommandPipelineForwards.command_for_causal_lm_forward,
policy=policy,
)
return policy

View File

@ -1,59 +0,0 @@
diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py
index 5aa21260..01453a05 100644
--- a/colossalai/shardformer/layer/normalization.py
+++ b/colossalai/shardformer/layer/normalization.py
@@ -165,7 +165,7 @@ class LayerNorm(BaseLayerNorm):
Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm.
"""
- assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
+ # assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
LazyInitContext.materialize(module)
@@ -174,7 +174,7 @@ class LayerNorm(BaseLayerNorm):
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(module.weight)
- SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(module.bias)
return module
@@ -209,9 +209,12 @@ class FusedLayerNorm(BaseLayerNorm):
LazyInitContext.materialize(module)
# get the attributes of the module
- normalized_shape = module.normalized_shape
- eps = module.eps
- elementwise_affine = module.elementwise_affine
+ # normalized_shape = module.normalized_shape
+ # eps = module.eps
+ # elementwise_affine = module.elementwise_affine
+ normalized_shape = module.weight.size(0)
+ eps = module.variance_epsilon
+ elementwise_affine = True
dtype = module.weight.dtype
device = module.weight.device
@@ -244,7 +247,7 @@ class FusedLayerNorm(BaseLayerNorm):
# aggregation of these gradients is necessary during backpropagation.
# Therefore, we annotate these parameters in advance to indicate the need for gradient aggregation.
SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.weight)
- SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
+ # SeqParallelUtils.marked_as_sp_partial_derived_param(layernorm.bias)
return layernorm
diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py
index 6075f836..a7166e38 100644
--- a/tests/test_shardformer/test_model/test_shard_command.py
+++ b/tests/test_shardformer/test_model/test_shard_command.py
@@ -210,6 +210,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
],
)
def run_command_test(test_config):
+ print(test_config)
sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_casual_lm")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():

View File

@ -16,8 +16,6 @@ if HAS_COMMAND:
# ===============================
def data_gen():
input_ids = torch.Tensor(
[
[1, 15043, 29892, 590, 11203, 338, 274, 1082, 1, 15043, 29892, 590, 11203, 338, 274, 1082],

View File

@ -79,10 +79,24 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
command_model, shard_command_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
command_model,
shard_command_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
col_layer_grads = get_grad_tensors_for_check(
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
command_model,
shard_command_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
norm_layer_grads = get_grad_tensors_for_check(
command_model,
@ -121,7 +135,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
else:
atol, rtol = 5e-3, 5e-3
check_weight(
command_model, shard_command_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
command_model,
shard_command_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
# check grads