[shardformer] fix chatglm implementation (#5644)

* [shardformer] fix chatglm policy

* [shardformer] fix chatglm flash attn

* [shardformer] update readme

* [shardformer] fix chatglm init

* [shardformer] fix chatglm test

* [pipeline] fix chatglm merge batch
pull/5654/head
Hongxin Liu 2024-04-25 14:41:17 +08:00 committed by GitHub
parent 5d88ef1aaf
commit bbb2c21f16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 193 additions and 117 deletions

View File

@ -7,7 +7,7 @@ from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.accelerator import get_accelerator
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils import get_current_device
@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.send_forward(output_obj)
if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}
def run_forward_backward(
@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
batch_size_dim = getattr(model, "batch_size_dim", 0)
outputs = merge_batch(outputs, batch_size_dim)
return {"loss": accum_loss, "outputs": outputs}
def forward_backward_step(

View File

@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
- [x] Unit Testing
- [ ] Policy Implementation
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
## 💡 API Design
@ -391,6 +391,43 @@ _POLICY_LIST = {
}
```
#### How to support those models in huggingface model hub but not in the transformers library
There are two cases:
1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".
Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.
Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
E.g. for llama:
```python
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
```
for chatglm2:
```python
policy["GLMBlock"] = ModulePolicyDescription(...)
```
Then when registering such models in the autopolicy, we should follow below format:
```python
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
file_name="<policy_filename>", class_name="<policy_class_name>"
)
```
As for chatglm2 model, it should be:
```python
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
)
```
When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
### Write Your Unit Testing
This section serves as the guideline for testing the `shardformer` module.
@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
In the case of using 2 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |
| N_CTX | org_model | shard_model |
|:-----:|:---------:|:-----------:|
| 256 | 11.2ms | 17.2ms |
| 512 | 9.8ms | 19.5ms |
| 1024 | 19.6ms | 18.9ms |
| 2048 | 46.6ms | 30.8ms |
| 4096 | 160.5ms | 90.4ms |
<p align="center">
@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
In the case of using 4 GPUs, the training times are as follows.
| N_CTX | org_model | shard_model |
| :------: | :-----: | :-----: |
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |
| N_CTX | org_model | shard_model |
|:-----:|:---------:|:-----------:|
| 256 | 10.0ms | 21.1ms |
| 512 | 11.5ms | 20.2ms |
| 1024 | 22.1ms | 20.6ms |
| 2048 | 46.9ms | 24.8ms |
| 4096 | 160.4ms | 68.0ms |
@ -475,10 +512,10 @@ warmup_fraction = 0.03
| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|:--------:|:-------:|:-------:|:----------:|:-------------:|
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.

View File

@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
)
LazyInitContext.materialize(module)
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
normalized_shape = module.weight.shape[0]
eps = module.variance_epsilon
elementwise_affine = True
else:
# get the attributes of the module
normalized_shape = module.normalized_shape
eps = module.eps
elementwise_affine = module.elementwise_affine
# try to get normalized_shape, eps, elementwise_affine from the module
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
elementwise_affine = getattr(module, "elementwise_affine", True)
rmsnorm = FusedRMSNormWithHook(
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
)
rmsnorm.weight = module.weight

View File

@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
def get_flash_core_attention_forward():
@ -31,7 +30,12 @@ def get_flash_core_attention_forward():
device=query_layer.device,
)
temp_mask = (
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
torch.ones(
query_layer.shape[2],
key_layer.shape[2],
dtype=torch.bool,
device=query_layer.device,
)
.tril(diagonal=0)
.expand(query_layer.shape[0], 1, -1, -1)
)
@ -49,6 +53,7 @@ def get_flash_core_attention_forward():
attention_mask=attn_bias,
attention_mask_type=attention_mask_type,
dropout_p=dropout_p,
scale=1.0 / self.norm_factor,
)
context_layer = context_layer.permute(2, 0, 1, 3)
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
@staticmethod
def chatglm_model_forward(
self: ChatGLMModel,
self: "ChatGLMModel",
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
@ -194,7 +199,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
@ -224,7 +231,9 @@ class ChatGLMPipelineForwards:
if shard_config and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
hidden_states = gather_forward_split_backward(
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@ -254,7 +263,7 @@ class ChatGLMPipelineForwards:
@staticmethod
def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration,
self: "ChatGLMForConditionalGeneration",
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,

View File

@ -151,10 +151,10 @@ _POLICY_LIST = {
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
),
# ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
"transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMModelPolicy"
),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
),
# Falcon
@ -202,6 +202,13 @@ def _fullname(obj):
module = klass.__module__
if module == "builtins":
return klass.__qualname__ # avoid outputs like 'builtins.str'
# patch custom models which are not in transformers
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
if module.startswith("transformers_modules"):
split_module = module.split(".")
if len(split_module) >= 2:
module = f"{split_module[0]}.{split_module[-1]}"
return module + "." + klass.__qualname__
@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy:
if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)

View File

@ -7,7 +7,6 @@ from torch import Tensor
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_forward_fn,
@ -17,7 +16,11 @@ from ..modeling.chatglm2 import (
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"]
__all__ = [
"ChatGLMPolicy",
"ChatGLMModelPolicy",
"ChatGLMForConditionalGenerationPolicy",
]
class ChatGLMPolicy(Policy):
@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy):
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
policy = {}
embedding_cls = None
@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
policy[GLMBlock] = ModulePolicyDescription(
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
attn_kwargs = {
"self_attention.qkv_hidden_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
)
// self.shard_config.tensor_parallel_size,
}
if self.model.config.multi_query_attention:
assert (
self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0
), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = (
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
)
attn_kwargs["self_attention.qkv_hidden_size"] = (
self.model.config.kv_channels * self.model.config.num_attention_heads
+ 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num
) // self.shard_config.tensor_parallel_size
policy["GLMBlock"] = ModulePolicyDescription(
attribute_replacement={
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy):
self.model.config.kv_channels * self.model.config.num_attention_heads
)
// self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size": (
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
)
// self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
* self.model.config.num_attention_heads
// self.shard_config.tensor_parallel_size,
**attn_kwargs,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap},
kwargs={
"seq_parallel_mode": sp_mode,
"seq_parallel_dim": 0,
"overlap": overlap,
},
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy):
),
],
policy=policy,
target_key=ChatGLMModel,
target_key="ChatGLMModel",
)
# optimization configuration
self.append_or_create_submodule_replacement(
@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy):
),
],
policy=policy,
target_key=GLMBlock,
target_key="GLMBlock",
)
if self.model.config.post_layer_norm:
@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy):
)
],
policy=policy,
target_key=ChatGLMModel,
target_key="ChatGLMModel",
)
# use flash attention
@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy):
"forward": get_flash_core_attention_forward(),
},
policy=policy,
target_key=CoreAttention,
target_key="CoreAttention",
)
# use sequence parallel
@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy):
self.append_or_create_method_replacement(
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=ChatGLMModel,
target_key="ChatGLMModel",
)
# use jit fused operator
@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy):
"dropout_add": get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=GLMBlock,
target_key="GLMBlock",
)
return policy
@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy):
stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {
"forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config,
)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy
model_cls="ChatGLMModel",
new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
policy=policy,
)
return policy
@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
model_cls=ChatGLMForConditionalGeneration,
model_cls="ChatGLMForConditionalGeneration",
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
policy=policy,
)

View File

@ -310,13 +310,6 @@ if dist.get_world_size() > 1:
2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
```python
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
```
when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
## How Shardformer Works
### Main Idea

View File

@ -303,13 +303,6 @@ if dist.get_world_size() > 1:
2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时请确保labels的总数为张量并行度的整数倍否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
3. 训练ChatGLM-2 6B的情况有点特殊由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时请通过以下方式导入config/model的类
```python
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
```
并且使用这些导入的类初始化模型。
## Shardformer的工作原理

View File

@ -1,7 +1,6 @@
import torch
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
from torch.nn import init
from transformers import AutoConfig, AutoModelForCausalLM
from ..registry import ModelAttribute, model_zoo
@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
)
loss_fn = lambda x: x["loss"]
config = ChatGLMConfig(
config = AutoConfig.from_pretrained(
"THUDM/chatglm2-6b",
trust_remote_code=True,
num_layers=2,
padded_vocab_size=65024,
hidden_size=64,
ffn_hidden_size=214,
num_attention_heads=8,
kv_channels=16,
rmsnorm=True,
original_rope=True,
use_cache=True,
multi_query_attention=False,
torch_dtype=torch.float32,
)
infer_config = ChatGLMConfig(
infer_config = AutoConfig.from_pretrained(
"THUDM/chatglm2-6b",
trust_remote_code=True,
num_layers=2,
padded_vocab_size=65024,
hidden_size=128,
@ -60,18 +66,18 @@ infer_config = ChatGLMConfig(
torch_dtype=torch.float32,
)
model_zoo.register(
name="transformers_chatglm",
model_fn=lambda: ChatGLMModel(config, empty_init=False),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_chatglm_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
def init_chatglm():
model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
for m in model.modules():
if m.__class__.__name__ == "RMSNorm":
init.ones_(m.weight)
return model
model_zoo.register(
name="transformers_chatglm_for_conditional_generation",
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
model_fn=init_chatglm,
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,

View File

@ -227,7 +227,7 @@ def check_output_hidden_state(
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
def check_weight(

View File

@ -11,6 +11,7 @@ from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
@ -103,8 +104,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
# TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong
# if org_model.__class__.__name__ == "ChatGLMModel":
# check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
if org_model.__class__.__name__ == "ChatGLMModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
@ -177,14 +178,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},