[pipeline] Llama causal lm and llama for sequence classification pipeline (#4208)

* bloom policy

* llama pipeline forward and tests

* fix the output and attention_mask

* fix name

* bind argument to policy

* Revert "bloom policy"

This reverts commit 8dee68a0a2.

This policy should be revert and copied to feature/bloom

* revert the bloom changes

* cancel unneeded inputs

* gpt

* finish llama

* causal lm and sequence classification

* revision
pull/4445/head
Jianghai 2023-07-11 15:23:33 +08:00 committed by Hongxin Liu
parent 1622031058
commit 31bcf867ae
4 changed files with 109 additions and 21 deletions

View File

@ -162,6 +162,24 @@ class Policy(ABC):
return policy return policy
def append_or_create_method_replacement(
self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription],
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Append or create a new method replacement description to the policy for the given key.
Args:
description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
target_key (Union[str, nn.Module]): the key of the policy to be updated
"""
if target_key in policy:
policy[target_key].method_replacement.update(description)
else:
policy[target_key] = ModulePolicyDescription(method_replacement=description)
return policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get layers that should be held in current stage. This method should be implemented by subclass. """Get layers that should be held in current stage. This method should be implemented by subclass.

View File

@ -131,17 +131,20 @@ class LlamaModelPolicy(LlamaPolicy):
super().__init__() super().__init__()
def module_policy(self): def module_policy(self):
module_policy = super().module_policy() policy = super().module_policy()
from transformers.models.llama.modeling_llama import LlamaModel from transformers.models.llama.modeling_llama import LlamaModel
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
# set None as default # set None as default
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
module_policy[LlamaModel] = ModulePolicyDescription(method_replacement={ method_replacement = {
'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index) 'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index)
}) }
return module_policy self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaModel)
return policy
def get_held_layers(self) -> List[Module]: def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
@ -158,7 +161,7 @@ class LlamaModelPolicy(LlamaPolicy):
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in bert model""" """No shared params in llama model"""
return [] return []
@ -179,8 +182,43 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
]) ])
} }
policy.update(new_item) policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward': partial(llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaForCausalLM)
return policy return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.model.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.model.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.model.norm)
held_layers.append(module.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama model"""
llama_model = self.model.model
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
# tie weights
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}]
return []
class LlamaForSequenceClassificationPolicy(LlamaPolicy): class LlamaForSequenceClassificationPolicy(LlamaPolicy):
@ -199,8 +237,42 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
]) ])
} }
policy.update(new_item) policy.update(new_item)
# to be confirmed
if self.pipeline_stage_manager:
# set None as default
stage_manager = self.pipeline_stage_manager
layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {
'forward':
partial(llama_for_sequence_classification_forward,
stage_manager=stage_manager,
stage_index=stage_index)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=LlamaForSequenceClassification)
return policy return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
module = self.model
stage_manager = self.pipeline_stage_manager
held_layers = []
layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages)
if stage_manager.is_first_stage():
held_layers.append(module.model.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
held_layers.extend(module.model.layers[start_idx:end_idx])
if stage_manager.is_last_stage():
held_layers.append(module.model.norm)
held_layers.append(module.score)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"""No shared params in llama for sequence classification model"""
return []
def llama_model_forward( def llama_model_forward(
self: LlamaModel, self: LlamaModel,

View File

@ -52,7 +52,7 @@ loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
loss_fn = lambda x: x.loss loss_fn = lambda x: x.loss
config = transformers.GPT2Config(n_layer=2, config = transformers.GPT2Config(n_layer=2,
n_head=2, n_head=4,
vocab_size=50258, vocab_size=50258,
attn_pdrop=0, attn_pdrop=0,
embd_pdrop=0, embd_pdrop=0,

View File

@ -49,21 +49,19 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
x = torch.randint(0, 1000, (2, 3)).cuda() x = torch.randint(0, 1000, (2, 3)).cuda()
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda() hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
if name == 'transformers_llama': org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization, enable_tensor_parallelism, use_lazy_init)
enable_tensor_parallelism, use_lazy_init) if stage_manager.stage == 0:
if stage_manager.stage == 0: attention_mask = torch.ones_like(x).cuda()
attention_mask = torch.ones_like(x).cuda() output = sharded_model(input_ids=x, attention_mask=attention_mask)
output = sharded_model(input_ids=x, attention_mask=attention_mask) assert output['hidden_states'].shape == (2, 3, 128)
assert output['hidden_states'].shape == (2, 3, 128) else:
else: attention_mask = torch.ones((2, 3)).cuda()
attention_mask = torch.ones((2, 3)).cuda() output = sharded_model(
output = sharded_model( hidden_states=hidden_states,
hidden_states=hidden_states, attention_mask=attention_mask,
attention_mask=attention_mask, )
) assert output[0] is not None
# print(output[0].shape)
assert output[0].shape == (2, 3, 128)
torch.cuda.empty_cache() torch.cuda.empty_cache()