mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support ChatGLMForConditionalGeneration & add fusedlayernorm for vit
parent
4da05052f4
commit
8120eca0c0
|
@ -90,7 +90,31 @@ class ChatGLMModelPolicy(Policy):
|
|||
policy=policy,
|
||||
target_key=ChatGLMModel)
|
||||
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedRMSNorm),
|
||||
SubModuleReplacementDescription(suffix="post_attention_layernorm",
|
||||
target_module=col_nn.FusedRMSNorm)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GLMBlock)
|
||||
|
||||
if self.model.config.post_layer_norm:
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
|
||||
target_module=col_nn.FusedRMSNorm)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
||||
class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
|
||||
|
||||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
return policy
|
||||
|
|
|
@ -23,7 +23,7 @@ class ViTPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer
|
||||
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer, ViTModel
|
||||
|
||||
policy = {}
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import transformers
|
|||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
from .chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from .chatglm2_6b.modeling_chatglm import ChatGLMModel
|
||||
from .chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
# ================================
|
||||
# Register single-sentence ChatGLM
|
||||
|
@ -21,7 +21,7 @@ output_transform_fn = lambda x: x
|
|||
|
||||
# define loss function
|
||||
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn = lambda x: x.loss
|
||||
loss_fn = lambda x: x.logits.mean()
|
||||
config = ChatGLMConfig(num_layers=1,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
|
@ -36,3 +36,10 @@ model_zoo.register(name='transformers_chatglm',
|
|||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_chatglm_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
||||
model_zoo.register(name="transformers_chatglm_for_conditional_generation",
|
||||
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch
|
|||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy
|
||||
from colossalai.shardformer.policies.chatglm import ChatGLMForConditionalGenerationPolicy, ChatGLMModelPolicy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
|
@ -85,6 +85,8 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
|
|||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
if name == "transformers_chatglm":
|
||||
sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda()
|
||||
else:
|
||||
sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda()
|
||||
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
|
Loading…
Reference in New Issue