mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix chatglm implementation (#5644)
* [shardformer] fix chatglm policy * [shardformer] fix chatglm flash attn * [shardformer] update readme * [shardformer] fix chatglm init * [shardformer] fix chatglm test * [pipeline] fix chatglm merge batchpull/5654/head
parent
5d88ef1aaf
commit
bbb2c21f16
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
|||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.utils import get_current_device
|
||||
|
@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
self.send_forward(output_obj)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
batch_size_dim = getattr(model, "batch_size_dim", 0)
|
||||
outputs = merge_batch(outputs, batch_size_dim)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def run_forward_backward(
|
||||
|
@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
|||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||
|
||||
if outputs is not None:
|
||||
outputs = merge_batch(outputs)
|
||||
if isinstance(model, ModelWrapper):
|
||||
model = model.unwrap()
|
||||
batch_size_dim = getattr(model, "batch_size_dim", 0)
|
||||
outputs = merge_batch(outputs, batch_size_dim)
|
||||
return {"loss": accum_loss, "outputs": outputs}
|
||||
|
||||
def forward_backward_step(
|
||||
|
|
|
@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
|
|||
- [x] Unit Testing
|
||||
- [ ] Policy Implementation
|
||||
|
||||
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|
||||
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
|
||||
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
|
||||
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
|
||||
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|
||||
|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
|
||||
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||
| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
|
||||
| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] |
|
||||
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
|
||||
| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||
|
||||
|
||||
## 💡 API Design
|
||||
|
@ -391,6 +391,43 @@ _POLICY_LIST = {
|
|||
}
|
||||
```
|
||||
|
||||
#### How to support those models in huggingface model hub but not in the transformers library
|
||||
|
||||
There are two cases:
|
||||
|
||||
1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
|
||||
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".
|
||||
|
||||
Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.
|
||||
|
||||
Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
|
||||
|
||||
E.g. for llama:
|
||||
```python
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
|
||||
```
|
||||
|
||||
for chatglm2:
|
||||
```python
|
||||
policy["GLMBlock"] = ModulePolicyDescription(...)
|
||||
```
|
||||
|
||||
Then when registering such models in the autopolicy, we should follow below format:
|
||||
```python
|
||||
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
|
||||
file_name="<policy_filename>", class_name="<policy_class_name>"
|
||||
)
|
||||
```
|
||||
|
||||
As for chatglm2 model, it should be:
|
||||
```python
|
||||
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
|
||||
)
|
||||
```
|
||||
|
||||
When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
|
||||
|
||||
### Write Your Unit Testing
|
||||
|
||||
This section serves as the guideline for testing the `shardformer` module.
|
||||
|
@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
|
|||
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
|
||||
|
||||
In the case of using 2 GPUs, the training times are as follows.
|
||||
| N_CTX | org_model | shard_model |
|
||||
| :------: | :-----: | :-----: |
|
||||
| 256 | 11.2ms | 17.2ms |
|
||||
| 512 | 9.8ms | 19.5ms |
|
||||
| 1024 | 19.6ms | 18.9ms |
|
||||
| 2048 | 46.6ms | 30.8ms |
|
||||
| 4096 | 160.5ms | 90.4ms |
|
||||
| N_CTX | org_model | shard_model |
|
||||
|:-----:|:---------:|:-----------:|
|
||||
| 256 | 11.2ms | 17.2ms |
|
||||
| 512 | 9.8ms | 19.5ms |
|
||||
| 1024 | 19.6ms | 18.9ms |
|
||||
| 2048 | 46.6ms | 30.8ms |
|
||||
| 4096 | 160.5ms | 90.4ms |
|
||||
|
||||
|
||||
<p align="center">
|
||||
|
@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
|
|||
|
||||
In the case of using 4 GPUs, the training times are as follows.
|
||||
|
||||
| N_CTX | org_model | shard_model |
|
||||
| :------: | :-----: | :-----: |
|
||||
| 256 | 10.0ms | 21.1ms |
|
||||
| 512 | 11.5ms | 20.2ms |
|
||||
| 1024 | 22.1ms | 20.6ms |
|
||||
| 2048 | 46.9ms | 24.8ms |
|
||||
| 4096 | 160.4ms | 68.0ms |
|
||||
| N_CTX | org_model | shard_model |
|
||||
|:-----:|:---------:|:-----------:|
|
||||
| 256 | 10.0ms | 21.1ms |
|
||||
| 512 | 11.5ms | 20.2ms |
|
||||
| 1024 | 22.1ms | 20.6ms |
|
||||
| 2048 | 46.9ms | 24.8ms |
|
||||
| 4096 | 160.4ms | 68.0ms |
|
||||
|
||||
|
||||
|
||||
|
@ -475,10 +512,10 @@ warmup_fraction = 0.03
|
|||
|
||||
|
||||
| accuracy | f1 | loss | GPU number | model sharded |
|
||||
| :------: | :-----: | :-----: | :--------: | :---------: |
|
||||
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
|
||||
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
|
||||
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|
||||
|:--------:|:-------:|:-------:|:----------:|:-------------:|
|
||||
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
|
||||
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
|
||||
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|
||||
|
||||
|
||||
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
|
||||
|
|
|
@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
|
|||
)
|
||||
|
||||
LazyInitContext.materialize(module)
|
||||
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
|
||||
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
|
||||
normalized_shape = module.weight.shape[0]
|
||||
eps = module.variance_epsilon
|
||||
elementwise_affine = True
|
||||
else:
|
||||
# get the attributes of the module
|
||||
normalized_shape = module.normalized_shape
|
||||
eps = module.eps
|
||||
elementwise_affine = module.elementwise_affine
|
||||
|
||||
# try to get normalized_shape, eps, elementwise_affine from the module
|
||||
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
|
||||
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
||||
elementwise_affine = getattr(module, "elementwise_affine", True)
|
||||
|
||||
rmsnorm = FusedRMSNormWithHook(
|
||||
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
|
||||
normalized_shape=normalized_shape,
|
||||
eps=eps,
|
||||
elementwise_affine=elementwise_affine,
|
||||
)
|
||||
|
||||
rmsnorm.weight = module.weight
|
||||
|
|
|
@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
from colossalai.shardformer import ShardConfig
|
||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
|
||||
def get_flash_core_attention_forward():
|
||||
|
@ -31,7 +30,12 @@ def get_flash_core_attention_forward():
|
|||
device=query_layer.device,
|
||||
)
|
||||
temp_mask = (
|
||||
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
|
||||
torch.ones(
|
||||
query_layer.shape[2],
|
||||
key_layer.shape[2],
|
||||
dtype=torch.bool,
|
||||
device=query_layer.device,
|
||||
)
|
||||
.tril(diagonal=0)
|
||||
.expand(query_layer.shape[0], 1, -1, -1)
|
||||
)
|
||||
|
@ -49,6 +53,7 @@ def get_flash_core_attention_forward():
|
|||
attention_mask=attn_bias,
|
||||
attention_mask_type=attention_mask_type,
|
||||
dropout_p=dropout_p,
|
||||
scale=1.0 / self.norm_factor,
|
||||
)
|
||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||
|
@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
|
|||
|
||||
@staticmethod
|
||||
def chatglm_model_forward(
|
||||
self: ChatGLMModel,
|
||||
self: "ChatGLMModel",
|
||||
input_ids,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
|
@ -194,7 +199,9 @@ class ChatGLMPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = split_forward_gather_backward(
|
||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
for idx in range(start_idx, end_idx):
|
||||
layer = self.encoder._get_layer(idx)
|
||||
|
@ -224,7 +231,9 @@ class ChatGLMPipelineForwards:
|
|||
if shard_config and shard_config.enable_sequence_parallelism:
|
||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||
hidden_states = gather_forward_split_backward(
|
||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
||||
hidden_states,
|
||||
dim=0,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
)
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
@ -254,7 +263,7 @@ class ChatGLMPipelineForwards:
|
|||
|
||||
@staticmethod
|
||||
def chatglm_for_conditional_generation_forward(
|
||||
self: ChatGLMForConditionalGeneration,
|
||||
self: "ChatGLMForConditionalGeneration",
|
||||
input_ids: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
|
|
|
@ -151,10 +151,10 @@ _POLICY_LIST = {
|
|||
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
|
||||
),
|
||||
# ChatGLM
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
|
||||
"transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLMModelPolicy"
|
||||
),
|
||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
|
||||
),
|
||||
# Falcon
|
||||
|
@ -202,6 +202,13 @@ def _fullname(obj):
|
|||
module = klass.__module__
|
||||
if module == "builtins":
|
||||
return klass.__qualname__ # avoid outputs like 'builtins.str'
|
||||
# patch custom models which are not in transformers
|
||||
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
|
||||
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
|
||||
if module.startswith("transformers_modules"):
|
||||
split_module = module.split(".")
|
||||
if len(split_module) >= 2:
|
||||
module = f"{split_module[0]}.{split_module[-1]}"
|
||||
return module + "." + klass.__qualname__
|
||||
|
||||
|
||||
|
@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy:
|
|||
|
||||
if policy_location is None:
|
||||
raise NotImplementedError(
|
||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
|
||||
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
|
||||
)
|
||||
else:
|
||||
policy = import_policy(policy_location)
|
||||
|
|
|
@ -7,7 +7,6 @@ from torch import Tensor
|
|||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
|
||||
from ..modeling.chatglm2 import (
|
||||
get_chatglm_sequence_parallel_forward_fn,
|
||||
|
@ -17,7 +16,11 @@ from ..modeling.chatglm2 import (
|
|||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"]
|
||||
__all__ = [
|
||||
"ChatGLMPolicy",
|
||||
"ChatGLMModelPolicy",
|
||||
"ChatGLMForConditionalGenerationPolicy",
|
||||
]
|
||||
|
||||
|
||||
class ChatGLMPolicy(Policy):
|
||||
|
@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy):
|
|||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
|
||||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
|
@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy):
|
|||
sp_partial_derived = sp_mode == "split_gather"
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[GLMBlock] = ModulePolicyDescription(
|
||||
assert (
|
||||
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||
), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
|
||||
attn_kwargs = {
|
||||
"self_attention.qkv_hidden_size": (
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
|
||||
)
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if self.model.config.multi_query_attention:
|
||||
assert (
|
||||
self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0
|
||||
), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
|
||||
attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = (
|
||||
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
attn_kwargs["self_attention.qkv_hidden_size"] = (
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads
|
||||
+ 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num
|
||||
) // self.shard_config.tensor_parallel_size
|
||||
policy["GLMBlock"] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
|
@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy):
|
|||
self.model.config.kv_channels * self.model.config.num_attention_heads
|
||||
)
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.qkv_hidden_size": (
|
||||
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
|
||||
)
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
|
||||
* self.model.config.num_attention_heads
|
||||
// self.shard_config.tensor_parallel_size,
|
||||
**attn_kwargs,
|
||||
},
|
||||
param_replacement=[],
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.query_key_value",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap},
|
||||
kwargs={
|
||||
"seq_parallel_mode": sp_mode,
|
||||
"seq_parallel_dim": 0,
|
||||
"overlap": overlap,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attention.dense",
|
||||
|
@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy):
|
|||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
target_key="ChatGLMModel",
|
||||
)
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy):
|
|||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GLMBlock,
|
||||
target_key="GLMBlock",
|
||||
)
|
||||
|
||||
if self.model.config.post_layer_norm:
|
||||
|
@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy):
|
|||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
target_key="ChatGLMModel",
|
||||
)
|
||||
|
||||
# use flash attention
|
||||
|
@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy):
|
|||
"forward": get_flash_core_attention_forward(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=CoreAttention,
|
||||
target_key="CoreAttention",
|
||||
)
|
||||
|
||||
# use sequence parallel
|
||||
|
@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy):
|
|||
self.append_or_create_method_replacement(
|
||||
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
target_key="ChatGLMModel",
|
||||
)
|
||||
|
||||
# use jit fused operator
|
||||
|
@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy):
|
|||
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||
},
|
||||
policy=policy,
|
||||
target_key=GLMBlock,
|
||||
target_key="GLMBlock",
|
||||
)
|
||||
|
||||
return policy
|
||||
|
@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy):
|
|||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||
method_replacement = {
|
||||
"forward": partial(
|
||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
||||
new_forward,
|
||||
stage_manager=stage_manager,
|
||||
stage_index=stage_index,
|
||||
shard_config=self.shard_config,
|
||||
)
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
|
|||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy
|
||||
model_cls="ChatGLMModel",
|
||||
new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
|
||||
policy=policy,
|
||||
)
|
||||
return policy
|
||||
|
||||
|
@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
|
|||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=ChatGLMForConditionalGeneration,
|
||||
model_cls="ChatGLMForConditionalGeneration",
|
||||
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
|
||||
policy=policy,
|
||||
)
|
||||
|
|
|
@ -310,13 +310,6 @@ if dist.get_world_size() > 1:
|
|||
|
||||
2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
|
||||
|
||||
3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
|
||||
```python
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
```
|
||||
when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
|
||||
|
||||
## How Shardformer Works
|
||||
|
||||
### Main Idea
|
||||
|
|
|
@ -303,13 +303,6 @@ if dist.get_world_size() > 1:
|
|||
|
||||
2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
|
||||
|
||||
3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类:
|
||||
```python
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
```
|
||||
并且使用这些导入的类初始化模型。
|
||||
|
||||
|
||||
## Shardformer的工作原理
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import torch
|
||||
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
||||
from torch.nn import init
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
|
@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
|
|||
)
|
||||
loss_fn = lambda x: x["loss"]
|
||||
|
||||
config = ChatGLMConfig(
|
||||
config = AutoConfig.from_pretrained(
|
||||
"THUDM/chatglm2-6b",
|
||||
trust_remote_code=True,
|
||||
num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=64,
|
||||
ffn_hidden_size=214,
|
||||
num_attention_heads=8,
|
||||
kv_channels=16,
|
||||
rmsnorm=True,
|
||||
original_rope=True,
|
||||
use_cache=True,
|
||||
multi_query_attention=False,
|
||||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
infer_config = ChatGLMConfig(
|
||||
|
||||
infer_config = AutoConfig.from_pretrained(
|
||||
"THUDM/chatglm2-6b",
|
||||
trust_remote_code=True,
|
||||
num_layers=2,
|
||||
padded_vocab_size=65024,
|
||||
hidden_size=128,
|
||||
|
@ -60,18 +66,18 @@ infer_config = ChatGLMConfig(
|
|||
torch_dtype=torch.float32,
|
||||
)
|
||||
|
||||
model_zoo.register(
|
||||
name="transformers_chatglm",
|
||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_chatglm_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
||||
def init_chatglm():
|
||||
model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
|
||||
for m in model.modules():
|
||||
if m.__class__.__name__ == "RMSNorm":
|
||||
init.ones_(m.weight)
|
||||
return model
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name="transformers_chatglm_for_conditional_generation",
|
||||
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
|
||||
model_fn=init_chatglm,
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
|
|
|
@ -227,7 +227,7 @@ def check_output_hidden_state(
|
|||
|
||||
|
||||
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
|
||||
assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
|
||||
assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
|
||||
|
||||
|
||||
def check_weight(
|
||||
|
|
|
@ -11,6 +11,7 @@ from tests.test_shardformer.test_model._utils import (
|
|||
build_model_from_hybrid_plugin,
|
||||
check_all_grad_tensors,
|
||||
check_loss,
|
||||
check_output_hidden_state,
|
||||
check_weight,
|
||||
get_grad_tensors_for_check,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
|
@ -103,8 +104,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
atol, rtol = 5e-3, 5e-3
|
||||
|
||||
# TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong
|
||||
# if org_model.__class__.__name__ == "ChatGLMModel":
|
||||
# check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
|
||||
if org_model.__class__.__name__ == "ChatGLMModel":
|
||||
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
|
||||
|
||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||
|
||||
|
@ -177,14 +178,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
{
|
||||
"tp_size": 4,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
},
|
||||
{
|
||||
"tp_size": 2,
|
||||
"pp_size": 1,
|
||||
"enable_all_optimization": True,
|
||||
"enable_all_optimization": False,
|
||||
"use_lazy_init": False,
|
||||
"precision": "fp32",
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue