mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] Llama causal lm and llama for sequence classification pipeline (#4208)
* bloom policy
* llama pipeline forward and tests
* fix the output and attention_mask
* fix name
* bind argument to policy
* Revert "bloom policy"
This reverts commit 8dee68a0a2
.
This policy should be revert and copied to feature/bloom
* revert the bloom changes
* cancel unneeded inputs
* gpt
* finish llama
* causal lm and sequence classification
* revision
pull/4445/head
parent
1622031058
commit
31bcf867ae
|
@ -162,6 +162,24 @@ class Policy(ABC):
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
def append_or_create_method_replacement(
|
||||||
|
self, description: Dict[str, Callable], policy: Dict[Union[str, nn.Module], ModulePolicyDescription],
|
||||||
|
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||||
|
r"""
|
||||||
|
Append or create a new method replacement description to the policy for the given key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
description (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
|
||||||
|
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
|
||||||
|
target_key (Union[str, nn.Module]): the key of the policy to be updated
|
||||||
|
"""
|
||||||
|
if target_key in policy:
|
||||||
|
policy[target_key].method_replacement.update(description)
|
||||||
|
else:
|
||||||
|
policy[target_key] = ModulePolicyDescription(method_replacement=description)
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
def get_held_layers(self) -> List[Module]:
|
def get_held_layers(self) -> List[Module]:
|
||||||
"""Get layers that should be held in current stage. This method should be implemented by subclass.
|
"""Get layers that should be held in current stage. This method should be implemented by subclass.
|
||||||
|
|
||||||
|
|
|
@ -131,17 +131,20 @@ class LlamaModelPolicy(LlamaPolicy):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def module_policy(self):
|
def module_policy(self):
|
||||||
module_policy = super().module_policy()
|
policy = super().module_policy()
|
||||||
from transformers.models.llama.modeling_llama import LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
if self.pipeline_stage_manager:
|
if self.pipeline_stage_manager:
|
||||||
# set None as default
|
# set None as default
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages)
|
layers_per_stage = Policy.distribute_layers(len(self.model.layers), stage_manager.num_stages)
|
||||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
module_policy[LlamaModel] = ModulePolicyDescription(method_replacement={
|
method_replacement = {
|
||||||
'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index)
|
'forward': partial(llama_model_forward, stage_manager=stage_manager, stage_index=stage_index)
|
||||||
})
|
}
|
||||||
return module_policy
|
self.append_or_create_method_replacement(description=method_replacement,
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaModel)
|
||||||
|
return policy
|
||||||
|
|
||||||
def get_held_layers(self) -> List[Module]:
|
def get_held_layers(self) -> List[Module]:
|
||||||
"""Get pipeline layers for current stage."""
|
"""Get pipeline layers for current stage."""
|
||||||
|
@ -158,7 +161,7 @@ class LlamaModelPolicy(LlamaPolicy):
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
"""No shared params in bert model"""
|
"""No shared params in llama model"""
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@ -179,8 +182,43 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
|
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages)
|
||||||
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
method_replacement = {
|
||||||
|
'forward': partial(llama_for_causal_lm_forward, stage_manager=stage_manager, stage_index=stage_index)
|
||||||
|
}
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement,
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaForCausalLM)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
module = self.model
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
held_layers = []
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages)
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.model.embed_tokens)
|
||||||
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
held_layers.extend(module.model.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.model.norm)
|
||||||
|
held_layers.append(module.lm_head)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in llama model"""
|
||||||
|
llama_model = self.model.model
|
||||||
|
if id(llama_model.embed_tokens.weight) == id(self.model.lm_head.weight):
|
||||||
|
# tie weights
|
||||||
|
return [{0: llama_model.embed_tokens.weight, self.stage_manager.num_stages - 1: self.model.lm_head.weight}]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
|
|
||||||
|
@ -199,8 +237,42 @@ class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
])
|
])
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
|
# to be confirmed
|
||||||
|
if self.pipeline_stage_manager:
|
||||||
|
# set None as default
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
layers_per_stage = Policy.distribute_layers(len(self.model.model.layers), stage_manager.num_stages)
|
||||||
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
method_replacement = {
|
||||||
|
'forward':
|
||||||
|
partial(llama_for_sequence_classification_forward,
|
||||||
|
stage_manager=stage_manager,
|
||||||
|
stage_index=stage_index)
|
||||||
|
}
|
||||||
|
self.append_or_create_method_replacement(description=method_replacement,
|
||||||
|
policy=policy,
|
||||||
|
target_key=LlamaForSequenceClassification)
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
def get_held_layers(self) -> List[Module]:
|
||||||
|
"""Get pipeline layers for current stage."""
|
||||||
|
module = self.model
|
||||||
|
stage_manager = self.pipeline_stage_manager
|
||||||
|
held_layers = []
|
||||||
|
layers_per_stage = self.distribute_layers(len(module.model.layers), stage_manager.num_stages)
|
||||||
|
if stage_manager.is_first_stage():
|
||||||
|
held_layers.append(module.model.embed_tokens)
|
||||||
|
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
|
held_layers.extend(module.model.layers[start_idx:end_idx])
|
||||||
|
if stage_manager.is_last_stage():
|
||||||
|
held_layers.append(module.model.norm)
|
||||||
|
held_layers.append(module.score)
|
||||||
|
return held_layers
|
||||||
|
|
||||||
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
|
"""No shared params in llama for sequence classification model"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self: LlamaModel,
|
self: LlamaModel,
|
||||||
|
|
|
@ -52,7 +52,7 @@ loss_fn_for_gpt2_model = lambda x: x.last_hidden_state.mean()
|
||||||
loss_fn = lambda x: x.loss
|
loss_fn = lambda x: x.loss
|
||||||
|
|
||||||
config = transformers.GPT2Config(n_layer=2,
|
config = transformers.GPT2Config(n_layer=2,
|
||||||
n_head=2,
|
n_head=4,
|
||||||
vocab_size=50258,
|
vocab_size=50258,
|
||||||
attn_pdrop=0,
|
attn_pdrop=0,
|
||||||
embd_pdrop=0,
|
embd_pdrop=0,
|
||||||
|
|
|
@ -49,21 +49,19 @@ def run_llama_test(enable_fused_normalization, enable_tensor_parallelism, use_la
|
||||||
x = torch.randint(0, 1000, (2, 3)).cuda()
|
x = torch.randint(0, 1000, (2, 3)).cuda()
|
||||||
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
hidden_states = torch.randint(0, 1000, (2, 3, 128)).to(torch.float32).cuda()
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
if name == 'transformers_llama':
|
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
||||||
org_model, sharded_model = build_pipeline_model(model_fn, stage_manager, enable_fused_normalization,
|
enable_tensor_parallelism, use_lazy_init)
|
||||||
enable_tensor_parallelism, use_lazy_init)
|
if stage_manager.stage == 0:
|
||||||
if stage_manager.stage == 0:
|
attention_mask = torch.ones_like(x).cuda()
|
||||||
attention_mask = torch.ones_like(x).cuda()
|
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
||||||
output = sharded_model(input_ids=x, attention_mask=attention_mask)
|
assert output['hidden_states'].shape == (2, 3, 128)
|
||||||
assert output['hidden_states'].shape == (2, 3, 128)
|
else:
|
||||||
else:
|
attention_mask = torch.ones((2, 3)).cuda()
|
||||||
attention_mask = torch.ones((2, 3)).cuda()
|
output = sharded_model(
|
||||||
output = sharded_model(
|
hidden_states=hidden_states,
|
||||||
hidden_states=hidden_states,
|
attention_mask=attention_mask,
|
||||||
attention_mask=attention_mask,
|
)
|
||||||
)
|
assert output[0] is not None
|
||||||
# print(output[0].shape)
|
|
||||||
assert output[0].shape == (2, 3, 128)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue