diff --git a/colossalai/pipeline/schedule/one_f_one_b.py b/colossalai/pipeline/schedule/one_f_one_b.py index 58008b98f..bfea8b67d 100644 --- a/colossalai/pipeline/schedule/one_f_one_b.py +++ b/colossalai/pipeline/schedule/one_f_one_b.py @@ -7,7 +7,7 @@ from torch.nn import Module from torch.utils._pytree import tree_map from colossalai.accelerator import get_accelerator -from colossalai.interface import OptimizerWrapper +from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.utils import get_current_device @@ -327,7 +327,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): self.send_forward(output_obj) if outputs is not None: - outputs = merge_batch(outputs) + if isinstance(model, ModelWrapper): + model = model.unwrap() + batch_size_dim = getattr(model, "batch_size_dim", 0) + outputs = merge_batch(outputs, batch_size_dim) return {"loss": accum_loss, "outputs": outputs} def run_forward_backward( @@ -410,7 +413,10 @@ class OneForwardOneBackwardSchedule(PipelineSchedule): assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs) if outputs is not None: - outputs = merge_batch(outputs) + if isinstance(model, ModelWrapper): + model = model.unwrap() + batch_size_dim = getattr(model, "batch_size_dim", 0) + outputs = merge_batch(outputs, batch_size_dim) return {"loss": accum_loss, "outputs": outputs} def forward_backward_step( diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index c8670affb..d45421868 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer: - [x] Unit Testing - [ ] Policy Implementation -| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | -| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: | -| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | -| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | -| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | -| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | -| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | -| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | -| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | -| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | -| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap | +|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:| +| bert | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | +| t5 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| llama V1/V2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| gpt2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | +| opt | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| bloom | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | +| chatglm2 | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | [√] | +| vit | [√] | [√] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| whisper | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | +| sam | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| blip2 | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | +| falcon | [√] | [√] | [√] | [√] | [√] | [ ] | [√] | [ ] | [ ] | +| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | +| mistral | [√] | [ ] | [ ] | [√] | [√] | [√] | [√] | [ ] | [ ] | ## 💡 API Design @@ -391,6 +391,43 @@ _POLICY_LIST = { } ``` +#### How to support those models in huggingface model hub but not in the transformers library + +There are two cases: + +1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B. +2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b". + +Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`. + +Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself. + +E.g. for llama: +```python +policy[LlamaDecoderLayer] = ModulePolicyDescription(...) +``` + +for chatglm2: +```python +policy["GLMBlock"] = ModulePolicyDescription(...) +``` + +Then when registering such models in the autopolicy, we should follow below format: +```python +"transformers_modules..": PolicyLocation( + file_name="", class_name="" +) +``` + +As for chatglm2 model, it should be: +```python +"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( + file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" +) +``` + +When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy. + ### Write Your Unit Testing This section serves as the guideline for testing the `shardformer` module. @@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length. In the case of using 2 GPUs, the training times are as follows. -| N_CTX | org_model | shard_model | -| :------: | :-----: | :-----: | -| 256 | 11.2ms | 17.2ms | -| 512 | 9.8ms | 19.5ms | -| 1024 | 19.6ms | 18.9ms | -| 2048 | 46.6ms | 30.8ms | -| 4096 | 160.5ms | 90.4ms | +| N_CTX | org_model | shard_model | +|:-----:|:---------:|:-----------:| +| 256 | 11.2ms | 17.2ms | +| 512 | 9.8ms | 19.5ms | +| 1024 | 19.6ms | 18.9ms | +| 2048 | 46.6ms | 30.8ms | +| 4096 | 160.5ms | 90.4ms |

@@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows. In the case of using 4 GPUs, the training times are as follows. -| N_CTX | org_model | shard_model | -| :------: | :-----: | :-----: | -| 256 | 10.0ms | 21.1ms | -| 512 | 11.5ms | 20.2ms | -| 1024 | 22.1ms | 20.6ms | -| 2048 | 46.9ms | 24.8ms | -| 4096 | 160.4ms | 68.0ms | +| N_CTX | org_model | shard_model | +|:-----:|:---------:|:-----------:| +| 256 | 10.0ms | 21.1ms | +| 512 | 11.5ms | 20.2ms | +| 1024 | 22.1ms | 20.6ms | +| 2048 | 46.9ms | 24.8ms | +| 4096 | 160.4ms | 68.0ms | @@ -475,10 +512,10 @@ warmup_fraction = 0.03 | accuracy | f1 | loss | GPU number | model sharded | -| :------: | :-----: | :-----: | :--------: | :---------: | -| 0.82971 | 0.87713 | 0.23194 | 4 | True | -| 0.83797 | 0.88006 | 0.22683 | 2 | True | -| 0.84521 | 0.88700 | 0.21822 | 1 | False | +|:--------:|:-------:|:-------:|:----------:|:-------------:| +| 0.82971 | 0.87713 | 0.23194 | 4 | True | +| 0.83797 | 0.88006 | 0.22683 | 2 | True | +| 0.84521 | 0.88700 | 0.21822 | 1 | False | Overall, the results demonstrate that using shardformers during model training does not affect the convergence. diff --git a/colossalai/shardformer/layer/normalization.py b/colossalai/shardformer/layer/normalization.py index bba4bd070..5aa212600 100644 --- a/colossalai/shardformer/layer/normalization.py +++ b/colossalai/shardformer/layer/normalization.py @@ -281,19 +281,16 @@ class FusedRMSNorm(BaseLayerNorm): ) LazyInitContext.materialize(module) - # to check if it is huggingface LlamaRMSNorm or MistralRMSNorm - if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]: - normalized_shape = module.weight.shape[0] - eps = module.variance_epsilon - elementwise_affine = True - else: - # get the attributes of the module - normalized_shape = module.normalized_shape - eps = module.eps - elementwise_affine = module.elementwise_affine + + # try to get normalized_shape, eps, elementwise_affine from the module + normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0]) + eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps + elementwise_affine = getattr(module, "elementwise_affine", True) rmsnorm = FusedRMSNormWithHook( - normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine + normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, ) rmsnorm.weight = module.weight diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 9207b34d0..53c151f02 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -12,7 +12,6 @@ from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.shardformer import ShardConfig from colossalai.shardformer.layer import AttnMaskType, ColoAttention from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel def get_flash_core_attention_forward(): @@ -31,7 +30,12 @@ def get_flash_core_attention_forward(): device=query_layer.device, ) temp_mask = ( - torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device) + torch.ones( + query_layer.shape[2], + key_layer.shape[2], + dtype=torch.bool, + device=query_layer.device, + ) .tril(diagonal=0) .expand(query_layer.shape[0], 1, -1, -1) ) @@ -49,6 +53,7 @@ def get_flash_core_attention_forward(): attention_mask=attn_bias, attention_mask_type=attention_mask_type, dropout_p=dropout_p, + scale=1.0 / self.norm_factor, ) context_layer = context_layer.permute(2, 0, 1, 3) new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) @@ -115,7 +120,7 @@ class ChatGLMPipelineForwards: @staticmethod def chatglm_model_forward( - self: ChatGLMModel, + self: "ChatGLMModel", input_ids, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.BoolTensor] = None, @@ -194,7 +199,9 @@ class ChatGLMPipelineForwards: if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = split_forward_gather_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) for idx in range(start_idx, end_idx): layer = self.encoder._get_layer(idx) @@ -224,7 +231,9 @@ class ChatGLMPipelineForwards: if shard_config and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode == "split_gather": hidden_states = gather_forward_split_backward( - hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group + hidden_states, + dim=0, + process_group=shard_config.tensor_parallel_process_group, ) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) @@ -254,7 +263,7 @@ class ChatGLMPipelineForwards: @staticmethod def chatglm_for_conditional_generation_forward( - self: ChatGLMForConditionalGeneration, + self: "ChatGLMForConditionalGeneration", input_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py index 0991ace2c..d2b582af5 100644 --- a/colossalai/shardformer/policies/auto_policy.py +++ b/colossalai/shardformer/policies/auto_policy.py @@ -151,10 +151,10 @@ _POLICY_LIST = { file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy" ), # ChatGLM - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation( + "transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation( file_name="chatglm2", class_name="ChatGLMModelPolicy" ), - "colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( + "transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation( file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy" ), # Falcon @@ -202,6 +202,13 @@ def _fullname(obj): module = klass.__module__ if module == "builtins": return klass.__qualname__ # avoid outputs like 'builtins.str' + # patch custom models which are not in transformers + # it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub) + # or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory) + if module.startswith("transformers_modules"): + split_module = module.split(".") + if len(split_module) >= 2: + module = f"{split_module[0]}.{split_module[-1]}" return module + "." + klass.__qualname__ @@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy: if policy_location is None: raise NotImplementedError( - f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" + f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}" ) else: policy = import_policy(policy_location) diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index f205835e7..4baf89f6a 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -7,7 +7,6 @@ from torch import Tensor import colossalai.shardformer.layer as col_nn from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel from ..modeling.chatglm2 import ( get_chatglm_sequence_parallel_forward_fn, @@ -17,7 +16,11 @@ from ..modeling.chatglm2 import ( from ..modeling.jit import get_jit_fused_dropout_add_func from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription -__all__ = ["ChatGLMPolicy", "ChatGLMModelPolicy", "ChatGLMForConditionalGenerationPolicy"] +__all__ = [ + "ChatGLMPolicy", + "ChatGLMModelPolicy", + "ChatGLMForConditionalGenerationPolicy", +] class ChatGLMPolicy(Policy): @@ -34,8 +37,6 @@ class ChatGLMPolicy(Policy): return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMModel, CoreAttention, GLMBlock - policy = {} embedding_cls = None @@ -67,7 +68,27 @@ class ChatGLMPolicy(Policy): sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: - policy[GLMBlock] = ModulePolicyDescription( + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"num_attention_heads {self.model.config.num_attention_heads} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}" + attn_kwargs = { + "self_attention.qkv_hidden_size": ( + self.model.config.kv_channels * self.model.config.num_attention_heads * 3 + ) + // self.shard_config.tensor_parallel_size, + } + if self.model.config.multi_query_attention: + assert ( + self.model.config.multi_query_group_num % self.shard_config.tensor_parallel_size == 0 + ), f"multi_query_group_num {self.model.config.multi_query_group_num} should be divisible by tensor_parallel_size {self.shard_config.tensor_parallel_size}" + attn_kwargs["self_attention.num_multi_query_groups_per_partition"] = ( + self.model.config.multi_query_group_num // self.shard_config.tensor_parallel_size + ) + attn_kwargs["self_attention.qkv_hidden_size"] = ( + self.model.config.kv_channels * self.model.config.num_attention_heads + + 2 * self.model.config.kv_channels * self.model.config.multi_query_group_num + ) // self.shard_config.tensor_parallel_size + policy["GLMBlock"] = ModulePolicyDescription( attribute_replacement={ "self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, @@ -75,22 +96,23 @@ class ChatGLMPolicy(Policy): self.model.config.kv_channels * self.model.config.num_attention_heads ) // self.shard_config.tensor_parallel_size, - "self_attention.qkv_hidden_size": ( - self.model.config.kv_channels * self.model.config.num_attention_heads * 3 - ) - // self.shard_config.tensor_parallel_size, "self_attention.core_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, "self_attention.core_attention.hidden_size_per_partition": self.model.config.kv_channels * self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, + **attn_kwargs, }, param_replacement=[], sub_module_replacement=[ SubModuleReplacementDescription( suffix="self_attention.query_key_value", target_module=col_nn.Linear1D_Col, - kwargs={"seq_parallel_mode": sp_mode, "seq_parallel_dim": 0, "overlap": overlap}, + kwargs={ + "seq_parallel_mode": sp_mode, + "seq_parallel_dim": 0, + "overlap": overlap, + }, ), SubModuleReplacementDescription( suffix="self_attention.dense", @@ -114,7 +136,7 @@ class ChatGLMPolicy(Policy): ), ], policy=policy, - target_key=ChatGLMModel, + target_key="ChatGLMModel", ) # optimization configuration self.append_or_create_submodule_replacement( @@ -131,7 +153,7 @@ class ChatGLMPolicy(Policy): ), ], policy=policy, - target_key=GLMBlock, + target_key="GLMBlock", ) if self.model.config.post_layer_norm: @@ -143,7 +165,7 @@ class ChatGLMPolicy(Policy): ) ], policy=policy, - target_key=ChatGLMModel, + target_key="ChatGLMModel", ) # use flash attention @@ -153,7 +175,7 @@ class ChatGLMPolicy(Policy): "forward": get_flash_core_attention_forward(), }, policy=policy, - target_key=CoreAttention, + target_key="CoreAttention", ) # use sequence parallel @@ -161,7 +183,7 @@ class ChatGLMPolicy(Policy): self.append_or_create_method_replacement( description={"forward": get_chatglm_sequence_parallel_forward_fn(self.shard_config)}, policy=policy, - target_key=ChatGLMModel, + target_key="ChatGLMModel", ) # use jit fused operator @@ -172,7 +194,7 @@ class ChatGLMPolicy(Policy): "dropout_add": get_jit_fused_dropout_add_func(), }, policy=policy, - target_key=GLMBlock, + target_key="GLMBlock", ) return policy @@ -220,7 +242,10 @@ class ChatGLMPolicy(Policy): stage_index = stage_manager.get_stage_index(layers_per_stage) method_replacement = { "forward": partial( - new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config + new_forward, + stage_manager=stage_manager, + stage_index=stage_index, + shard_config=self.shard_config, ) } self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) @@ -234,7 +259,9 @@ class ChatGLMModelPolicy(ChatGLMPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=ChatGLMModel, new_forward=ChatGLMPipelineForwards.chatglm_model_forward, policy=policy + model_cls="ChatGLMModel", + new_forward=ChatGLMPipelineForwards.chatglm_model_forward, + policy=policy, ) return policy @@ -252,7 +279,7 @@ class ChatGLMForConditionalGenerationPolicy(ChatGLMModelPolicy): if self.pipeline_stage_manager is not None: self.set_pipeline_forward( - model_cls=ChatGLMForConditionalGeneration, + model_cls="ChatGLMForConditionalGeneration", new_forward=ChatGLMPipelineForwards.chatglm_for_conditional_generation_forward, policy=policy, ) diff --git a/docs/source/en/features/shardformer.md b/docs/source/en/features/shardformer.md index 672945ea2..68d310f5c 100644 --- a/docs/source/en/features/shardformer.md +++ b/docs/source/en/features/shardformer.md @@ -310,13 +310,6 @@ if dist.get_world_size() > 1: 2. When you use Shardformer to process classification models such as `GPT2ForSequenceClassification`, `ViTForImageClassification`, please ensure that the total number of labels should be integer multiple of tensor parallel size, otherwise Shardformer can't process the classifier layer correctly. A simple fix could be appending dummy labels in transformers config. This bug will be fixed in future version of Shardformer. -3. The case of training ChatGLM-2 6B is a little special: since Huggingface transformers doesn't officially support ChatGLM at present, please import the configuration/model classes through - ```python - from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel - ``` - when training ChatGLM-2 with Shardformer, and initialize your model with these imported classes. - ## How Shardformer Works ### Main Idea diff --git a/docs/source/zh-Hans/features/shardformer.md b/docs/source/zh-Hans/features/shardformer.md index a7bcbd9f2..a42c7cc2e 100644 --- a/docs/source/zh-Hans/features/shardformer.md +++ b/docs/source/zh-Hans/features/shardformer.md @@ -303,13 +303,6 @@ if dist.get_world_size() > 1: 2. 当使用Shardformer处理`GPT2ForSequenceClassification`、`ViTForImageClassification`等分类模型时,请确保labels的总数为张量并行度的整数倍,否则Shardformer无法正确地处理classifier层。一个简单的修复方法就是在transformers的config中添加虚拟的标签。这一bug将在 Shardformer的未来版本中修复。 -3. 训练ChatGLM-2 6B的情况有点特殊:由于Huggingface Transformers 目前尚未正式支持ChatGLM。在使用Shardformer训练ChatGLM-2时,请通过以下方式导入config/model的类: - ```python - from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig - from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel - ``` - 并且使用这些导入的类初始化模型。 - ## Shardformer的工作原理 diff --git a/tests/kit/model_zoo/transformers/chatglm2.py b/tests/kit/model_zoo/transformers/chatglm2.py index 0b178d58c..f443553bb 100644 --- a/tests/kit/model_zoo/transformers/chatglm2.py +++ b/tests/kit/model_zoo/transformers/chatglm2.py @@ -1,7 +1,6 @@ import torch - -from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig -from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel +from torch.nn import init +from transformers import AutoConfig, AutoModelForCausalLM from ..registry import ModelAttribute, model_zoo @@ -34,19 +33,26 @@ loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( ) loss_fn = lambda x: x["loss"] -config = ChatGLMConfig( +config = AutoConfig.from_pretrained( + "THUDM/chatglm2-6b", + trust_remote_code=True, num_layers=2, padded_vocab_size=65024, hidden_size=64, + ffn_hidden_size=214, num_attention_heads=8, kv_channels=16, rmsnorm=True, original_rope=True, use_cache=True, + multi_query_attention=False, torch_dtype=torch.float32, ) -infer_config = ChatGLMConfig( + +infer_config = AutoConfig.from_pretrained( + "THUDM/chatglm2-6b", + trust_remote_code=True, num_layers=2, padded_vocab_size=65024, hidden_size=128, @@ -60,18 +66,18 @@ infer_config = ChatGLMConfig( torch_dtype=torch.float32, ) -model_zoo.register( - name="transformers_chatglm", - model_fn=lambda: ChatGLMModel(config, empty_init=False), - data_gen_fn=data_gen, - output_transform_fn=output_transform_fn, - loss_fn=loss_fn_for_chatglm_model, - model_attribute=ModelAttribute(has_control_flow=True), -) + +def init_chatglm(): + model = AutoModelForCausalLM.from_config(config, empty_init=False, trust_remote_code=True) + for m in model.modules(): + if m.__class__.__name__ == "RMSNorm": + init.ones_(m.weight) + return model + model_zoo.register( name="transformers_chatglm_for_conditional_generation", - model_fn=lambda: ChatGLMForConditionalGeneration(config, empty_init=False), + model_fn=init_chatglm, data_gen_fn=data_gen_for_conditional_generation, output_transform_fn=output_transform_fn, loss_fn=loss_fn, diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index a77ba39a1..1835a5c8e 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -227,7 +227,7 @@ def check_output_hidden_state( def check_loss(org_loss: Tensor, sharded_loss: Tensor, atol: float = 1e-5, rtol: float = 1e-3): - assert torch.allclose(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol) + assert_close(org_loss.float(), sharded_loss.float(), atol=atol, rtol=rtol) def check_weight( diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 405ceba32..376d315c1 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -11,6 +11,7 @@ from tests.test_shardformer.test_model._utils import ( build_model_from_hybrid_plugin, check_all_grad_tensors, check_loss, + check_output_hidden_state, check_weight, get_grad_tensors_for_check, run_forward_backward_with_hybrid_plugin, @@ -103,8 +104,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, atol, rtol = 5e-3, 5e-3 # TODO: ChatGLMModel output is [S, B, H], merging batch of pipeline is wrong - # if org_model.__class__.__name__ == "ChatGLMModel": - # check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) + if org_model.__class__.__name__ == "ChatGLMModel": + check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) @@ -177,14 +178,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, { "tp_size": 4, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", }, { "tp_size": 2, "pp_size": 1, - "enable_all_optimization": True, + "enable_all_optimization": False, "use_lazy_init": False, "precision": "fp32", },