[format] applied code formatting on changed files in pull request 4820 (#4886)

Co-authored-by: github-actions <github-actions@github.com>
pull/4990/head
github-actions[bot] 2023-10-18 11:46:37 +08:00 committed by GitHub
parent c7aa319ba0
commit 486d06a2d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 297 additions and 258 deletions

View File

@ -1,3 +1,3 @@
from .pipeline import PPInferEngine from .pipeline import PPInferEngine
__all__ = ['PPInferEngine'] __all__ = ["PPInferEngine"]

View File

@ -1,3 +1,3 @@
from .engine import PPInferEngine from .engine import PPInferEngine
__all__ = ['PPInferEngine'] __all__ = ["PPInferEngine"]

View File

@ -1,28 +1,32 @@
import argparse
import time
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers
import colossalai import colossalai
import time
from colossalai.inference import PPInferEngine from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
import argparse
GIGABYTE = 1024 ** 3 GIGABYTE = 1024**3
MEGABYTE = 1024 * 1024 MEGABYTE = 1024 * 1024
colossalai.launch_from_torch(config={}) colossalai.launch_from_torch(config={})
def data_gen(batch_size: int=4, seq_len: int=512):
def data_gen(batch_size: int = 4, seq_len: int = 512):
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
attention_mask = torch.ones((1, seq_len), dtype=torch.int32) attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
data = dict(input_ids=input_ids, attention_mask=attention_mask) data = dict(input_ids=input_ids, attention_mask=attention_mask)
for k, v in data.items(): for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim() new_shape = [1] * v.dim()
new_shape[0] = batch_size new_shape[0] = batch_size
data[k] = v.to('cuda').repeat(*new_shape) data[k] = v.to("cuda").repeat(*new_shape)
return data return data
def print_details_info(timestamps, model_config, args, whole_end2end): def print_details_info(timestamps, model_config, args, whole_end2end):
if dist.get_rank() == 0: if dist.get_rank() == 0:
prefill = [] prefill = []
@ -31,32 +35,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
for timestamp in timestamps: for timestamp in timestamps:
prefill.append(timestamp[1] - timestamp[0]) prefill.append(timestamp[1] - timestamp[0])
encoder.append( encoder.append(
sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2)) sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
)
end2end.append(timestamp[-1] - timestamp[0]) end2end.append(timestamp[-1] - timestamp[0])
print(whole_end2end) print(whole_end2end)
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: with open(
mb_avg_end2end = sum(end2end)/len(end2end) f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size) "w+",
whole_avg_latency = whole_end2end/(args.new_length * args.batch_size) ) as f:
mb_avg_end2end = sum(end2end) / len(end2end)
mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size)
whole_avg_latency = whole_end2end / (args.new_length * args.batch_size)
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
if args.dtype in ['fp16','bf16']: if args.dtype in ["fp16", "bf16"]:
num_bytes = 2 num_bytes = 2
else: else:
num_bytes = 4 num_bytes = 4
f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") f.write(
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000)) f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n"
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000)) )
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000)) f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000))
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000))
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000))
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000)) f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000))
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000)))) f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000))))
f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12)) f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12))
f.write("----------------------------------------------------------\n") f.write("----------------------------------------------------------\n")
if torch.cuda.is_available(): if torch.cuda.is_available():
current_device = torch.cuda.current_device() current_device = torch.cuda.current_device()
@ -66,7 +75,10 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
max_memory_allocated = torch.cuda.max_memory_allocated() max_memory_allocated = torch.cuda.max_memory_allocated()
memory_reserved = torch.cuda.memory_reserved() memory_reserved = torch.cuda.memory_reserved()
max_memory_reserved = torch.cuda.max_memory_reserved() max_memory_reserved = torch.cuda.max_memory_reserved()
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: with open(
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
"a",
) as f:
f.write( f.write(
f"\nCurrently using GPU: {current_device}\n" f"\nCurrently using GPU: {current_device}\n"
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
@ -77,29 +89,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
) )
if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', default='toy', help='the size of model') parser.add_argument("--model", default="toy", help="the size of model")
parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size') parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length') parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
parser.add_argument('--new_length', type=int, default=4, help='new tokens length') parser.add_argument("--new_length", type=int, default=4, help="new tokens length")
parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size') parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
parser.add_argument('--pp_size', type=int, default=2, help='pipeline size') parser.add_argument("--pp_size", type=int, default=2, help="pipeline size")
parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log') parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log")
parser.add_argument('--dtype', type=str, default='fp16', help='data type') parser.add_argument("--dtype", type=str, default="fp16", help="data type")
args = parser.parse_args() args = parser.parse_args()
if args.model == 'toy': if args.model == "toy":
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
elif args.model == '7b': elif args.model == "7b":
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf')) model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf"))
elif args.model == '13b': elif args.model == "13b":
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf')) model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf"))
else: else:
raise NotImplementedError raise NotImplementedError
engine = PPInferEngine(
engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) pp_size=args.pp_size,
dtype=args.dtype,
micro_batch_size=args.mb_size,
new_length=args.new_length,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy(),
verbose=True,
)
data = data_gen(args.batch_size, args.seq_len) data = data_gen(args.batch_size, args.seq_len)
torch.cuda.synchronize() torch.cuda.synchronize()
@ -109,4 +129,3 @@ if __name__ == '__main__':
whole_end2end = time.time() - whole_end2end whole_end2end = time.time() - whole_end2end
print_details_info(timestamps, model.config, args, whole_end2end) print_details_info(timestamps, model.config, args, whole_end2end)

View File

@ -1,5 +1,3 @@
from typing import Callable, List, Optional, Set, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -13,7 +11,7 @@ from .microbatch_manager import MicroBatchManager
class PPInferEngine: class PPInferEngine:
''' """
PPInferEngine is a class that handles the pipeline parallel inference. PPInferEngine is a class that handles the pipeline parallel inference.
Args: Args:
@ -41,12 +39,12 @@ class PPInferEngine:
output = engine.inference([tokenized_input]) output = engine.inference([tokenized_input])
``` ```
''' """
def __init__( def __init__(
self, self,
pp_size: int, pp_size: int,
dtype: str = 'fp16', dtype: str = "fp16",
pp_model: nn.Module = None, pp_model: nn.Module = None,
model: nn.Module = None, model: nn.Module = None,
model_policy: Policy = None, model_policy: Policy = None,
@ -54,7 +52,7 @@ class PPInferEngine:
micro_batch_size: int = 1, micro_batch_size: int = 1,
micro_batch_buffer_size: int = None, micro_batch_buffer_size: int = None,
verbose: bool = False, verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options # TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False, early_stopping: bool = False,
do_sample: bool = False, do_sample: bool = False,
num_beams: int = 1, num_beams: int = 1,
@ -63,15 +61,16 @@ class PPInferEngine:
self.pp_size = pp_size self.pp_size = pp_size
self.pg_mesh = ProcessGroupMesh(pp_size) self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, self.mb_manager = MicroBatchManager(
micro_batch_buffer_size or pp_size) self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
)
self.verbose = verbose self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
if dtype == 'fp16': if dtype == "fp16":
model.half() model.half()
elif dtype == 'bf16': elif dtype == "bf16":
model.to(torch.bfloat16) model.to(torch.bfloat16)
self.model = pp_model or self._shardformer(model, model_policy) self.model = pp_model or self._shardformer(model, model_policy)

View File

@ -3,7 +3,7 @@ from typing import Dict, Tuple
import torch import torch
__all__ = 'MicroBatchManager' __all__ = "MicroBatchManager"
class Status(Enum): class Status(Enum):
@ -13,7 +13,7 @@ class Status(Enum):
COOLDOWN = 4 COOLDOWN = 4
class MicroBatchDescription(): class MicroBatchDescription:
""" """
This is the class to record the infomation of each microbatch, and also do some update operation. This is the class to record the infomation of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
@ -30,14 +30,14 @@ class MicroBatchDescription():
output_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
new_length: int, new_length: int,
) -> None: ) -> None:
assert output_dict.get('hidden_states') is not None assert output_dict.get("hidden_states") is not None
self.mb_length = output_dict['hidden_states'].shape[-2] self.mb_length = output_dict["hidden_states"].shape[-2]
self.target_length = self.mb_length + new_length self.target_length = self.mb_length + new_length
self.kv_cache = () self.kv_cache = ()
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
if output_dict is not None: if output_dict is not None:
self._update_kvcache(output_dict['past_key_values']) self._update_kvcache(output_dict["past_key_values"])
def _update_kvcache(self, kv_cache: Tuple): def _update_kvcache(self, kv_cache: Tuple):
assert type(kv_cache) == tuple assert type(kv_cache) == tuple
@ -64,7 +64,6 @@ class MicroBatchDescription():
Return the current sequnence length of micro batch Return the current sequnence length of micro batch
""" """
pass
class HeadMicroBatchDescription(MicroBatchDescription): class HeadMicroBatchDescription(MicroBatchDescription):
@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription):
""" """
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], def __init__(
new_length: int) -> None: self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
) -> None:
super().__init__(inputs_dict, output_dict, new_length) super().__init__(inputs_dict, output_dict, new_length)
assert inputs_dict is not None assert inputs_dict is not None
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
self.input_ids = inputs_dict['input_ids'] self.input_ids = inputs_dict["input_ids"]
self.attn_mask = inputs_dict['attention_mask'] self.attn_mask = inputs_dict["attention_mask"]
self.new_tokens = None self.new_tokens = None
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
@ -104,7 +104,8 @@ class HeadMicroBatchDescription(MicroBatchDescription):
def _update_attnmask(self): def _update_attnmask(self):
self.attn_mask = torch.cat( self.attn_mask = torch.cat(
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
)
@property @property
def cur_length(self): def cur_length(self):
@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription):
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
""" """
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], def __init__(
new_length: int) -> None: self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
) -> None:
super().__init__(inputs_dict, output_dict, new_length) super().__init__(inputs_dict, output_dict, new_length)
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
@ -146,8 +148,8 @@ class BodyMicroBatchDescription(MicroBatchDescription):
return self.kv_cache[0][0].shape[-2] + 1 return self.kv_cache[0][0].shape[-2] + 1
class MicroBatchManager(): class MicroBatchManager:
''' """
MicroBatchManager is a class that manages the micro batch. MicroBatchManager is a class that manages the micro batch.
Args: Args:
@ -156,7 +158,7 @@ class MicroBatchManager():
micro_batch_size (int): the micro batch size. micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
''' """
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
self.stage = stage self.stage = stage

View File

@ -1,7 +1,6 @@
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import torch import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
from transformers.utils import logging from transformers.utils import logging
@ -10,41 +9,41 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
class GPT2PipelineForwards: class GPT2PipelineForwards:
''' """
This class serves as a micro library for forward function substitution of GPT2 models This class serves as a micro library for forward function substitution of GPT2 models
under pipeline setting. under pipeline setting.
''' """
@staticmethod @staticmethod
def gpt2_model_forward( def gpt2_model_forward(
self: GPT2Model, self: GPT2Model,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: stage_index: Optional[List[int]] = None,
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details. # Please refer to original code of transformers for more details.
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Preprocess passed in arguments # Preprocess passed in arguments
if output_attentions: if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False output_attentions = False
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False output_hidden_states = False
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -96,7 +95,7 @@ class GPT2PipelineForwards:
# positions we want to attend and the dtype's smallest value for masked positions. # positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
@ -137,7 +136,8 @@ class GPT2PipelineForwards:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False use_cache = False
presents = () if use_cache else None presents = () if use_cache else None
@ -166,7 +166,6 @@ class GPT2PipelineForwards:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, use_cache, output_attentions) return module(*inputs, use_cache, output_attentions)
@ -218,61 +217,64 @@ class GPT2PipelineForwards:
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
return {'hidden_states': hidden_states, 'past_key_values': presents} return {"hidden_states": hidden_states, "past_key_values": presents}
@staticmethod @staticmethod
def gpt2_lmhead_model_forward( def gpt2_lmhead_model_forward(
self: GPT2LMHeadModel, self: GPT2LMHeadModel,
input_ids: Optional[torch.LongTensor] = None, input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None, stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: stage_index: Optional[List[int]] = None,
) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
Please refer to original code of transformers for more details. Please refer to original code of transformers for more details.
""" """
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# If is first stage and after warmup, go throught lm_head first # If is first stage and after warmup, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None: if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
return {'logits': lm_logits} return {"logits": lm_logits}
# Not first stage or before warmup, go through gpt2 model # Not first stage or before warmup, go through gpt2 model
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, outputs = GPT2PipelineForwards.gpt2_model_forward(
input_ids, self.transformer,
past_key_values=past_key_values, input_ids,
attention_mask=attention_mask, past_key_values=past_key_values,
token_type_ids=token_type_ids, attention_mask=attention_mask,
position_ids=position_ids, token_type_ids=token_type_ids,
head_mask=head_mask, position_ids=position_ids,
inputs_embeds=inputs_embeds, head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states, inputs_embeds=inputs_embeds,
encoder_attention_mask=encoder_attention_mask, encoder_hidden_states=encoder_hidden_states,
use_cache=use_cache, encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions, use_cache=use_cache,
output_hidden_states=output_hidden_states, output_attentions=output_attentions,
return_dict=return_dict, output_hidden_states=output_hidden_states,
stage_manager=stage_manager, return_dict=return_dict,
hidden_states=hidden_states, stage_manager=stage_manager,
stage_index=stage_index) hidden_states=hidden_states,
stage_index=stage_index,
)
return outputs return outputs

View File

@ -1,8 +1,6 @@
from typing import List, Optional, Tuple from typing import List, Optional
import torch import torch
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from transformers.utils import logging from transformers.utils import logging
@ -10,10 +8,10 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
class LlamaPipelineForwards: class LlamaPipelineForwards:
''' """
This class serves as a micro library for forward function substitution of Llama models This class serves as a micro library for forward function substitution of Llama models
under pipeline setting. under pipeline setting.
''' """
def llama_model_forward( def llama_model_forward(
self: LlamaModel, self: LlamaModel,
@ -34,10 +32,10 @@ class LlamaPipelineForwards:
# Preprocess passed in arguments # Preprocess passed in arguments
if output_attentions: if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False output_attentions = False
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False output_hidden_states = False
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
@ -70,10 +68,9 @@ class LlamaPipelineForwards:
seq_length_with_past = seq_length_with_past + past_key_values_length seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None: if position_ids is None:
position_ids = torch.arange(past_key_values_length, position_ids = torch.arange(
seq_length + past_key_values_length, past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
dtype=torch.long, )
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else: else:
position_ids = position_ids.view(-1, seq_length).long() position_ids = position_ids.view(-1, seq_length).long()
@ -81,16 +78,18 @@ class LlamaPipelineForwards:
# embed positions, for the first stage, hidden_states is the input embeddings, # embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage # for the other stages, hidden_states is the output of the previous stage
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), attention_mask = torch.ones(
dtype=torch.bool, (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
device=hidden_states.device) )
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, attention_mask = self._prepare_decoder_attention_mask(
past_key_values_length) attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning_once( logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False use_cache = False
# decoder layers # decoder layers
@ -112,7 +111,6 @@ class LlamaPipelineForwards:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# None for past_key_value # None for past_key_value
return module(*inputs, output_attentions, None) return module(*inputs, output_attentions, None)
@ -152,7 +150,7 @@ class LlamaPipelineForwards:
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
# always return dict for imediate stage # always return dict for imediate stage
return {'hidden_states': hidden_states, 'past_key_values': next_cache} return {"hidden_states": hidden_states, "past_key_values": next_cache}
def llama_for_causal_lm_forward( def llama_for_causal_lm_forward(
self: LlamaForCausalLM, self: LlamaForCausalLM,
@ -171,45 +169,45 @@ class LlamaPipelineForwards:
stage_index: Optional[List[int]] = None, stage_index: Optional[List[int]] = None,
): ):
r""" r"""
Args: Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns: Returns:
Example: Example:
```python ```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM >>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?" >>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt") >>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate >>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```""" ```"""
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions: if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
output_attentions = False output_attentions = False
if output_hidden_states: if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
output_hidden_states = False output_hidden_states = False
# If is first stage and after warmup, go throught lm_head first # If is first stage and after warmup, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None: if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
return {'logits': lm_logits} return {"logits": lm_logits}
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward( outputs = LlamaPipelineForwards.llama_model_forward(

View File

@ -11,7 +11,6 @@ from ..modeling.gpt2 import GPT2PipelineForwards
class GPT2LMHeadModelPipelinePolicy(GPT2Policy): class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -22,18 +21,22 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
addon_module = { addon_module = {
GPT2LMHeadModel: GPT2LMHeadModel: ModulePolicyDescription(
ModulePolicyDescription(sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
]) )
]
)
} }
module_policy.update(addon_module) module_policy.update(addon_module)
if self.pipeline_stage_manager is not None: if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPT2LMHeadModel, self.set_pipeline_forward(
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, model_cls=GPT2LMHeadModel,
policy=module_policy) new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
policy=module_policy,
)
return module_policy return module_policy
def get_held_layers(self) -> List[nn.Module]: def get_held_layers(self) -> List[nn.Module]:
@ -45,7 +48,7 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
return held_layers return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]: def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''The weights of wte and lm_head are shared.''' """The weights of wte and lm_head are shared."""
module = self.model module = self.model
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if stage_manager is not None: if stage_manager is not None:
@ -56,16 +59,16 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface """If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy.""" to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager: if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'GPT2Model': if self.model.__class__.__name__ == "GPT2Model":
module = self.model module = self.model
else: else:
module = self.model.transformer module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)

View File

@ -1,19 +1,15 @@
from functools import partial from typing import List
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D from colossalai.shardformer.layer import Linear1D_Col
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaPolicy from colossalai.shardformer.policies.llama import LlamaPolicy
from ..modeling.llama import LlamaPipelineForwards from ..modeling.llama import LlamaPipelineForwards
class LlamaForCausalLMPipelinePolicy(LlamaPolicy): class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -25,19 +21,21 @@ class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
if self.shard_config.enable_tensor_parallelism: if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm # add a new item for casual lm
new_item = { new_item = {
LlamaForCausalLM: LlamaForCausalLM: ModulePolicyDescription(
ModulePolicyDescription(sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
]) )
]
)
} }
policy.update(new_item) policy.update(new_item)
if self.pipeline_stage_manager: if self.pipeline_stage_manager:
# set None as default # set None as default
self.set_pipeline_forward(model_cls=LlamaForCausalLM, self.set_pipeline_forward(
new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
policy=policy) )
return policy return policy

View File

@ -1,4 +1,4 @@
from typing import List, Optional, Set from typing import Set
import torch.nn as nn import torch.nn as nn
@ -30,6 +30,6 @@ def get_suffix_name(suffix: str, name: str):
suffix (str): The suffix of the suffix module suffix (str): The suffix of the suffix module
name (str): The name of the current module name (str): The name of the current module
""" """
point = '' if suffix is '' else '.' point = "" if suffix is "" else "."
suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}' suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}"
return suffix_name return suffix_name

View File

@ -302,7 +302,9 @@ class PipelineP2PCommunication:
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None: def p2p_communicate(
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
) -> None:
""" """
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
@ -313,5 +315,7 @@ class PipelineP2PCommunication:
if peer is None: if peer is None:
peer = self.stage_manager.get_next_rank() peer = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank() cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype) recv_tensor = _p2p_comm(
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
)
return recv_tensor return recv_tensor

View File

@ -1,6 +1,6 @@
import time import time
from functools import partial from functools import partial
from typing import Any, Iterable, List, Optional, Union from typing import Any, Iterable, Optional, Union
import torch import torch
import torch.cuda import torch.cuda
@ -16,7 +16,7 @@ from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule from .base import PipelineSchedule
class ActionIntervalBuffer(): class ActionIntervalBuffer:
""" """
The buffer to save the interval hidden states and new token for stage to use. The buffer to save the interval hidden states and new token for stage to use.
@ -70,8 +70,9 @@ class GenerateSchedule(PipelineSchedule):
self.batch = batch self.batch = batch
self.batch_size = get_batch_size(batch) self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0 self.microbatch_offset = 0
assert self.batch_size % self.microbatch_size == 0, \ assert (
f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" self.batch_size % self.microbatch_size == 0
), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
self.num_microbatches = self.batch_size // self.microbatch_size self.num_microbatches = self.batch_size // self.microbatch_size
self.round = self.num_microbatches // self.stage_manager.num_stages self.round = self.num_microbatches // self.stage_manager.num_stages
@ -86,26 +87,26 @@ class GenerateSchedule(PipelineSchedule):
return tree_map(partial(to_device, device=get_current_device()), micro_batch) return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def _prepare_inputs_for_interval_stage(self): def _prepare_inputs_for_interval_stage(self):
''' """
Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values
Returns: Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
''' """
model_inputs = { model_inputs = (
'past_key_values': self.mb_manager.cur_kv_cache {"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
} if self.mb_manager.cur_kv_cache is not None else None )
return model_inputs return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
''' """
Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values` Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values`
`input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end, `input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end,
`past_key_values` is the past_key_values save in the micro batch manager `past_key_values` is the past_key_values save in the micro batch manager
Returns: Returns:
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
''' """
new_mask = self.mb_manager.cur_descrption.attn_mask new_mask = self.mb_manager.cur_descrption.attn_mask
past_key_values = self.mb_manager.cur_descrption.kv_cache past_key_values = self.mb_manager.cur_descrption.kv_cache
@ -117,12 +118,12 @@ class GenerateSchedule(PipelineSchedule):
return input_ids return input_ids
def _recv_pre_stage(self) -> Any: def _recv_pre_stage(self) -> Any:
''' """
Receive the output from previous stage Receive the output from previous stage
Returns: Returns:
Any: The output from previous stage Any: The output from previous stage
''' """
if self.stage_manager.num_stages == 2: if self.stage_manager.num_stages == 2:
return self.comm.p2p_recv() return self.comm.p2p_recv()
return self.comm.recv_forward() return self.comm.recv_forward()
@ -138,7 +139,7 @@ class GenerateSchedule(PipelineSchedule):
output_dict = model_forward(model, inputs_dict, None) output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None) self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states'] self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _gen_token_action(self, model: Module): def _gen_token_action(self, model: Module):
""" """
@ -146,13 +147,15 @@ class GenerateSchedule(PipelineSchedule):
""" """
hidden_states = self.action_interval_buffer.hidden_states hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
hidden_states = {'hidden_states': hidden_states} hidden_states = {"hidden_states": hidden_states}
logits = model_forward(model, None, hidden_states) logits = model_forward(model, None, hidden_states)
if self.verbose and self.stage_manager.is_first_stage(): if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize() torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time()) self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" assert (
new_token = self._get_token_id(logits['logits']) "logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(None, None, new_token) self.mb_manager.step(None, None, new_token)
self.action_interval_buffer.new_token = new_token self.action_interval_buffer.new_token = new_token
@ -168,17 +171,17 @@ class GenerateSchedule(PipelineSchedule):
output_dict = model_forward(model, inputs_dict, None) output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None) self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states'] self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _body_encoding_action(self, model: Module): def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None" assert hidden_states is not None, "When not first stage, the hidden states should not be None"
inputs_dict = self._prepare_inputs_for_interval_stage() inputs_dict = self._prepare_inputs_for_interval_stage()
hidden_states = {'hidden_states': hidden_states} hidden_states = {"hidden_states": hidden_states}
output_dict = model_forward(model, inputs_dict, hidden_states) output_dict = model_forward(model, inputs_dict, hidden_states)
self.mb_manager.step(inputs_dict, output_dict, None) self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states'] self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
def _comm_action(self, recv_pre: bool) -> torch.Tensor: def _comm_action(self, recv_pre: bool) -> torch.Tensor:
""" """
@ -246,10 +249,13 @@ class GenerateSchedule(PipelineSchedule):
whole_timestamp = [] whole_timestamp = []
#run by round # run by round
for _ in range(self.round): for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages) self.timestamps = (
] if self.verbose and self.stage_manager.is_first_stage() else None [[] for _ in range(self.stage_manager.num_stages)]
if self.verbose and self.stage_manager.is_first_stage()
else None
)
self.action_interval_buffer.clear() self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False: while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_action(model) actions = self._gen_action(model)
@ -286,8 +292,11 @@ class GenerateSchedule(PipelineSchedule):
whole_timestamp = [] whole_timestamp = []
# run by round # run by round
for _ in range(self.round): for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages) self.timestamps = (
] if self.verbose and self.stage_manager.is_first_stage() else None [[] for _ in range(self.stage_manager.num_stages)]
if self.verbose and self.stage_manager.is_first_stage()
else None
)
while self.mb_manager.is_micro_batch_done() is False: while self.mb_manager.is_micro_batch_done() is False:
inputs_dict = None inputs_dict = None
new_token = None new_token = None
@ -307,13 +316,17 @@ class GenerateSchedule(PipelineSchedule):
hidden_states = self.comm.recv_forward() hidden_states = self.comm.recv_forward()
if self.stage_manager.is_first_stage(): if self.stage_manager.is_first_stage():
# First just generate a new token # First just generate a new token
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" assert (
hidden_states is not None
), "When first stage in GENERATE phase, the hidden states should not be None"
logits = model_forward(model, None, hidden_states) logits = model_forward(model, None, hidden_states)
if self.verbose and self.stage_manager.is_first_stage(): if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize() torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time()) self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" assert (
new_token = self._get_token_id(logits['logits']) "logits" in logits
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits["logits"])
self.mb_manager.step(None, None, new_token) self.mb_manager.step(None, None, new_token)
# If the current micro batch is not DONE, go through blocks # If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
@ -327,9 +340,11 @@ class GenerateSchedule(PipelineSchedule):
self.mb_manager.step(inputs_dict, output_dict, None) self.mb_manager.step(inputs_dict, output_dict, None)
# Current microbatch is not DONE, send hidden_state to next stage # Current microbatch is not DONE, send hidden_state to next stage
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE, if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (
Status.COOLDOWN): Status.GENERATE,
self.comm.send_forward({'hidden_states': output_dict['hidden_states']}) Status.COOLDOWN,
):
self.comm.send_forward({"hidden_states": output_dict["hidden_states"]})
self.mb_manager.next() self.mb_manager.next()

View File

@ -1,9 +1,6 @@
from copy import deepcopy
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn
import transformers import transformers
import colossalai import colossalai
@ -20,27 +17,29 @@ def data_gen():
inputs = data_gen() inputs = data_gen()
for k, v in inputs.items(): for k, v in inputs.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
new_shape = [1] * v.dim() new_shape = [1] * v.dim()
new_shape[0] = 16 new_shape[0] = 16
inputs[k] = v.to('cuda').repeat(*new_shape) inputs[k] = v.to("cuda").repeat(*new_shape)
def pipeline_inference_test(pp_size, new_length, micro_batch_size): def pipeline_inference_test(pp_size, new_length, micro_batch_size):
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
engine = PPInferEngine(pp_size=pp_size, engine = PPInferEngine(
model=model, pp_size=pp_size,
model_policy=GPT2LMHeadModelPipelinePolicy(), model=model,
new_length=new_length, model_policy=GPT2LMHeadModelPipelinePolicy(),
micro_batch_size=micro_batch_size) new_length=new_length,
micro_batch_size=micro_batch_size,
)
output = engine.inference([inputs]) output = engine.inference([inputs])
if dist.get_rank() == 0: if dist.get_rank() == 0:
assert len(output[0]) == new_length, f"{len(output)}, {new_length}" assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
@parameterize('pp_size', [4]) @parameterize("pp_size", [4])
@parameterize('new_length', [4, 8, 16]) @parameterize("new_length", [4, 8, 16])
@parameterize('micro_batch_size', [1, 4]) @parameterize("micro_batch_size", [1, 4])
@clear_cache_before_run() @clear_cache_before_run()
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
pipeline_inference_test(pp_size, new_length, micro_batch_size) pipeline_inference_test(pp_size, new_length, micro_batch_size)
@ -48,7 +47,7 @@ def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
def check_pipeline_inference(rank, world_size, port): def check_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_pipeline_inference_test() run_pipeline_inference_test()
@ -59,5 +58,5 @@ def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=4) spawn(check_pipeline_inference, nprocs=4)
if __name__ == '__main__': if __name__ == "__main__":
test_pipeline_inference() test_pipeline_inference()