[shardformer] fix chatglm implementation (#5644)

* [shardformer] fix chatglm policy

* [shardformer] fix chatglm flash attn

* [shardformer] update readme

* [shardformer] fix chatglm init

* [shardformer] fix chatglm test

* [pipeline] fix chatglm merge batch
pull/5654/head
Hongxin Liu 2024-04-25 14:41:17 +08:00 committed by GitHub
parent 5d88ef1aaf
commit bbb2c21f16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 193 additions and 117 deletions

View File

@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_forward(output_obj) self.send_forward(output_obj)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward( def run_forward_backward(
@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None: if outputs is not None:
outputs = merge_batch(outputs) if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs} return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step( def forward_backward_step(

View File

@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
- [x] Unit Testing - [x] Unit Testing
- [ ] Policy Implementation - [ ] Policy Implementation
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | | model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | |:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | | bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | | t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | | llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | | gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | | opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | | bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | | chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | | vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | | whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | | sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | | blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | | falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | | mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
## 💡 API Design ## 💡 API Design
@ -391,6 +391,43 @@ _POLICY_LIST = {
} }
``` ```
#### How to support those models in huggingface model hub but not in the transformers library
There are two cases:
1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".
Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.
Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
E.g. for llama:
```python
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
```
for chatglm2:
```python
policy["GLMBlock"] = ModulePolicyDescription(...)
```
Then when registering such models in the autopolicy, we should follow below format:
```python
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
file_name="<policy_filename>", class_name="<policy_class_name>"
)
```
As for chatglm2 model, it should be:
```python
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
)
```
When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
### Write Your Unit Testing ### Write Your Unit Testing
This section serves as the guideline for testing the `shardformer` module. This section serves as the guideline for testing the `shardformer` module.
@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length. We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
In the case of using 2 GPUs, the training times are as follows. In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model | | N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: | |:-----:|:---------:|:-----------:|
| 256 | 11.2ms | 17.2ms | | 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms | | 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms | | 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms | | 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms | | 4096 | 160.5ms | 90.4ms |
<p align="center"> <p align="center">
@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
In the case of using 4 GPUs, the training times are as follows. In the case of using 4 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model | | N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: | |:-----:|:---------:|:-----------:|
| 256 | 10.0ms | 21.1ms | | 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms | | 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms | | 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms | | 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms | | 4096 | 160.4ms | 68.0ms |
@ -475,10 +512,10 @@ warmup_fraction = 0.03
| accuracy | f1 | loss | GPU number | model sharded | | accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: | |:--------:|:-------:|:-------:|:----------:|:-------------:|
| 0.82971 | 0.87713 | 0.23194 | 4 | True | | 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True | | 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False | | 0.84521 | 0.88700 | 0.21822 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence. Overall, the results demonstrate that using shardformers during model training does not affect the convergence.

View File

@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
) )
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]: # try to get normalized_shape, eps, elementwise_affine from the module
normalized_shape = module.weight.shape[0] normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
eps = module.variance_epsilon eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = True elementwise_affine = getattr(module, "elementwise_affine", True)
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine
rmsnorm = FusedRMSNormWithHook( rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
) )
rmsnorm.weight = module.weight rmsnorm.weight = module.weight

View File

@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
def get_flash_core_attention_forward(): def get_flash_core_attention_forward():
@ -31,7 +30,12 @@ def get_flash_core_attention_forward():
device=query_layer.device, device=query_layer.device,
) )
temp_mask = ( temp_mask = (
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device) torch.ones(
query_layer.shape[2],
key_layer.shape[2],
dtype=torch.bool,
device=query_layer.device,
)
.tril(diagonal=0) .tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1) .expand(query_layer.shape[0], 1, -1, -1)
) )
@ -49,6 +53,7 @@ def get_flash_core_attention_forward():
attention_mask=attn_bias, attention_mask=attn_bias,
attention_mask_type=attention_mask_type, attention_mask_type=attention_mask_type,
dropout_p=dropout_p, dropout_p=dropout_p,
scale=1.0 / self.norm_factor,
) )
context_layer = context_layer.permute(2, 0, 1, 3) context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
@staticmethod @staticmethod
def chatglm_model_forward( def chatglm_model_forward(
self: ChatGLMModel, self: "ChatGLMModel",
input_ids, input_ids,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None, attention_mask: Optional[torch.BoolTensor] = None,
@ -194,7 +199,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather": if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward( hidden_states = split_forward_gather_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
for idx in range(start_idx, end_idx): for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx) layer = self.encoder._get_layer(idx)
@ -224,7 +231,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism: if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather": if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward( hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
) )
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
@ -254,7 +263,7 @@ class ChatGLMPipelineForwards:
@staticmethod @staticmethod
def chatglm_for_conditional_generation_forward( def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration, self: "ChatGLMForConditionalGeneration",
input_ids: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,

View File

@ -151,10 +151,10 @@ _POLICY_LIST = {
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy" file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
), ),
# ChatGLM # ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( "transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMModelPolicy" file_name="chatglm2", class_name="ChatGLMModelPolicy"
), ),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( "transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
), ),
# Falcon # Falcon
@ -202,6 +202,13 @@ def _fullname(obj):
module = klass.__module__ module = klass.__module__
if module == "builtins": if module == "builtins":
return klass.__qualname__ # avoid outputs like 'builtins.str' return klass.__qualname__ # avoid outputs like 'builtins.str'
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
module = f"{split_module[0]}.{split_module[-1]}"
return module + "." + klass.__qualname__ return module + "." + klass.__qualname__
@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy:
if policy_location is None: if policy_location is None:
raise NotImplementedError( raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
) )
else: else:
policy = import_policy(policy_location) policy = import_policy(policy_location)

View File

@ -7,7 +7,6 @@ from torch import Tensor
import colossalai.shardformer.layer as col_nn import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..modeling.chatglm2 import ( from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_forward_fn, get_chatglm_sequence_parallel_forward_fn,
@ -17,7 +16,11 @@ from ..modeling.chatglm2 import (
from ..modeling.jit import get_jit_fused_dropout_add_func from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"] __all__ = [
"ChatGLMPolicy",
"ChatGLMModelPolicy",
"ChatGLMForConditionalGenerationPolicy",
]
class ChatGLMPolicy(Policy): class ChatGLMPolicy(Policy):
@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy):
return self.model return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
policy = {} policy = {}
embedding_cls = None embedding_cls = None
@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather" sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
policy[GLMBlock] = ModulePolicyDescription( assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
attn_kwargs = {
"self_attention.qkv_hidden_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
)
// self.shard_config.tensor_parallel_size,
}
if self.model.config.multi_query_attention:
assert (
self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0
), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = (
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
)
attn_kwargs["self_attention.qkv_hidden_size"] = (
self.model.config.kv_channels * self.model.config.num_attention_heads
+ 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num
) // self.shard_config.tensor_parallel_size
policy["GLMBlock"] = ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy):
self.model.config.kv_channels * self.model.config.num_attention_heads self.model.config.kv_channels * self.model.config.num_attention_heads
) )
// self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
)
// self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads "self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels "self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
* self.model.config.num_attention_heads * self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size, // self.shard_config.tensor_parallel_size,
**attn_kwargs,
}, },
param_replacement=[], param_replacement=[],
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.query_key_value", suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col, target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap}, kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"overlap": overlap,
},
), ),
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="self_attention.dense", suffix="self_attention.dense",
@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy):
), ),
], ],
policy=policy, policy=policy,
target_key=ChatGLMModel, target_key="ChatGLMModel",
) )
# optimization configuration # optimization configuration
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy):
), ),
], ],
policy=policy, policy=policy,
target_key=GLMBlock, target_key="GLMBlock",
) )
if self.model.config.post_layer_norm: if self.model.config.post_layer_norm:
@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy):
) )
], ],
policy=policy, policy=policy,
target_key=ChatGLMModel, target_key="ChatGLMModel",
) )
# use flash attention # use flash attention
@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy):
"forward": get_flash_core_attention_forward(), "forward": get_flash_core_attention_forward(),
}, },
policy=policy, policy=policy,
target_key=CoreAttention, target_key="CoreAttention",
) )
# use sequence parallel # use sequence parallel
@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy):
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
policy=policy, policy=policy,
target_key=ChatGLMModel, target_key="ChatGLMModel",
) )
# use jit fused operator # use jit fused operator
@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy):
"dropout_add": get_jit_fused_dropout_add_func(), "dropout_add": get_jit_fused_dropout_add_func(),
}, },
policy=policy, policy=policy,
target_key=GLMBlock, target_key="GLMBlock",
) )
return policy return policy
@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy):
stage_index = stage_manager.get_stage_index(layers_per_stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
) )
} }
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy model_cls="ChatGLMModel",
new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
policy=policy,
) )
return policy return policy
@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward( self.set_pipeline_forward(
model_cls=ChatGLMForConditionalGeneration, model_cls="ChatGLMForConditionalGeneration",
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
policy=policy, policy=policy,
) )

View File

@ -310,13 +310,6 @@ if dist.get_world_size() > 1:
2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer. 2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
```python
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
```
when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
## How Shardformer Works ## How Shardformer Works
### Main Idea ### Main Idea

View File

@ -303,13 +303,6 @@ if dist.get_world_size() > 1:
2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时请确保labels的总数为张量并行度的整数倍否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。 2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时请确保labels的总数为张量并行度的整数倍否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
3. 训练ChatGLM-2 6B的情况有点特殊由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时请通过以下方式导入config/model的类
```python
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
```
并且使用这些导入的类初始化模型。
## Shardformer的工作原理 ## Shardformer的工作原理

View File

@ -1,7 +1,6 @@
import torch import torch
from torch.nn import init
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig from transformers import AutoConfig, AutoModelForCausalLM
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..registry import ModelAttribute, model_zoo from ..registry import ModelAttribute, model_zoo
@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
) )
loss_fn = lambda x: x["loss"] loss_fn = lambda x: x["loss"]
config = ChatGLMConfig( config = AutoConfig.from_pretrained(
"THUDM/chatglm2-6b",
trust_remote_code=True,
num_layers=2, num_layers=2,
padded_vocab_size=65024, padded_vocab_size=65024,
hidden_size=64, hidden_size=64,
ffn_hidden_size=214,
num_attention_heads=8, num_attention_heads=8,
kv_channels=16, kv_channels=16,
rmsnorm=True, rmsnorm=True,
original_rope=True, original_rope=True,
use_cache=True, use_cache=True,
multi_query_attention=False,
torch_dtype=torch.float32, torch_dtype=torch.float32,
) )
infer_config = ChatGLMConfig(
infer_config = AutoConfig.from_pretrained(
"THUDM/chatglm2-6b",
trust_remote_code=True,
num_layers=2, num_layers=2,
padded_vocab_size=65024, padded_vocab_size=65024,
hidden_size=128, hidden_size=128,
@ -60,18 +66,18 @@ infer_config = ChatGLMConfig(
torch_dtype=torch.float32, torch_dtype=torch.float32,
) )
model_zoo.register(
name="transformers_chatglm", def init_chatglm():
model_fn=lambda: ChatGLMModel(config, empty_init=False), model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
data_gen_fn=data_gen, for m in model.modules():
output_transform_fn=output_transform_fn, if m.__class__.__name__ == "RMSNorm":
loss_fn=loss_fn_for_chatglm_model, init.ones_(m.weight)
model_attribute=ModelAttribute(has_control_flow=True), return model
)
model_zoo.register( model_zoo.register(
name="transformers_chatglm_for_conditional_generation", name="transformers_chatglm_for_conditional_generation",
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), model_fn=init_chatglm,
data_gen_fn=data_gen_for_conditional_generation, data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn, loss_fn=loss_fn,

View File

@ -227,7 +227,7 @@ def check_output_hidden_state(
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol) assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
def check_weight( def check_weight(

View File

@ -11,6 +11,7 @@ from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin, build_model_from_hybrid_plugin,
check_all_grad_tensors, check_all_grad_tensors,
check_loss, check_loss,
check_output_hidden_state,
check_weight, check_weight,
get_grad_tensors_for_check, get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin, run_forward_backward_with_hybrid_plugin,
@ -103,8 +104,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
# TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong # TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong
# if org_model.__class__.__name__ == "ChatGLMModel": if org_model.__class__.__name__ == "ChatGLMModel":
# check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
@ -177,14 +178,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{ {
"tp_size": 4, "tp_size": 4,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
}, },
{ {
"tp_size": 2, "tp_size": 2,
"pp_size": 1, "pp_size": 1,
"enable_all_optimization": True, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
}, },