mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix chatglm implementation (#5644)
* [shardformer] fix chatglm policy * [shardformer] fix chatglm flash attn * [shardformer] update readme * [shardformer] fix chatglm init * [shardformer] fix chatglm test * [pipeline] fix chatglm merge batchpull/5654/head
parent
5d88ef1aaf
commit
bbb2c21f16
|
@ -7,7 +7,7 @@ from torch.nn import Module
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
self.send_forward(output_obj)
|
self.send_forward(output_obj)
|
||||||
|
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs = merge_batch(outputs)
|
if isinstance(model, ModelWrapper):
|
||||||
|
model = model.unwrap()
|
||||||
|
batch_size_dim = getattr(model, "batch_size_dim", 0)
|
||||||
|
outputs = merge_batch(outputs, batch_size_dim)
|
||||||
return {"loss": accum_loss, "outputs": outputs}
|
return {"loss": accum_loss, "outputs": outputs}
|
||||||
|
|
||||||
def run_forward_backward(
|
def run_forward_backward(
|
||||||
|
@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
|
||||||
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
|
||||||
|
|
||||||
if outputs is not None:
|
if outputs is not None:
|
||||||
outputs = merge_batch(outputs)
|
if isinstance(model, ModelWrapper):
|
||||||
|
model = model.unwrap()
|
||||||
|
batch_size_dim = getattr(model, "batch_size_dim", 0)
|
||||||
|
outputs = merge_batch(outputs, batch_size_dim)
|
||||||
return {"loss": accum_loss, "outputs": outputs}
|
return {"loss": accum_loss, "outputs": outputs}
|
||||||
|
|
||||||
def forward_backward_step(
|
def forward_backward_step(
|
||||||
|
|
|
@ -115,7 +115,7 @@ We will follow this roadmap to develop Shardformer:
|
||||||
- [ ] Policy Implementation
|
- [ ] Policy Implementation
|
||||||
|
|
||||||
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
|
||||||
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
|
|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
|
||||||
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] |
|
||||||
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||||
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] |
|
||||||
|
@ -391,6 +391,43 @@ _POLICY_LIST = {
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### How to support those models in huggingface model hub but not in the transformers library
|
||||||
|
|
||||||
|
There are two cases:
|
||||||
|
|
||||||
|
1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
|
||||||
|
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".
|
||||||
|
|
||||||
|
Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.
|
||||||
|
|
||||||
|
Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
|
||||||
|
|
||||||
|
E.g. for llama:
|
||||||
|
```python
|
||||||
|
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
for chatglm2:
|
||||||
|
```python
|
||||||
|
policy["GLMBlock"] = ModulePolicyDescription(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
Then when registering such models in the autopolicy, we should follow below format:
|
||||||
|
```python
|
||||||
|
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
|
||||||
|
file_name="<policy_filename>", class_name="<policy_class_name>"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
As for chatglm2 model, it should be:
|
||||||
|
```python
|
||||||
|
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||||
|
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
|
||||||
|
|
||||||
### Write Your Unit Testing
|
### Write Your Unit Testing
|
||||||
|
|
||||||
This section serves as the guideline for testing the `shardformer` module.
|
This section serves as the guideline for testing the `shardformer` module.
|
||||||
|
@ -425,7 +462,7 @@ We set the batch size to 4, the number of attention heads to 8, and the head dim
|
||||||
|
|
||||||
In the case of using 2 GPUs, the training times are as follows.
|
In the case of using 2 GPUs, the training times are as follows.
|
||||||
| N_CTX | org_model | shard_model |
|
| N_CTX | org_model | shard_model |
|
||||||
| :------: | :-----: | :-----: |
|
|:-----:|:---------:|:-----------:|
|
||||||
| 256 | 11.2ms | 17.2ms |
|
| 256 | 11.2ms | 17.2ms |
|
||||||
| 512 | 9.8ms | 19.5ms |
|
| 512 | 9.8ms | 19.5ms |
|
||||||
| 1024 | 19.6ms | 18.9ms |
|
| 1024 | 19.6ms | 18.9ms |
|
||||||
|
@ -441,7 +478,7 @@ In the case of using 2 GPUs, the training times are as follows.
|
||||||
In the case of using 4 GPUs, the training times are as follows.
|
In the case of using 4 GPUs, the training times are as follows.
|
||||||
|
|
||||||
| N_CTX | org_model | shard_model |
|
| N_CTX | org_model | shard_model |
|
||||||
| :------: | :-----: | :-----: |
|
|:-----:|:---------:|:-----------:|
|
||||||
| 256 | 10.0ms | 21.1ms |
|
| 256 | 10.0ms | 21.1ms |
|
||||||
| 512 | 11.5ms | 20.2ms |
|
| 512 | 11.5ms | 20.2ms |
|
||||||
| 1024 | 22.1ms | 20.6ms |
|
| 1024 | 22.1ms | 20.6ms |
|
||||||
|
@ -475,7 +512,7 @@ warmup_fraction = 0.03
|
||||||
|
|
||||||
|
|
||||||
| accuracy | f1 | loss | GPU number | model sharded |
|
| accuracy | f1 | loss | GPU number | model sharded |
|
||||||
| :------: | :-----: | :-----: | :--------: | :---------: |
|
|:--------:|:-------:|:-------:|:----------:|:-------------:|
|
||||||
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
|
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
|
||||||
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
|
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
|
||||||
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
|
||||||
|
|
|
@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm):
|
||||||
)
|
)
|
||||||
|
|
||||||
LazyInitContext.materialize(module)
|
LazyInitContext.materialize(module)
|
||||||
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
|
|
||||||
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
|
# try to get normalized_shape, eps, elementwise_affine from the module
|
||||||
normalized_shape = module.weight.shape[0]
|
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
|
||||||
eps = module.variance_epsilon
|
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
|
||||||
elementwise_affine = True
|
elementwise_affine = getattr(module, "elementwise_affine", True)
|
||||||
else:
|
|
||||||
# get the attributes of the module
|
|
||||||
normalized_shape = module.normalized_shape
|
|
||||||
eps = module.eps
|
|
||||||
elementwise_affine = module.elementwise_affine
|
|
||||||
|
|
||||||
rmsnorm = FusedRMSNormWithHook(
|
rmsnorm = FusedRMSNormWithHook(
|
||||||
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
|
normalized_shape=normalized_shape,
|
||||||
|
eps=eps,
|
||||||
|
elementwise_affine=elementwise_affine,
|
||||||
)
|
)
|
||||||
|
|
||||||
rmsnorm.weight = module.weight
|
rmsnorm.weight = module.weight
|
||||||
|
|
|
@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
|
||||||
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
|
||||||
|
|
||||||
|
|
||||||
def get_flash_core_attention_forward():
|
def get_flash_core_attention_forward():
|
||||||
|
@ -31,7 +30,12 @@ def get_flash_core_attention_forward():
|
||||||
device=query_layer.device,
|
device=query_layer.device,
|
||||||
)
|
)
|
||||||
temp_mask = (
|
temp_mask = (
|
||||||
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
|
torch.ones(
|
||||||
|
query_layer.shape[2],
|
||||||
|
key_layer.shape[2],
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=query_layer.device,
|
||||||
|
)
|
||||||
.tril(diagonal=0)
|
.tril(diagonal=0)
|
||||||
.expand(query_layer.shape[0], 1, -1, -1)
|
.expand(query_layer.shape[0], 1, -1, -1)
|
||||||
)
|
)
|
||||||
|
@ -49,6 +53,7 @@ def get_flash_core_attention_forward():
|
||||||
attention_mask=attn_bias,
|
attention_mask=attn_bias,
|
||||||
attention_mask_type=attention_mask_type,
|
attention_mask_type=attention_mask_type,
|
||||||
dropout_p=dropout_p,
|
dropout_p=dropout_p,
|
||||||
|
scale=1.0 / self.norm_factor,
|
||||||
)
|
)
|
||||||
context_layer = context_layer.permute(2, 0, 1, 3)
|
context_layer = context_layer.permute(2, 0, 1, 3)
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
|
||||||
|
@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def chatglm_model_forward(
|
def chatglm_model_forward(
|
||||||
self: ChatGLMModel,
|
self: "ChatGLMModel",
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.BoolTensor] = None,
|
attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
|
@ -194,7 +199,9 @@ class ChatGLMPipelineForwards:
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = split_forward_gather_backward(
|
hidden_states = split_forward_gather_backward(
|
||||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
hidden_states,
|
||||||
|
dim=0,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
)
|
)
|
||||||
for idx in range(start_idx, end_idx):
|
for idx in range(start_idx, end_idx):
|
||||||
layer = self.encoder._get_layer(idx)
|
layer = self.encoder._get_layer(idx)
|
||||||
|
@ -224,7 +231,9 @@ class ChatGLMPipelineForwards:
|
||||||
if shard_config and shard_config.enable_sequence_parallelism:
|
if shard_config and shard_config.enable_sequence_parallelism:
|
||||||
if shard_config.sequence_parallelism_mode == "split_gather":
|
if shard_config.sequence_parallelism_mode == "split_gather":
|
||||||
hidden_states = gather_forward_split_backward(
|
hidden_states = gather_forward_split_backward(
|
||||||
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
|
hidden_states,
|
||||||
|
dim=0,
|
||||||
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
)
|
)
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
@ -254,7 +263,7 @@ class ChatGLMPipelineForwards:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def chatglm_for_conditional_generation_forward(
|
def chatglm_for_conditional_generation_forward(
|
||||||
self: ChatGLMForConditionalGeneration,
|
self: "ChatGLMForConditionalGeneration",
|
||||||
input_ids: Optional[torch.Tensor] = None,
|
input_ids: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.Tensor] = None,
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
|
|
@ -151,10 +151,10 @@ _POLICY_LIST = {
|
||||||
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
|
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
|
||||||
),
|
),
|
||||||
# ChatGLM
|
# ChatGLM
|
||||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
|
"transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
|
||||||
file_name="chatglm2", class_name="ChatGLMModelPolicy"
|
file_name="chatglm2", class_name="ChatGLMModelPolicy"
|
||||||
),
|
),
|
||||||
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
|
||||||
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
|
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
|
||||||
),
|
),
|
||||||
# Falcon
|
# Falcon
|
||||||
|
@ -202,6 +202,13 @@ def _fullname(obj):
|
||||||
module = klass.__module__
|
module = klass.__module__
|
||||||
if module == "builtins":
|
if module == "builtins":
|
||||||
return klass.__qualname__ # avoid outputs like 'builtins.str'
|
return klass.__qualname__ # avoid outputs like 'builtins.str'
|
||||||
|
# patch custom models which are not in transformers
|
||||||
|
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
|
||||||
|
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
|
||||||
|
if module.startswith("transformers_modules"):
|
||||||
|
split_module = module.split(".")
|
||||||
|
if len(split_module) >= 2:
|
||||||
|
module = f"{split_module[0]}.{split_module[-1]}"
|
||||||
return module + "." + klass.__qualname__
|
return module + "." + klass.__qualname__
|
||||||
|
|
||||||
|
|
||||||
|
@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy:
|
||||||
|
|
||||||
if policy_location is None:
|
if policy_location is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
|
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
policy = import_policy(policy_location)
|
policy = import_policy(policy_location)
|
||||||
|
|
|
@ -7,7 +7,6 @@ from torch import Tensor
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
|
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
|
||||||
|
|
||||||
from ..modeling.chatglm2 import (
|
from ..modeling.chatglm2 import (
|
||||||
get_chatglm_sequence_parallel_forward_fn,
|
get_chatglm_sequence_parallel_forward_fn,
|
||||||
|
@ -17,7 +16,11 @@ from ..modeling.chatglm2 import (
|
||||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"]
|
__all__ = [
|
||||||
|
"ChatGLMPolicy",
|
||||||
|
"ChatGLMModelPolicy",
|
||||||
|
"ChatGLMForConditionalGenerationPolicy",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ChatGLMPolicy(Policy):
|
class ChatGLMPolicy(Policy):
|
||||||
|
@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy):
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock
|
|
||||||
|
|
||||||
policy = {}
|
policy = {}
|
||||||
|
|
||||||
embedding_cls = None
|
embedding_cls = None
|
||||||
|
@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy):
|
||||||
sp_partial_derived = sp_mode == "split_gather"
|
sp_partial_derived = sp_mode == "split_gather"
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
policy[GLMBlock] = ModulePolicyDescription(
|
assert (
|
||||||
|
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
|
||||||
|
attn_kwargs = {
|
||||||
|
"self_attention.qkv_hidden_size": (
|
||||||
|
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
|
||||||
|
)
|
||||||
|
// self.shard_config.tensor_parallel_size,
|
||||||
|
}
|
||||||
|
if self.model.config.multi_query_attention:
|
||||||
|
assert (
|
||||||
|
self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0
|
||||||
|
), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}"
|
||||||
|
attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = (
|
||||||
|
self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size
|
||||||
|
)
|
||||||
|
attn_kwargs["self_attention.qkv_hidden_size"] = (
|
||||||
|
self.model.config.kv_channels * self.model.config.num_attention_heads
|
||||||
|
+ 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num
|
||||||
|
) // self.shard_config.tensor_parallel_size
|
||||||
|
policy["GLMBlock"] = ModulePolicyDescription(
|
||||||
attribute_replacement={
|
attribute_replacement={
|
||||||
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
|
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
|
||||||
// self.shard_config.tensor_parallel_size,
|
// self.shard_config.tensor_parallel_size,
|
||||||
|
@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy):
|
||||||
self.model.config.kv_channels * self.model.config.num_attention_heads
|
self.model.config.kv_channels * self.model.config.num_attention_heads
|
||||||
)
|
)
|
||||||
// self.shard_config.tensor_parallel_size,
|
// self.shard_config.tensor_parallel_size,
|
||||||
"self_attention.qkv_hidden_size": (
|
|
||||||
self.model.config.kv_channels * self.model.config.num_attention_heads * 3
|
|
||||||
)
|
|
||||||
// self.shard_config.tensor_parallel_size,
|
|
||||||
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
|
"self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
|
||||||
// self.shard_config.tensor_parallel_size,
|
// self.shard_config.tensor_parallel_size,
|
||||||
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
|
"self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels
|
||||||
* self.model.config.num_attention_heads
|
* self.model.config.num_attention_heads
|
||||||
// self.shard_config.tensor_parallel_size,
|
// self.shard_config.tensor_parallel_size,
|
||||||
|
**attn_kwargs,
|
||||||
},
|
},
|
||||||
param_replacement=[],
|
param_replacement=[],
|
||||||
sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.query_key_value",
|
suffix="self_attention.query_key_value",
|
||||||
target_module=col_nn.Linear1D_Col,
|
target_module=col_nn.Linear1D_Col,
|
||||||
kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap},
|
kwargs={
|
||||||
|
"seq_parallel_mode": sp_mode,
|
||||||
|
"seq_parallel_dim": 0,
|
||||||
|
"overlap": overlap,
|
||||||
|
},
|
||||||
),
|
),
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="self_attention.dense",
|
suffix="self_attention.dense",
|
||||||
|
@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=ChatGLMModel,
|
target_key="ChatGLMModel",
|
||||||
)
|
)
|
||||||
# optimization configuration
|
# optimization configuration
|
||||||
self.append_or_create_submodule_replacement(
|
self.append_or_create_submodule_replacement(
|
||||||
|
@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy):
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=GLMBlock,
|
target_key="GLMBlock",
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.model.config.post_layer_norm:
|
if self.model.config.post_layer_norm:
|
||||||
|
@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy):
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=ChatGLMModel,
|
target_key="ChatGLMModel",
|
||||||
)
|
)
|
||||||
|
|
||||||
# use flash attention
|
# use flash attention
|
||||||
|
@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy):
|
||||||
"forward": get_flash_core_attention_forward(),
|
"forward": get_flash_core_attention_forward(),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=CoreAttention,
|
target_key="CoreAttention",
|
||||||
)
|
)
|
||||||
|
|
||||||
# use sequence parallel
|
# use sequence parallel
|
||||||
|
@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy):
|
||||||
self.append_or_create_method_replacement(
|
self.append_or_create_method_replacement(
|
||||||
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
|
description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=ChatGLMModel,
|
target_key="ChatGLMModel",
|
||||||
)
|
)
|
||||||
|
|
||||||
# use jit fused operator
|
# use jit fused operator
|
||||||
|
@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy):
|
||||||
"dropout_add": get_jit_fused_dropout_add_func(),
|
"dropout_add": get_jit_fused_dropout_add_func(),
|
||||||
},
|
},
|
||||||
policy=policy,
|
policy=policy,
|
||||||
target_key=GLMBlock,
|
target_key="GLMBlock",
|
||||||
)
|
)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy):
|
||||||
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
stage_index = stage_manager.get_stage_index(layers_per_stage)
|
||||||
method_replacement = {
|
method_replacement = {
|
||||||
"forward": partial(
|
"forward": partial(
|
||||||
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
|
new_forward,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
stage_index=stage_index,
|
||||||
|
shard_config=self.shard_config,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy):
|
||||||
|
|
||||||
if self.pipeline_stage_manager is not None:
|
if self.pipeline_stage_manager is not None:
|
||||||
self.set_pipeline_forward(
|
self.set_pipeline_forward(
|
||||||
model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy
|
model_cls="ChatGLMModel",
|
||||||
|
new_forward=ChatGLMPipelineForwards.chatglm_model_forward,
|
||||||
|
policy=policy,
|
||||||
)
|
)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy):
|
||||||
|
|
||||||
if self.pipeline_stage_manager is not None:
|
if self.pipeline_stage_manager is not None:
|
||||||
self.set_pipeline_forward(
|
self.set_pipeline_forward(
|
||||||
model_cls=ChatGLMForConditionalGeneration,
|
model_cls="ChatGLMForConditionalGeneration",
|
||||||
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
|
new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
)
|
)
|
||||||
|
|
|
@ -310,13 +310,6 @@ if dist.get_world_size() > 1:
|
||||||
|
|
||||||
2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
|
2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer.
|
||||||
|
|
||||||
3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through
|
|
||||||
```python
|
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
|
||||||
```
|
|
||||||
when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes.
|
|
||||||
|
|
||||||
## How Shardformer Works
|
## How Shardformer Works
|
||||||
|
|
||||||
### Main Idea
|
### Main Idea
|
||||||
|
|
|
@ -303,13 +303,6 @@ if dist.get_world_size() > 1:
|
||||||
|
|
||||||
2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
|
2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。
|
||||||
|
|
||||||
3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类:
|
|
||||||
```python
|
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
|
||||||
```
|
|
||||||
并且使用这些导入的类初始化模型。
|
|
||||||
|
|
||||||
|
|
||||||
## Shardformer的工作原理
|
## Shardformer的工作原理
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
|
from torch.nn import init
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
|
|
||||||
|
|
||||||
from ..registry import ModelAttribute, model_zoo
|
from ..registry import ModelAttribute, model_zoo
|
||||||
|
|
||||||
|
@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
|
||||||
)
|
)
|
||||||
loss_fn = lambda x: x["loss"]
|
loss_fn = lambda x: x["loss"]
|
||||||
|
|
||||||
config = ChatGLMConfig(
|
config = AutoConfig.from_pretrained(
|
||||||
|
"THUDM/chatglm2-6b",
|
||||||
|
trust_remote_code=True,
|
||||||
num_layers=2,
|
num_layers=2,
|
||||||
padded_vocab_size=65024,
|
padded_vocab_size=65024,
|
||||||
hidden_size=64,
|
hidden_size=64,
|
||||||
|
ffn_hidden_size=214,
|
||||||
num_attention_heads=8,
|
num_attention_heads=8,
|
||||||
kv_channels=16,
|
kv_channels=16,
|
||||||
rmsnorm=True,
|
rmsnorm=True,
|
||||||
original_rope=True,
|
original_rope=True,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
|
multi_query_attention=False,
|
||||||
torch_dtype=torch.float32,
|
torch_dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
infer_config = ChatGLMConfig(
|
|
||||||
|
infer_config = AutoConfig.from_pretrained(
|
||||||
|
"THUDM/chatglm2-6b",
|
||||||
|
trust_remote_code=True,
|
||||||
num_layers=2,
|
num_layers=2,
|
||||||
padded_vocab_size=65024,
|
padded_vocab_size=65024,
|
||||||
hidden_size=128,
|
hidden_size=128,
|
||||||
|
@ -60,18 +66,18 @@ infer_config = ChatGLMConfig(
|
||||||
torch_dtype=torch.float32,
|
torch_dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_zoo.register(
|
|
||||||
name="transformers_chatglm",
|
def init_chatglm():
|
||||||
model_fn=lambda: ChatGLMModel(config, empty_init=False),
|
model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True)
|
||||||
data_gen_fn=data_gen,
|
for m in model.modules():
|
||||||
output_transform_fn=output_transform_fn,
|
if m.__class__.__name__ == "RMSNorm":
|
||||||
loss_fn=loss_fn_for_chatglm_model,
|
init.ones_(m.weight)
|
||||||
model_attribute=ModelAttribute(has_control_flow=True),
|
return model
|
||||||
)
|
|
||||||
|
|
||||||
model_zoo.register(
|
model_zoo.register(
|
||||||
name="transformers_chatglm_for_conditional_generation",
|
name="transformers_chatglm_for_conditional_generation",
|
||||||
model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False),
|
model_fn=init_chatglm,
|
||||||
data_gen_fn=data_gen_for_conditional_generation,
|
data_gen_fn=data_gen_for_conditional_generation,
|
||||||
output_transform_fn=output_transform_fn,
|
output_transform_fn=output_transform_fn,
|
||||||
loss_fn=loss_fn,
|
loss_fn=loss_fn,
|
||||||
|
|
|
@ -227,7 +227,7 @@ def check_output_hidden_state(
|
||||||
|
|
||||||
|
|
||||||
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
|
def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3):
|
||||||
assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
|
assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
def check_weight(
|
def check_weight(
|
||||||
|
|
|
@ -11,6 +11,7 @@ from tests.test_shardformer.test_model._utils import (
|
||||||
build_model_from_hybrid_plugin,
|
build_model_from_hybrid_plugin,
|
||||||
check_all_grad_tensors,
|
check_all_grad_tensors,
|
||||||
check_loss,
|
check_loss,
|
||||||
|
check_output_hidden_state,
|
||||||
check_weight,
|
check_weight,
|
||||||
get_grad_tensors_for_check,
|
get_grad_tensors_for_check,
|
||||||
run_forward_backward_with_hybrid_plugin,
|
run_forward_backward_with_hybrid_plugin,
|
||||||
|
@ -103,8 +104,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
atol, rtol = 5e-3, 5e-3
|
atol, rtol = 5e-3, 5e-3
|
||||||
|
|
||||||
# TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong
|
# TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong
|
||||||
# if org_model.__class__.__name__ == "ChatGLMModel":
|
if org_model.__class__.__name__ == "ChatGLMModel":
|
||||||
# check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
|
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
|
||||||
|
|
||||||
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
@ -177,14 +178,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
||||||
{
|
{
|
||||||
"tp_size": 4,
|
"tp_size": 4,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": False,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"tp_size": 2,
|
"tp_size": 2,
|
||||||
"pp_size": 1,
|
"pp_size": 1,
|
||||||
"enable_all_optimization": True,
|
"enable_all_optimization": False,
|
||||||
"use_lazy_init": False,
|
"use_lazy_init": False,
|
||||||
"precision": "fp32",
|
"precision": "fp32",
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in New Issue