mirror of https://github.com/hpcaitech/ColossalAI
Feature/chatglm (#4240)
* [shardformer] added tests * [shardformer] vit test finish and support * [shardformer] chatglm ready * import chatglm * [shardformer] add test kit in model zoo for chatglm * [sharformer] add first version of policy of chatglm * [shardformer] polish chatglm code * [shardformer] polish code * [shardformer] support chatglm without layernorm * [shardformer] chatglm shard without mlp sharding * [shardformer] delete some file * [shardformer] ChatGLM support layernorm sharding * [shardformer] register without auto policy * [shardformer] pre-commit check files * [shardformer] fix chatglm configuration with pre-commitpull/4445/head
parent
9ee4ebea83
commit
ed34bb1310
|
@ -0,0 +1,96 @@
|
|||
from typing import Dict, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
|
||||
|
||||
|
||||
class ChatGLMModelPolicy(Policy):
|
||||
|
||||
def config_sanity_check(self):
|
||||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.padded_vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
|
||||
|
||||
policy = {}
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
|
||||
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embedding.word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
|
||||
policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={
|
||||
"self_attention.num_attention_heads_per_partition":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.projection_size":
|
||||
(self.model.config.kv_channels * self.model.config.num_attention_heads) //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
"self_attention.qkv_hidden_size":
|
||||
(self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.num_attention_heads_per_partition":
|
||||
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.hidden_size_per_partition":
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads //
|
||||
self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
target_module=col_nn.Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.core_attention.attention_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
if not self.model.config.rmsnorm:
|
||||
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
|
||||
SubModuleReplacementDescription(suffix="post_attention_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GLMBlock)
|
||||
|
||||
if self.model.config.post_layer_norm:
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
|
||||
target_module=col_nn.FusedLayerNorm)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel)
|
||||
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
|
@ -1,6 +1,7 @@
|
|||
from .albert import *
|
||||
from .bert import *
|
||||
from .bloom import *
|
||||
from .chatglm import *
|
||||
from .gpt import *
|
||||
from .llama import *
|
||||
from .opt import *
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
from .chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from .chatglm2_6b.modeling_chatglm import ChatGLMModel
|
||||
|
||||
# ================================
|
||||
# Register single-sentence ChatGLM
|
||||
# ================================
|
||||
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64)
|
||||
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
|
||||
# define output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# define loss function
|
||||
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn = lambda x: x.loss
|
||||
config = ChatGLMConfig(num_layers=1,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
num_attention_heads=8,
|
||||
rmsnorm=False,
|
||||
original_rope=True,
|
||||
use_cache=True)
|
||||
|
||||
model_zoo.register(name='transformers_chatglm',
|
||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_chatglm_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
|
@ -0,0 +1,58 @@
|
|||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class ChatGLMConfig(PretrainedConfig):
|
||||
model_type = "chatglm"
|
||||
|
||||
def __init__(self,
|
||||
num_layers=28,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=4096,
|
||||
ffn_hidden_size=13696,
|
||||
kv_channels=128,
|
||||
num_attention_heads=32,
|
||||
seq_length=2048,
|
||||
hidden_dropout=0.0,
|
||||
attention_dropout=0.0,
|
||||
layernorm_epsilon=1e-5,
|
||||
rmsnorm=True,
|
||||
apply_residual_connection_post_layernorm=False,
|
||||
post_layer_norm=True,
|
||||
add_bias_linear=False,
|
||||
add_qkv_bias=False,
|
||||
bias_dropout_fusion=True,
|
||||
multi_query_attention=False,
|
||||
multi_query_group_num=1,
|
||||
apply_query_key_layer_scaling=True,
|
||||
attention_softmax_in_fp32=True,
|
||||
fp32_residual_connection=False,
|
||||
quantization_bit=0,
|
||||
pre_seq_len=None,
|
||||
prefix_projection=False,
|
||||
**kwargs):
|
||||
self.num_layers = num_layers
|
||||
self.vocab_size = padded_vocab_size
|
||||
self.padded_vocab_size = padded_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.ffn_hidden_size = ffn_hidden_size
|
||||
self.kv_channels = kv_channels
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.seq_length = seq_length
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layernorm_epsilon = layernorm_epsilon
|
||||
self.rmsnorm = rmsnorm
|
||||
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
||||
self.post_layer_norm = post_layer_norm
|
||||
self.add_bias_linear = add_bias_linear
|
||||
self.add_qkv_bias = add_qkv_bias
|
||||
self.bias_dropout_fusion = bias_dropout_fusion
|
||||
self.multi_query_attention = multi_query_attention
|
||||
self.multi_query_group_num = multi_query_group_num
|
||||
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
|
||||
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
|
||||
self.fp32_residual_connection = fp32_residual_connection
|
||||
self.quantization_bit = quantization_bit
|
||||
self.pre_seq_len = pre_seq_len
|
||||
self.prefix_projection = prefix_projection
|
||||
super().__init__(**kwargs)
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,107 @@
|
|||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.testing import (
|
||||
assert_hf_output_close,
|
||||
clear_cache_before_run,
|
||||
parameterize,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
# do backward
|
||||
org_loss.backward()
|
||||
shard_loss.backward()
|
||||
|
||||
assert torch.allclose(org_loss, shard_loss,
|
||||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||
|
||||
# unwrap model
|
||||
if org_model.__class__.__name__ == 'ChatGLMModel':
|
||||
chatglm_model = org_model
|
||||
shard_chatglm_model = sharded_model
|
||||
else:
|
||||
chatglm_model = org_model.transformer
|
||||
shard_chatglm_model = sharded_model.transformer
|
||||
|
||||
# check attention grad
|
||||
org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
|
||||
shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
|
||||
shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||
|
||||
# check embedding weights
|
||||
org_grad = chatglm_model.embedding.word_embeddings.weight.grad
|
||||
shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad
|
||||
shard_weight = shard_chatglm_model.embedding.word_embeddings.weight
|
||||
|
||||
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
||||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||
torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||
else:
|
||||
all_shard_grad = shard_grad
|
||||
|
||||
assert torch.allclose(org_grad, all_shard_grad,
|
||||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
|
||||
|
||||
|
||||
@parameterize('enable_fused_normalization', [True, False])
|
||||
@parameterize('enable_tensor_parallelism', [True, False])
|
||||
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create new model
|
||||
org_model = model_fn().cuda()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
|
||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
if name == "transformers_chatglm":
|
||||
sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda()
|
||||
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def check_chatglm(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_chatglm_test()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_chatglm():
|
||||
spawn(check_chatglm, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_chatglm()
|
Loading…
Reference in New Issue