From 31bcf867aeb8efb5859683e3c727c646063748dc Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Tue, 11 Jul 2023 15:23:33 +0800 Subject: [PATCH] [pipeline] Llama causal lm and llama for sequence classification pipeline (#4208) * bloom policy * llama pipeline forward and tests * fix the output and attention_mask * fix name * bind argument to policy * Revert "bloom policy" This reverts commit 8dee68a0a22568dbeed6d4563372b25e1e825fb0. This policy should be revert and copied to feature/bloom * revert the bloom changes * cancel unneeded inputs * gpt * finish llama * causal lm and sequence classification * revision --- .../shardformer/policies/base_policy.py | 18 ++++ colossalai/shardformer/policies/llama.py | 82 +++++++++++++++++-- tests/kit/model_zoo/transformers/gpt.py | 2 +- .../test_model/test_shard_llama_pipeline.py | 28 +++---- 4 files changed, 109 insertions(+), 21 deletions(-) diff --git a/colossalai/shardformer/policies/base_policy.py b/colossalai/shardformer/policies/base_policy.py index aac86eb20..68fde0115 100644 --- a/colossalai/shardformer/policies/base_policy.py +++ b/colossalai/shardformer/policies/base_policy.py @@ -162,6 +162,24 @@ class Policy(ABC): return policy + def append_or_create_method_replacement( + self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription], + target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: + r""" + Append or create a new method replacement description to the policy for the given key. + + Args: + description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended + policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated + target_key (Union[str, nn.Module]): the key of the policy to be updated + """ + if target_key in policy: + policy[target_key].method_replacement.update(description) + else: + policy[target_key] = ModulePolicyDescription(method_replacement=description) + + return policy + def get_held_layers(self) -> List[Module]: """Get layers that should be held in current stage. This method should be implemented by subclass. diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index b2b647018..a3ea80726 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -131,17 +131,20 @@ class LlamaModelPolicy(LlamaPolicy): super().__init__() def module_policy(self): - module_policy = super().module_policy() + policy = super().module_policy() from transformers.models.llama.modeling_llama import LlamaModel if self.pipeline_stage_manager: # set None as default stage_manager = self.pipeline_stage_manager layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - module_policy[LlamaModel] = ModulePolicyDescription(method_replacement={ + method_replacement = { 'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index) - }) - return module_policy + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaModel) + return policy def get_held_layers(self) -> List[Module]: """Get pipeline layers for current stage.""" @@ -158,7 +161,7 @@ class LlamaModelPolicy(LlamaPolicy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - """No shared params in bert model""" + """No shared params in llama model""" return [] @@ -179,8 +182,43 @@ class LlamaForCausalLMPolicy(LlamaPolicy): ]) } policy.update(new_item) + + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': partial(llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaForCausalLM) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.model.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.model.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.model.norm) + held_layers.append(module.lm_head) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama model""" + llama_model = self.model.model + if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight): + # tie weights + return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}] + return [] + class LlamaForSequenceClassificationPolicy(LlamaPolicy): @@ -199,8 +237,42 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy): ]) } policy.update(new_item) + # to be confirmed + if self.pipeline_stage_manager: + # set None as default + stage_manager = self.pipeline_stage_manager + layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages) + stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) + method_replacement = { + 'forward': + partial(llama_for_sequence_classification_forward, + stage_manager=stage_manager, + stage_index=stage_index) + } + self.append_or_create_method_replacement(description=method_replacement, + policy=policy, + target_key=LlamaForSequenceClassification) return policy + def get_held_layers(self) -> List[Module]: + """Get pipeline layers for current stage.""" + module = self.model + stage_manager = self.pipeline_stage_manager + held_layers = [] + layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages) + if stage_manager.is_first_stage(): + held_layers.append(module.model.embed_tokens) + start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) + held_layers.extend(module.model.layers[start_idx:end_idx]) + if stage_manager.is_last_stage(): + held_layers.append(module.model.norm) + held_layers.append(module.score) + return held_layers + + def get_shared_params(self) -> List[Dict[int, Tensor]]: + """No shared params in llama for sequence classification model""" + return [] + def llama_model_forward( self: LlamaModel, diff --git a/tests/kit/model_zoo/transformers/gpt.py b/tests/kit/model_zoo/transformers/gpt.py index ac70138e3..b9e031078 100644 --- a/tests/kit/model_zoo/transformers/gpt.py +++ b/tests/kit/model_zoo/transformers/gpt.py @@ -52,7 +52,7 @@ loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean() loss_fn = lambda x: x.loss config = transformers.GPT2Config(n_layer=2, - n_head=2, + n_head=4, vocab_size=50258, attn_pdrop=0, embd_pdrop=0, diff --git a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py index 81c183d32..8fd9ed099 100644 --- a/tests/test_shardformer/test_model/test_shard_llama_pipeline.py +++ b/tests/test_shardformer/test_model/test_shard_llama_pipeline.py @@ -49,21 +49,19 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la x = torch.randint(0, 1000, (2, 3)).cuda() hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - if name == 'transformers_llama': - org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, - enable_tensor_parallelism, use_lazy_init) - if stage_manager.stage == 0: - attention_mask = torch.ones_like(x).cuda() - output = sharded_model(input_ids=x, attention_mask=attention_mask) - assert output['hidden_states'].shape == (2, 3, 128) - else: - attention_mask = torch.ones((2, 3)).cuda() - output = sharded_model( - hidden_states=hidden_states, - attention_mask=attention_mask, - ) - # print(output[0].shape) - assert output[0].shape == (2, 3, 128) + org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, + enable_tensor_parallelism, use_lazy_init) + if stage_manager.stage == 0: + attention_mask = torch.ones_like(x).cuda() + output = sharded_model(input_ids=x, attention_mask=attention_mask) + assert output['hidden_states'].shape == (2, 3, 128) + else: + attention_mask = torch.ones((2, 3)).cuda() + output = sharded_model( + hidden_states=hidden_states, + attention_mask=attention_mask, + ) + assert output[0] is not None torch.cuda.empty_cache()