mirror of https://github.com/hpcaitech/ColossalAI
[format] applied code formatting on changed files in pull request 4820 (#4886)
Co-authored-by: github-actions <github-actions@github.com>pull/4990/head
parent
c7aa319ba0
commit
486d06a2d5
|
@ -1,3 +1,3 @@
|
|||
from .pipeline import PPInferEngine
|
||||
|
||||
__all__ = ['PPInferEngine']
|
||||
__all__ = ["PPInferEngine"]
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
from .engine import PPInferEngine
|
||||
|
||||
__all__ = ['PPInferEngine']
|
||||
__all__ = ["PPInferEngine"]
|
||||
|
|
|
@ -1,28 +1,32 @@
|
|||
import argparse
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
|
||||
import colossalai
|
||||
import time
|
||||
from colossalai.inference import PPInferEngine
|
||||
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
|
||||
import argparse
|
||||
GIGABYTE = 1024 ** 3
|
||||
|
||||
GIGABYTE = 1024**3
|
||||
MEGABYTE = 1024 * 1024
|
||||
|
||||
colossalai.launch_from_torch(config={})
|
||||
|
||||
def data_gen(batch_size: int=4, seq_len: int=512):
|
||||
|
||||
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
|
||||
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
|
||||
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
for k, v in data.items():
|
||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = batch_size
|
||||
data[k] = v.to('cuda').repeat(*new_shape)
|
||||
data[k] = v.to("cuda").repeat(*new_shape)
|
||||
return data
|
||||
|
||||
|
||||
def print_details_info(timestamps, model_config, args, whole_end2end):
|
||||
if dist.get_rank() == 0:
|
||||
prefill = []
|
||||
|
@ -31,32 +35,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
|
|||
for timestamp in timestamps:
|
||||
prefill.append(timestamp[1] - timestamp[0])
|
||||
encoder.append(
|
||||
sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2))
|
||||
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
|
||||
)
|
||||
end2end.append(timestamp[-1] - timestamp[0])
|
||||
print(whole_end2end)
|
||||
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f:
|
||||
mb_avg_end2end = sum(end2end)/len(end2end)
|
||||
mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size)
|
||||
whole_avg_latency = whole_end2end/(args.new_length * args.batch_size)
|
||||
with open(
|
||||
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
|
||||
"w+",
|
||||
) as f:
|
||||
mb_avg_end2end = sum(end2end) / len(end2end)
|
||||
mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size)
|
||||
whole_avg_latency = whole_end2end / (args.new_length * args.batch_size)
|
||||
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||||
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
||||
if args.dtype in ['fp16','bf16']:
|
||||
if args.dtype in ["fp16", "bf16"]:
|
||||
num_bytes = 2
|
||||
else:
|
||||
num_bytes = 4
|
||||
|
||||
f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n")
|
||||
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000))
|
||||
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000))
|
||||
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000))
|
||||
f.write(
|
||||
f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n"
|
||||
)
|
||||
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000))
|
||||
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000))
|
||||
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000))
|
||||
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
|
||||
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000))
|
||||
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000))
|
||||
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
|
||||
f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000))))
|
||||
f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12))
|
||||
f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000))))
|
||||
f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12))
|
||||
f.write("----------------------------------------------------------\n")
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
current_device = torch.cuda.current_device()
|
||||
|
||||
|
@ -66,7 +75,10 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
|
|||
max_memory_allocated = torch.cuda.max_memory_allocated()
|
||||
memory_reserved = torch.cuda.memory_reserved()
|
||||
max_memory_reserved = torch.cuda.max_memory_reserved()
|
||||
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f:
|
||||
with open(
|
||||
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
|
||||
"a",
|
||||
) as f:
|
||||
f.write(
|
||||
f"\nCurrently using GPU: {current_device}\n"
|
||||
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
|
||||
|
@ -77,29 +89,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
|
|||
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='toy', help='the size of model')
|
||||
parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size')
|
||||
parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length')
|
||||
parser.add_argument('--new_length', type=int, default=4, help='new tokens length')
|
||||
parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size')
|
||||
parser.add_argument('--pp_size', type=int, default=2, help='pipeline size')
|
||||
parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log')
|
||||
parser.add_argument('--dtype', type=str, default='fp16', help='data type')
|
||||
parser.add_argument("--model", default="toy", help="the size of model")
|
||||
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||||
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
|
||||
parser.add_argument("--new_length", type=int, default=4, help="new tokens length")
|
||||
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
||||
parser.add_argument("--pp_size", type=int, default=2, help="pipeline size")
|
||||
parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.model == 'toy':
|
||||
if args.model == "toy":
|
||||
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
|
||||
elif args.model == '7b':
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf'))
|
||||
elif args.model == '13b':
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf'))
|
||||
elif args.model == "7b":
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf"))
|
||||
elif args.model == "13b":
|
||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf"))
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True)
|
||||
|
||||
engine = PPInferEngine(
|
||||
pp_size=args.pp_size,
|
||||
dtype=args.dtype,
|
||||
micro_batch_size=args.mb_size,
|
||||
new_length=args.new_length,
|
||||
model=model,
|
||||
model_policy=LlamaForCausalLMPipelinePolicy(),
|
||||
verbose=True,
|
||||
)
|
||||
data = data_gen(args.batch_size, args.seq_len)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
@ -109,4 +129,3 @@ if __name__ == '__main__':
|
|||
whole_end2end = time.time() - whole_end2end
|
||||
|
||||
print_details_info(timestamps, model.config, args, whole_end2end)
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
from typing import Callable, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -13,7 +11,7 @@ from .microbatch_manager import MicroBatchManager
|
|||
|
||||
|
||||
class PPInferEngine:
|
||||
'''
|
||||
"""
|
||||
PPInferEngine is a class that handles the pipeline parallel inference.
|
||||
|
||||
Args:
|
||||
|
@ -41,12 +39,12 @@ class PPInferEngine:
|
|||
output = engine.inference([tokenized_input])
|
||||
```
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pp_size: int,
|
||||
dtype: str = 'fp16',
|
||||
dtype: str = "fp16",
|
||||
pp_model: nn.Module = None,
|
||||
model: nn.Module = None,
|
||||
model_policy: Policy = None,
|
||||
|
@ -54,7 +52,7 @@ class PPInferEngine:
|
|||
micro_batch_size: int = 1,
|
||||
micro_batch_buffer_size: int = None,
|
||||
verbose: bool = False,
|
||||
# TODO: implement early_stopping, and various gerneration options
|
||||
# TODO: implement early_stopping, and various gerneration options
|
||||
early_stopping: bool = False,
|
||||
do_sample: bool = False,
|
||||
num_beams: int = 1,
|
||||
|
@ -63,15 +61,16 @@ class PPInferEngine:
|
|||
self.pp_size = pp_size
|
||||
self.pg_mesh = ProcessGroupMesh(pp_size)
|
||||
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
|
||||
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
|
||||
micro_batch_buffer_size or pp_size)
|
||||
self.mb_manager = MicroBatchManager(
|
||||
self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
|
||||
)
|
||||
self.verbose = verbose
|
||||
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
|
||||
|
||||
assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||
if dtype == 'fp16':
|
||||
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||
if dtype == "fp16":
|
||||
model.half()
|
||||
elif dtype == 'bf16':
|
||||
elif dtype == "bf16":
|
||||
model.to(torch.bfloat16)
|
||||
self.model = pp_model or self._shardformer(model, model_policy)
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
__all__ = 'MicroBatchManager'
|
||||
__all__ = "MicroBatchManager"
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
|
@ -13,7 +13,7 @@ class Status(Enum):
|
|||
COOLDOWN = 4
|
||||
|
||||
|
||||
class MicroBatchDescription():
|
||||
class MicroBatchDescription:
|
||||
"""
|
||||
This is the class to record the infomation of each microbatch, and also do some update operation.
|
||||
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||
|
@ -30,14 +30,14 @@ class MicroBatchDescription():
|
|||
output_dict: Dict[str, torch.Tensor],
|
||||
new_length: int,
|
||||
) -> None:
|
||||
assert output_dict.get('hidden_states') is not None
|
||||
self.mb_length = output_dict['hidden_states'].shape[-2]
|
||||
assert output_dict.get("hidden_states") is not None
|
||||
self.mb_length = output_dict["hidden_states"].shape[-2]
|
||||
self.target_length = self.mb_length + new_length
|
||||
self.kv_cache = ()
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
if output_dict is not None:
|
||||
self._update_kvcache(output_dict['past_key_values'])
|
||||
self._update_kvcache(output_dict["past_key_values"])
|
||||
|
||||
def _update_kvcache(self, kv_cache: Tuple):
|
||||
assert type(kv_cache) == tuple
|
||||
|
@ -64,7 +64,6 @@ class MicroBatchDescription():
|
|||
Return the current sequnence length of micro batch
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class HeadMicroBatchDescription(MicroBatchDescription):
|
||||
|
@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
|||
|
||||
"""
|
||||
|
||||
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
|
||||
new_length: int) -> None:
|
||||
def __init__(
|
||||
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, output_dict, new_length)
|
||||
assert inputs_dict is not None
|
||||
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None
|
||||
self.input_ids = inputs_dict['input_ids']
|
||||
self.attn_mask = inputs_dict['attention_mask']
|
||||
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
|
||||
self.input_ids = inputs_dict["input_ids"]
|
||||
self.attn_mask = inputs_dict["attention_mask"]
|
||||
self.new_tokens = None
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
|
@ -104,7 +104,8 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
|||
|
||||
def _update_attnmask(self):
|
||||
self.attn_mask = torch.cat(
|
||||
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1)
|
||||
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
|
||||
)
|
||||
|
||||
@property
|
||||
def cur_length(self):
|
||||
|
@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription):
|
|||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||
"""
|
||||
|
||||
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
|
||||
new_length: int) -> None:
|
||||
def __init__(
|
||||
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
|
||||
) -> None:
|
||||
super().__init__(inputs_dict, output_dict, new_length)
|
||||
|
||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||
|
@ -146,8 +148,8 @@ class BodyMicroBatchDescription(MicroBatchDescription):
|
|||
return self.kv_cache[0][0].shape[-2] + 1
|
||||
|
||||
|
||||
class MicroBatchManager():
|
||||
'''
|
||||
class MicroBatchManager:
|
||||
"""
|
||||
MicroBatchManager is a class that manages the micro batch.
|
||||
|
||||
Args:
|
||||
|
@ -156,7 +158,7 @@ class MicroBatchManager():
|
|||
micro_batch_size (int): the micro batch size.
|
||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
|
||||
self.stage = stage
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
|
||||
from transformers.utils import logging
|
||||
|
@ -10,41 +9,41 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
|
||||
|
||||
class GPT2PipelineForwards:
|
||||
'''
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of GPT2 models
|
||||
under pipeline setting.
|
||||
'''
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def gpt2_model_forward(
|
||||
self: GPT2Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
|
||||
self: GPT2Model,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||
# Please refer to original code of transformers for more details.
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Preprocess passed in arguments
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
@ -96,7 +95,7 @@ class GPT2PipelineForwards:
|
|||
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
|
@ -137,7 +136,8 @@ class GPT2PipelineForwards:
|
|||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
presents = () if use_cache else None
|
||||
|
@ -166,7 +166,6 @@ class GPT2PipelineForwards:
|
|||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, use_cache, output_attentions)
|
||||
|
@ -218,61 +217,64 @@ class GPT2PipelineForwards:
|
|||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
return {'hidden_states': hidden_states, 'past_key_values': presents}
|
||||
return {"hidden_states": hidden_states, "past_key_values": presents}
|
||||
|
||||
@staticmethod
|
||||
def gpt2_lmhead_model_forward(
|
||||
self: GPT2LMHeadModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
self: GPT2LMHeadModel,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
token_type_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
head_mask: Optional[torch.FloatTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
hidden_states: Optional[torch.FloatTensor] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||
|
||||
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
"""
|
||||
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
||||
Please refer to original code of transformers for more details.
|
||||
"""
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# If is first stage and after warmup, go throught lm_head first
|
||||
if stage_manager.is_first_stage() and hidden_states is not None:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
return {'logits': lm_logits}
|
||||
return {"logits": lm_logits}
|
||||
|
||||
# Not first stage or before warmup, go through gpt2 model
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index)
|
||||
outputs = GPT2PipelineForwards.gpt2_model_forward(
|
||||
self.transformer,
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
stage_manager=stage_manager,
|
||||
hidden_states=hidden_states,
|
||||
stage_index=stage_index,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from typing import List, Optional, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
||||
from transformers.utils import logging
|
||||
|
||||
|
@ -10,10 +8,10 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
|||
|
||||
|
||||
class LlamaPipelineForwards:
|
||||
'''
|
||||
"""
|
||||
This class serves as a micro library for forward function substitution of Llama models
|
||||
under pipeline setting.
|
||||
'''
|
||||
"""
|
||||
|
||||
def llama_model_forward(
|
||||
self: LlamaModel,
|
||||
|
@ -34,10 +32,10 @@ class LlamaPipelineForwards:
|
|||
|
||||
# Preprocess passed in arguments
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
@ -70,10 +68,9 @@ class LlamaPipelineForwards:
|
|||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
|
@ -81,16 +78,18 @@ class LlamaPipelineForwards:
|
|||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||
# for the other stages, hidden_states is the output of the previous stage
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones((batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=hidden_states.device)
|
||||
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states,
|
||||
past_key_values_length)
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||
)
|
||||
attention_mask = self._prepare_decoder_attention_mask(
|
||||
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||
)
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
|
@ -112,7 +111,6 @@ class LlamaPipelineForwards:
|
|||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs, output_attentions, None)
|
||||
|
@ -152,7 +150,7 @@ class LlamaPipelineForwards:
|
|||
next_cache = next_decoder_cache if use_cache else None
|
||||
|
||||
# always return dict for imediate stage
|
||||
return {'hidden_states': hidden_states, 'past_key_values': next_cache}
|
||||
return {"hidden_states": hidden_states, "past_key_values": next_cache}
|
||||
|
||||
def llama_for_causal_lm_forward(
|
||||
self: LlamaForCausalLM,
|
||||
|
@ -171,45 +169,45 @@ class LlamaPipelineForwards:
|
|||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
|
||||
Returns:
|
||||
Returns:
|
||||
|
||||
Example:
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
```python
|
||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
>>> # Generate
|
||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||
```"""
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if output_attentions:
|
||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||
output_attentions = False
|
||||
if output_hidden_states:
|
||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
||||
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||
output_hidden_states = False
|
||||
|
||||
# If is first stage and after warmup, go throught lm_head first
|
||||
if stage_manager.is_first_stage() and hidden_states is not None:
|
||||
lm_logits = self.lm_head(hidden_states)
|
||||
return {'logits': lm_logits}
|
||||
return {"logits": lm_logits}
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = LlamaPipelineForwards.llama_model_forward(
|
||||
|
|
|
@ -11,7 +11,6 @@ from ..modeling.gpt2 import GPT2PipelineForwards
|
|||
|
||||
|
||||
class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -22,18 +21,22 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
|||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
addon_module = {
|
||||
GPT2LMHeadModel:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
GPT2LMHeadModel: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
|
||||
])
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
|
||||
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||
policy=module_policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=GPT2LMHeadModel,
|
||||
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||
policy=module_policy,
|
||||
)
|
||||
return module_policy
|
||||
|
||||
def get_held_layers(self) -> List[nn.Module]:
|
||||
|
@ -45,7 +48,7 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
|||
return held_layers
|
||||
|
||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||
'''The weights of wte and lm_head are shared.'''
|
||||
"""The weights of wte and lm_head are shared."""
|
||||
module = self.model
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if stage_manager is not None:
|
||||
|
@ -56,16 +59,16 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
|||
|
||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||
to customized forward method, and add this changing to policy."""
|
||||
to customized forward method, and add this changing to policy."""
|
||||
if not self.pipeline_stage_manager:
|
||||
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||
stage_manager = self.pipeline_stage_manager
|
||||
if self.model.__class__.__name__ == 'GPT2Model':
|
||||
if self.model.__class__.__name__ == "GPT2Model":
|
||||
module = self.model
|
||||
else:
|
||||
module = self.model.transformer
|
||||
|
||||
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||
|
|
|
@ -1,19 +1,15 @@
|
|||
from functools import partial
|
||||
from typing import Callable, Dict, List, Union
|
||||
from typing import List
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.layer import Linear1D_Col
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||
|
||||
from ..modeling.llama import LlamaPipelineForwards
|
||||
|
||||
|
||||
class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -25,19 +21,21 @@ class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
|
|||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
new_item = {
|
||||
LlamaForCausalLM:
|
||||
ModulePolicyDescription(sub_module_replacement=[
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
||||
])
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
self.set_pipeline_forward(model_cls=LlamaForCausalLM,
|
||||
new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward,
|
||||
policy=policy)
|
||||
self.set_pipeline_forward(
|
||||
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
||||
)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List, Optional, Set
|
||||
from typing import Set
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -30,6 +30,6 @@ def get_suffix_name(suffix: str, name: str):
|
|||
suffix (str): The suffix of the suffix module
|
||||
name (str): The name of the current module
|
||||
"""
|
||||
point = '' if suffix is '' else '.'
|
||||
suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}'
|
||||
point = "" if suffix is "" else "."
|
||||
suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}"
|
||||
return suffix_name
|
||||
|
|
|
@ -167,7 +167,7 @@ def _p2p_comm(
|
|||
group: ProcessGroup,
|
||||
comm_dtype: torch.dtype = torch.float16,
|
||||
):
|
||||
"""
|
||||
"""
|
||||
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
|
||||
|
||||
Agrs:
|
||||
|
@ -176,7 +176,7 @@ def _p2p_comm(
|
|||
peer (int): rank of the peer
|
||||
group (ProcessGroup): process group
|
||||
comm_dtype (torch.dtype): dtype of the tensor to be sent
|
||||
|
||||
|
||||
Returns:
|
||||
torch.Tensor: tensor received from previous stage
|
||||
"""
|
||||
|
@ -302,7 +302,9 @@ class PipelineP2PCommunication:
|
|||
cur_rank = self.stage_manager.get_rank()
|
||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||
|
||||
def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None:
|
||||
def p2p_communicate(
|
||||
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
||||
) -> None:
|
||||
"""
|
||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||
|
||||
|
@ -313,5 +315,7 @@ class PipelineP2PCommunication:
|
|||
if peer is None:
|
||||
peer = self.stage_manager.get_next_rank()
|
||||
cur_rank = self.stage_manager.get_rank()
|
||||
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype)
|
||||
recv_tensor = _p2p_comm(
|
||||
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
|
||||
)
|
||||
return recv_tensor
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from functools import partial
|
||||
from typing import Any, Iterable, List, Optional, Union
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.cuda
|
||||
|
@ -16,7 +16,7 @@ from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
|
|||
from .base import PipelineSchedule
|
||||
|
||||
|
||||
class ActionIntervalBuffer():
|
||||
class ActionIntervalBuffer:
|
||||
"""
|
||||
The buffer to save the interval hidden states and new token for stage to use.
|
||||
|
||||
|
@ -70,8 +70,9 @@ class GenerateSchedule(PipelineSchedule):
|
|||
self.batch = batch
|
||||
self.batch_size = get_batch_size(batch)
|
||||
self.microbatch_offset = 0
|
||||
assert self.batch_size % self.microbatch_size == 0, \
|
||||
f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
|
||||
assert (
|
||||
self.batch_size % self.microbatch_size == 0
|
||||
), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
|
||||
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||
self.round = self.num_microbatches // self.stage_manager.num_stages
|
||||
|
||||
|
@ -86,26 +87,26 @@ class GenerateSchedule(PipelineSchedule):
|
|||
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
|
||||
|
||||
def _prepare_inputs_for_interval_stage(self):
|
||||
'''
|
||||
"""
|
||||
Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values
|
||||
|
||||
Returns:
|
||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||
'''
|
||||
model_inputs = {
|
||||
'past_key_values': self.mb_manager.cur_kv_cache
|
||||
} if self.mb_manager.cur_kv_cache is not None else None
|
||||
"""
|
||||
model_inputs = (
|
||||
{"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
|
||||
)
|
||||
return model_inputs
|
||||
|
||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||
'''
|
||||
"""
|
||||
Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values`
|
||||
`input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end,
|
||||
`past_key_values` is the past_key_values save in the micro batch manager
|
||||
|
||||
Returns:
|
||||
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
|
||||
'''
|
||||
"""
|
||||
new_mask = self.mb_manager.cur_descrption.attn_mask
|
||||
past_key_values = self.mb_manager.cur_descrption.kv_cache
|
||||
|
||||
|
@ -117,12 +118,12 @@ class GenerateSchedule(PipelineSchedule):
|
|||
return input_ids
|
||||
|
||||
def _recv_pre_stage(self) -> Any:
|
||||
'''
|
||||
"""
|
||||
Receive the output from previous stage
|
||||
|
||||
Returns:
|
||||
Any: The output from previous stage
|
||||
'''
|
||||
"""
|
||||
if self.stage_manager.num_stages == 2:
|
||||
return self.comm.p2p_recv()
|
||||
return self.comm.recv_forward()
|
||||
|
@ -138,7 +139,7 @@ class GenerateSchedule(PipelineSchedule):
|
|||
output_dict = model_forward(model, inputs_dict, None)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _gen_token_action(self, model: Module):
|
||||
"""
|
||||
|
@ -146,13 +147,15 @@ class GenerateSchedule(PipelineSchedule):
|
|||
"""
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
hidden_states = {'hidden_states': hidden_states}
|
||||
hidden_states = {"hidden_states": hidden_states}
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
assert (
|
||||
"logits" in logits
|
||||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
self.action_interval_buffer.new_token = new_token
|
||||
|
@ -168,17 +171,17 @@ class GenerateSchedule(PipelineSchedule):
|
|||
output_dict = model_forward(model, inputs_dict, None)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _body_encoding_action(self, model: Module):
|
||||
hidden_states = self.action_interval_buffer.hidden_states
|
||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||
inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||
hidden_states = {'hidden_states': hidden_states}
|
||||
hidden_states = {"hidden_states": hidden_states}
|
||||
output_dict = model_forward(model, inputs_dict, hidden_states)
|
||||
|
||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
||||
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||
|
||||
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
||||
"""
|
||||
|
@ -246,10 +249,13 @@ class GenerateSchedule(PipelineSchedule):
|
|||
|
||||
whole_timestamp = []
|
||||
|
||||
#run by round
|
||||
# run by round
|
||||
for _ in range(self.round):
|
||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
|
||||
] if self.verbose and self.stage_manager.is_first_stage() else None
|
||||
self.timestamps = (
|
||||
[[] for _ in range(self.stage_manager.num_stages)]
|
||||
if self.verbose and self.stage_manager.is_first_stage()
|
||||
else None
|
||||
)
|
||||
self.action_interval_buffer.clear()
|
||||
while self.mb_manager.is_micro_batch_done() is False:
|
||||
actions = self._gen_action(model)
|
||||
|
@ -286,8 +292,11 @@ class GenerateSchedule(PipelineSchedule):
|
|||
whole_timestamp = []
|
||||
# run by round
|
||||
for _ in range(self.round):
|
||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
|
||||
] if self.verbose and self.stage_manager.is_first_stage() else None
|
||||
self.timestamps = (
|
||||
[[] for _ in range(self.stage_manager.num_stages)]
|
||||
if self.verbose and self.stage_manager.is_first_stage()
|
||||
else None
|
||||
)
|
||||
while self.mb_manager.is_micro_batch_done() is False:
|
||||
inputs_dict = None
|
||||
new_token = None
|
||||
|
@ -307,13 +316,17 @@ class GenerateSchedule(PipelineSchedule):
|
|||
hidden_states = self.comm.recv_forward()
|
||||
if self.stage_manager.is_first_stage():
|
||||
# First just generate a new token
|
||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
assert (
|
||||
hidden_states is not None
|
||||
), "When first stage in GENERATE phase, the hidden states should not be None"
|
||||
logits = model_forward(model, None, hidden_states)
|
||||
if self.verbose and self.stage_manager.is_first_stage():
|
||||
torch.cuda.synchronize()
|
||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits['logits'])
|
||||
assert (
|
||||
"logits" in logits
|
||||
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||
new_token = self._get_token_id(logits["logits"])
|
||||
self.mb_manager.step(None, None, new_token)
|
||||
# If the current micro batch is not DONE, go through blocks
|
||||
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
||||
|
@ -327,9 +340,11 @@ class GenerateSchedule(PipelineSchedule):
|
|||
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||
|
||||
# Current microbatch is not DONE, send hidden_state to next stage
|
||||
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE,
|
||||
Status.COOLDOWN):
|
||||
self.comm.send_forward({'hidden_states': output_dict['hidden_states']})
|
||||
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (
|
||||
Status.GENERATE,
|
||||
Status.COOLDOWN,
|
||||
):
|
||||
self.comm.send_forward({"hidden_states": output_dict["hidden_states"]})
|
||||
|
||||
self.mb_manager.next()
|
||||
|
||||
|
|
|
@ -1,9 +1,6 @@
|
|||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
|
||||
import colossalai
|
||||
|
@ -20,27 +17,29 @@ def data_gen():
|
|||
|
||||
inputs = data_gen()
|
||||
for k, v in inputs.items():
|
||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
||||
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||
new_shape = [1] * v.dim()
|
||||
new_shape[0] = 16
|
||||
inputs[k] = v.to('cuda').repeat(*new_shape)
|
||||
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||
|
||||
|
||||
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
|
||||
engine = PPInferEngine(pp_size=pp_size,
|
||||
model=model,
|
||||
model_policy=GPT2LMHeadModelPipelinePolicy(),
|
||||
new_length=new_length,
|
||||
micro_batch_size=micro_batch_size)
|
||||
engine = PPInferEngine(
|
||||
pp_size=pp_size,
|
||||
model=model,
|
||||
model_policy=GPT2LMHeadModelPipelinePolicy(),
|
||||
new_length=new_length,
|
||||
micro_batch_size=micro_batch_size,
|
||||
)
|
||||
output = engine.inference([inputs])
|
||||
if dist.get_rank() == 0:
|
||||
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
|
||||
|
||||
|
||||
@parameterize('pp_size', [4])
|
||||
@parameterize('new_length', [4, 8, 16])
|
||||
@parameterize('micro_batch_size', [1, 4])
|
||||
@parameterize("pp_size", [4])
|
||||
@parameterize("new_length", [4, 8, 16])
|
||||
@parameterize("micro_batch_size", [1, 4])
|
||||
@clear_cache_before_run()
|
||||
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||
pipeline_inference_test(pp_size, new_length, micro_batch_size)
|
||||
|
@ -48,7 +47,7 @@ def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
|||
|
||||
|
||||
def check_pipeline_inference(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_pipeline_inference_test()
|
||||
|
||||
|
||||
|
@ -59,5 +58,5 @@ def test_pipeline_inference():
|
|||
spawn(check_pipeline_inference, nprocs=4)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
test_pipeline_inference()
|
||||
|
|
Loading…
Reference in New Issue