Feature/chatglm (#4240)

* [shardformer] added tests

* [shardformer] vit test finish and support

* [shardformer] chatglm ready

* import chatglm

* [shardformer] add test kit in model zoo for chatglm

* [sharformer] add first version of policy of chatglm

* [shardformer] polish chatglm code

* [shardformer] polish code

* [shardformer] support chatglm without layernorm

* [shardformer] chatglm shard without mlp sharding

* [shardformer] delete some file

* [shardformer] ChatGLM support layernorm sharding

* [shardformer] register without auto policy

* [shardformer] pre-commit check files

* [shardformer] fix chatglm configuration with pre-commit
pull/4445/head
Kun Lin 2023-07-20 17:28:00 +08:00 committed by Hongxin Liu
parent 9ee4ebea83
commit ed34bb1310
6 changed files with 1672 additions and 0 deletions

View File

@ -0,0 +1,96 @@
from typing import Dict, Union
import torch.nn as nn
import colossalai.shardformer.layer as col_nn
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
class ChatGLMModelPolicy(Policy):
def config_sanity_check(self):
pass
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.padded_vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from tests.kit.model_zoo.transformers.chatglm2_6b.modeling_chatglm import ChatGLMModel, GLMBlock
policy = {}
if self.shard_config.enable_tensor_parallelism:
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.projection_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads) //
self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition":
self.model.config.kv_channels * self.model.config.num_attention_heads //
self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
if not self.model.config.rmsnorm:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="input_layernorm", target_module=col_nn.FusedLayerNorm),
SubModuleReplacementDescription(suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm)
],
policy=policy,
target_key=GLMBlock)
if self.model.config.post_layer_norm:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="encoder.final_layernorm",
target_module=col_nn.FusedLayerNorm)
],
policy=policy,
target_key=ChatGLMModel)
return policy
def postprocess(self):
return self.model

View File

@ -1,6 +1,7 @@
from .albert import *
from .bert import *
from .bloom import *
from .chatglm import *
from .gpt import *
from .llama import *
from .opt import *

View File

@ -0,0 +1,38 @@
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
from .chatglm2_6b.configuration_chatglm import ChatGLMConfig
from .chatglm2_6b.modeling_chatglm import ChatGLMModel
# ================================
# Register single-sentence ChatGLM
# ================================
def data_gen():
input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
return dict(input_ids=input_ids, attention_mask=attention_mask)
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_chatglm_model = lambda x: x.last_hidden_state.mean()
loss_fn = lambda x: x.loss
config = ChatGLMConfig(num_layers=1,
padded_vocab_size=65024,
hidden_size=64,
num_attention_heads=8,
rmsnorm=False,
original_rope=True,
use_cache=True)
model_zoo.register(name='transformers_chatglm',
model_fn=lambda: ChatGLMModel(config, empty_init=False),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_chatglm_model,
model_attribute=ModelAttribute(has_control_flow=True))

View File

@ -0,0 +1,58 @@
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
model_type = "chatglm"
def __init__(self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,107 @@
import copy
import os
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.chatglm import ChatGLMModelPolicy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
# do backward
org_loss.backward()
shard_loss.backward()
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
# unwrap model
if org_model.__class__.__name__ == 'ChatGLMModel':
chatglm_model = org_model
shard_chatglm_model = sharded_model
else:
chatglm_model = org_model.transformer
shard_chatglm_model = sharded_model.transformer
# check attention grad
org_grad = chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
shard_grad = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight.grad
shard_weight = shard_chatglm_model.encoder.layers[0].self_attention.query_key_value.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
# check embedding weights
org_grad = chatglm_model.embedding.word_embeddings.weight.grad
shard_grad = shard_chatglm_model.embedding.word_embeddings.weight.grad
shard_weight = shard_chatglm_model.embedding.word_embeddings.weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
else:
all_shard_grad = shard_grad
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create new model
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
if name == "transformers_chatglm":
sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda()
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
def check_chatglm(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_chatglm_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm():
spawn(check_chatglm, 2)
if __name__ == "__main__":
test_chatglm()