diff --git a/colossalai/inference/__init__.py b/colossalai/inference/__init__.py index db33ae6fe..35891307e 100644 --- a/colossalai/inference/__init__.py +++ b/colossalai/inference/__init__.py @@ -1,3 +1,3 @@ from .pipeline import PPInferEngine -__all__ = ['PPInferEngine'] +__all__ = ["PPInferEngine"] diff --git a/colossalai/inference/pipeline/__init__.py b/colossalai/inference/pipeline/__init__.py index aff4568f7..41af9f3ef 100644 --- a/colossalai/inference/pipeline/__init__.py +++ b/colossalai/inference/pipeline/__init__.py @@ -1,3 +1,3 @@ from .engine import PPInferEngine -__all__ = ['PPInferEngine'] +__all__ = ["PPInferEngine"] diff --git a/colossalai/inference/pipeline/benchmark/benchmark.py b/colossalai/inference/pipeline/benchmark/benchmark.py index 97dfc6336..9c47909f7 100644 --- a/colossalai/inference/pipeline/benchmark/benchmark.py +++ b/colossalai/inference/pipeline/benchmark/benchmark.py @@ -1,28 +1,32 @@ +import argparse +import time + import torch import torch.distributed as dist import transformers import colossalai -import time from colossalai.inference import PPInferEngine from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy -import argparse -GIGABYTE = 1024 ** 3 + +GIGABYTE = 1024**3 MEGABYTE = 1024 * 1024 colossalai.launch_from_torch(config={}) -def data_gen(batch_size: int=4, seq_len: int=512): + +def data_gen(batch_size: int = 4, seq_len: int = 512): input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32) attention_mask = torch.ones((1, seq_len), dtype=torch.int32) data = dict(input_ids=input_ids, attention_mask=attention_mask) for k, v in data.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = batch_size - data[k] = v.to('cuda').repeat(*new_shape) + data[k] = v.to("cuda").repeat(*new_shape) return data + def print_details_info(timestamps, model_config, args, whole_end2end): if dist.get_rank() == 0: prefill = [] @@ -31,32 +35,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end): for timestamp in timestamps: prefill.append(timestamp[1] - timestamp[0]) encoder.append( - sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2)) + sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2) + ) end2end.append(timestamp[-1] - timestamp[0]) print(whole_end2end) - with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f: - mb_avg_end2end = sum(end2end)/len(end2end) - mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size) - whole_avg_latency = whole_end2end/(args.new_length * args.batch_size) + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "w+", + ) as f: + mb_avg_end2end = sum(end2end) / len(end2end) + mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size) + whole_avg_latency = whole_end2end / (args.new_length * args.batch_size) num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers) num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size - if args.dtype in ['fp16','bf16']: + if args.dtype in ["fp16", "bf16"]: num_bytes = 2 else: num_bytes = 4 - f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n") - f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000)) - f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000)) - f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000)) + f.write( + f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n" + ) + f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000)) + f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000)) + f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000)) f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000)) - f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000)) + f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000)) f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000)) - f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000)))) - f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12)) + f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000)))) + f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12)) f.write("----------------------------------------------------------\n") - if torch.cuda.is_available(): current_device = torch.cuda.current_device() @@ -66,7 +75,10 @@ def print_details_info(timestamps, model_config, args, whole_end2end): max_memory_allocated = torch.cuda.max_memory_allocated() memory_reserved = torch.cuda.memory_reserved() max_memory_reserved = torch.cuda.max_memory_reserved() - with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f: + with open( + f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log", + "a", + ) as f: f.write( f"\nCurrently using GPU: {current_device}\n" f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n" @@ -77,29 +89,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end): f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n" ) -if __name__ == '__main__': + +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model', default='toy', help='the size of model') - parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size') - parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length') - parser.add_argument('--new_length', type=int, default=4, help='new tokens length') - parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size') - parser.add_argument('--pp_size', type=int, default=2, help='pipeline size') - parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log') - parser.add_argument('--dtype', type=str, default='fp16', help='data type') + parser.add_argument("--model", default="toy", help="the size of model") + parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size") + parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length") + parser.add_argument("--new_length", type=int, default=4, help="new tokens length") + parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size") + parser.add_argument("--pp_size", type=int, default=2, help="pipeline size") + parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log") + parser.add_argument("--dtype", type=str, default="fp16", help="data type") args = parser.parse_args() - if args.model == 'toy': + if args.model == "toy": model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8)) - elif args.model == '7b': - model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf')) - elif args.model == '13b': - model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf')) + elif args.model == "7b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf")) + elif args.model == "13b": + model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf")) else: raise NotImplementedError - - - engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True) + + engine = PPInferEngine( + pp_size=args.pp_size, + dtype=args.dtype, + micro_batch_size=args.mb_size, + new_length=args.new_length, + model=model, + model_policy=LlamaForCausalLMPipelinePolicy(), + verbose=True, + ) data = data_gen(args.batch_size, args.seq_len) torch.cuda.synchronize() @@ -109,4 +129,3 @@ if __name__ == '__main__': whole_end2end = time.time() - whole_end2end print_details_info(timestamps, model.config, args, whole_end2end) - diff --git a/colossalai/inference/pipeline/engine.py b/colossalai/inference/pipeline/engine.py index 048ead2bc..4f42385ca 100644 --- a/colossalai/inference/pipeline/engine.py +++ b/colossalai/inference/pipeline/engine.py @@ -1,5 +1,3 @@ -from typing import Callable, List, Optional, Set, Union - import torch import torch.nn as nn @@ -13,7 +11,7 @@ from .microbatch_manager import MicroBatchManager class PPInferEngine: - ''' + """ PPInferEngine is a class that handles the pipeline parallel inference. Args: @@ -41,12 +39,12 @@ class PPInferEngine: output = engine.inference([tokenized_input]) ``` - ''' + """ def __init__( self, pp_size: int, - dtype: str = 'fp16', + dtype: str = "fp16", pp_model: nn.Module = None, model: nn.Module = None, model_policy: Policy = None, @@ -54,7 +52,7 @@ class PPInferEngine: micro_batch_size: int = 1, micro_batch_buffer_size: int = None, verbose: bool = False, - # TODO: implement early_stopping, and various gerneration options + # TODO: implement early_stopping, and various gerneration options early_stopping: bool = False, do_sample: bool = False, num_beams: int = 1, @@ -63,15 +61,16 @@ class PPInferEngine: self.pp_size = pp_size self.pg_mesh = ProcessGroupMesh(pp_size) self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True) - self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size, - micro_batch_buffer_size or pp_size) + self.mb_manager = MicroBatchManager( + self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size + ) self.verbose = verbose self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose) - assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'" - if dtype == 'fp16': + assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'" + if dtype == "fp16": model.half() - elif dtype == 'bf16': + elif dtype == "bf16": model.to(torch.bfloat16) self.model = pp_model or self._shardformer(model, model_policy) diff --git a/colossalai/inference/pipeline/microbatch_manager.py b/colossalai/inference/pipeline/microbatch_manager.py index b6b008442..49d1bf3f4 100644 --- a/colossalai/inference/pipeline/microbatch_manager.py +++ b/colossalai/inference/pipeline/microbatch_manager.py @@ -3,7 +3,7 @@ from typing import Dict, Tuple import torch -__all__ = 'MicroBatchManager' +__all__ = "MicroBatchManager" class Status(Enum): @@ -13,7 +13,7 @@ class Status(Enum): COOLDOWN = 4 -class MicroBatchDescription(): +class MicroBatchDescription: """ This is the class to record the infomation of each microbatch, and also do some update operation. This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more @@ -30,14 +30,14 @@ class MicroBatchDescription(): output_dict: Dict[str, torch.Tensor], new_length: int, ) -> None: - assert output_dict.get('hidden_states') is not None - self.mb_length = output_dict['hidden_states'].shape[-2] + assert output_dict.get("hidden_states") is not None + self.mb_length = output_dict["hidden_states"].shape[-2] self.target_length = self.mb_length + new_length self.kv_cache = () def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): if output_dict is not None: - self._update_kvcache(output_dict['past_key_values']) + self._update_kvcache(output_dict["past_key_values"]) def _update_kvcache(self, kv_cache: Tuple): assert type(kv_cache) == tuple @@ -64,7 +64,6 @@ class MicroBatchDescription(): Return the current sequnence length of micro batch """ - pass class HeadMicroBatchDescription(MicroBatchDescription): @@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription): """ - def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], - new_length: int) -> None: + def __init__( + self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + ) -> None: super().__init__(inputs_dict, output_dict, new_length) assert inputs_dict is not None - assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None - self.input_ids = inputs_dict['input_ids'] - self.attn_mask = inputs_dict['attention_mask'] + assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None + self.input_ids = inputs_dict["input_ids"] + self.attn_mask = inputs_dict["attention_mask"] self.new_tokens = None def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): @@ -104,7 +104,8 @@ class HeadMicroBatchDescription(MicroBatchDescription): def _update_attnmask(self): self.attn_mask = torch.cat( - (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1) + (self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1 + ) @property def cur_length(self): @@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription): output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`. """ - def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], - new_length: int) -> None: + def __init__( + self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int + ) -> None: super().__init__(inputs_dict, output_dict, new_length) def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None): @@ -146,8 +148,8 @@ class BodyMicroBatchDescription(MicroBatchDescription): return self.kv_cache[0][0].shape[-2] + 1 -class MicroBatchManager(): - ''' +class MicroBatchManager: + """ MicroBatchManager is a class that manages the micro batch. Args: @@ -156,7 +158,7 @@ class MicroBatchManager(): micro_batch_size (int): the micro batch size. micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages. - ''' + """ def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int): self.stage = stage diff --git a/colossalai/inference/pipeline/modeling/gpt2.py b/colossalai/inference/pipeline/modeling/gpt2.py index f490710c1..d2bfcb8b6 100644 --- a/colossalai/inference/pipeline/modeling/gpt2.py +++ b/colossalai/inference/pipeline/modeling/gpt2.py @@ -1,7 +1,6 @@ from typing import Dict, List, Optional, Tuple, Union import torch -from torch.nn import CrossEntropyLoss from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model from transformers.utils import logging @@ -10,41 +9,41 @@ from colossalai.pipeline.stage_manager import PipelineStageManager class GPT2PipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of GPT2 models under pipeline setting. - ''' + """ @staticmethod def gpt2_model_forward( - self: GPT2Model, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: - + self: GPT2Model, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]: # This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward. # Please refer to original code of transformers for more details. logger = logging.get_logger(__name__) # Preprocess passed in arguments if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -96,7 +95,7 @@ class GPT2PipelineForwards: # positions we want to attend and the dtype's smallest value for masked positions. # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention @@ -137,7 +136,8 @@ class GPT2PipelineForwards: if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False presents = () if use_cache else None @@ -166,7 +166,6 @@ class GPT2PipelineForwards: if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, use_cache, output_attentions) @@ -218,61 +217,64 @@ class GPT2PipelineForwards: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - return {'hidden_states': hidden_states, 'past_key_values': presents} + return {"hidden_states": hidden_states, "past_key_values": presents} @staticmethod def gpt2_lmhead_model_forward( - self: GPT2LMHeadModel, - input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, - attention_mask: Optional[torch.FloatTensor] = None, - token_type_ids: Optional[torch.LongTensor] = None, - position_ids: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.FloatTensor] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - stage_manager: Optional[PipelineStageManager] = None, - hidden_states: Optional[torch.FloatTensor] = None, - stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: + self: GPT2LMHeadModel, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + stage_manager: Optional[PipelineStageManager] = None, + hidden_states: Optional[torch.FloatTensor] = None, + stage_index: Optional[List[int]] = None, + ) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]: r""" - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set - `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` - are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` - This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. - Please refer to original code of transformers for more details. - """ + This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward. + Please refer to original code of transformers for more details. + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # If is first stage and after warmup, go throught lm_head first if stage_manager.is_first_stage() and hidden_states is not None: lm_logits = self.lm_head(hidden_states) - return {'logits': lm_logits} + return {"logits": lm_logits} # Not first stage or before warmup, go through gpt2 model - outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer, - input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - stage_manager=stage_manager, - hidden_states=hidden_states, - stage_index=stage_index) + outputs = GPT2PipelineForwards.gpt2_model_forward( + self.transformer, + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + stage_manager=stage_manager, + hidden_states=hidden_states, + stage_index=stage_index, + ) return outputs diff --git a/colossalai/inference/pipeline/modeling/llama.py b/colossalai/inference/pipeline/modeling/llama.py index eeda96df2..f46e1fbdd 100644 --- a/colossalai/inference/pipeline/modeling/llama.py +++ b/colossalai/inference/pipeline/modeling/llama.py @@ -1,8 +1,6 @@ -from typing import List, Optional, Tuple +from typing import List, Optional import torch -from torch.nn import CrossEntropyLoss, MSELoss -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel from transformers.utils import logging @@ -10,10 +8,10 @@ from colossalai.pipeline.stage_manager import PipelineStageManager class LlamaPipelineForwards: - ''' + """ This class serves as a micro library for forward function substitution of Llama models under pipeline setting. - ''' + """ def llama_model_forward( self: LlamaModel, @@ -34,10 +32,10 @@ class LlamaPipelineForwards: # Preprocess passed in arguments if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False use_cache = use_cache if use_cache is not None else self.config.use_cache @@ -70,10 +68,9 @@ class LlamaPipelineForwards: seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: - position_ids = torch.arange(past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device) + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() @@ -81,16 +78,18 @@ class LlamaPipelineForwards: # embed positions, for the first stage, hidden_states is the input embeddings, # for the other stages, hidden_states is the output of the previous stage if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length_with_past), - dtype=torch.bool, - device=hidden_states.device) - attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states, - past_key_values_length) + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length + ) if self.gradient_checkpointing and self.training: if use_cache: logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) use_cache = False # decoder layers @@ -112,7 +111,6 @@ class LlamaPipelineForwards: if self.gradient_checkpointing and self.training: def create_custom_forward(module): - def custom_forward(*inputs): # None for past_key_value return module(*inputs, output_attentions, None) @@ -152,7 +150,7 @@ class LlamaPipelineForwards: next_cache = next_decoder_cache if use_cache else None # always return dict for imediate stage - return {'hidden_states': hidden_states, 'past_key_values': next_cache} + return {"hidden_states": hidden_states, "past_key_values": next_cache} def llama_for_causal_lm_forward( self: LlamaForCausalLM, @@ -171,45 +169,45 @@ class LlamaPipelineForwards: stage_index: Optional[List[int]] = None, ): r""" - Args: - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, LlamaForCausalLM + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM - >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) - >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) - >>> prompt = "Hey, are you consciours? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> prompt = "Hey, are you consciours? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." - ```""" + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." + ```""" logger = logging.get_logger(__name__) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if output_attentions: - logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.') + logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.") output_attentions = False if output_hidden_states: - logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.') + logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.") output_hidden_states = False # If is first stage and after warmup, go throught lm_head first if stage_manager.is_first_stage() and hidden_states is not None: lm_logits = self.lm_head(hidden_states) - return {'logits': lm_logits} + return {"logits": lm_logits} # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = LlamaPipelineForwards.llama_model_forward( diff --git a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py index e51090200..51e6425b1 100644 --- a/colossalai/inference/pipeline/policy/gpt2_ppinfer.py +++ b/colossalai/inference/pipeline/policy/gpt2_ppinfer.py @@ -11,7 +11,6 @@ from ..modeling.gpt2 import GPT2PipelineForwards class GPT2LMHeadModelPipelinePolicy(GPT2Policy): - def __init__(self) -> None: super().__init__() @@ -22,18 +21,22 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy): if self.shard_config.enable_tensor_parallelism: addon_module = { - GPT2LMHeadModel: - ModulePolicyDescription(sub_module_replacement=[ + GPT2LMHeadModel: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}) - ]) + suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True} + ) + ] + ) } module_policy.update(addon_module) if self.pipeline_stage_manager is not None: - self.set_pipeline_forward(model_cls=GPT2LMHeadModel, - new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, - policy=module_policy) + self.set_pipeline_forward( + model_cls=GPT2LMHeadModel, + new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward, + policy=module_policy, + ) return module_policy def get_held_layers(self) -> List[nn.Module]: @@ -45,7 +48,7 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy): return held_layers def get_shared_params(self) -> List[Dict[int, Tensor]]: - '''The weights of wte and lm_head are shared.''' + """The weights of wte and lm_head are shared.""" module = self.model stage_manager = self.pipeline_stage_manager if stage_manager is not None: @@ -56,16 +59,16 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy): def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None: """If under pipeline parallel setting, replacing the original forward method of huggingface - to customized forward method, and add this changing to policy.""" + to customized forward method, and add this changing to policy.""" if not self.pipeline_stage_manager: raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.") stage_manager = self.pipeline_stage_manager - if self.model.__class__.__name__ == 'GPT2Model': + if self.model.__class__.__name__ == "GPT2Model": module = self.model else: module = self.model.transformer layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages) stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage) - method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} + method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) diff --git a/colossalai/inference/pipeline/policy/llama_ppinfer.py b/colossalai/inference/pipeline/policy/llama_ppinfer.py index bb359de0b..6e12ed61b 100644 --- a/colossalai/inference/pipeline/policy/llama_ppinfer.py +++ b/colossalai/inference/pipeline/policy/llama_ppinfer.py @@ -1,19 +1,15 @@ -from functools import partial -from typing import Callable, Dict, List, Union +from typing import List -import torch.nn as nn -from torch import Tensor from torch.nn import Module -from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D -from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription +from colossalai.shardformer.layer import Linear1D_Col +from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription from colossalai.shardformer.policies.llama import LlamaPolicy from ..modeling.llama import LlamaPipelineForwards class LlamaForCausalLMPipelinePolicy(LlamaPolicy): - def __init__(self) -> None: super().__init__() @@ -25,19 +21,21 @@ class LlamaForCausalLMPipelinePolicy(LlamaPolicy): if self.shard_config.enable_tensor_parallelism: # add a new item for casual lm new_item = { - LlamaForCausalLM: - ModulePolicyDescription(sub_module_replacement=[ + LlamaForCausalLM: ModulePolicyDescription( + sub_module_replacement=[ SubModuleReplacementDescription( - suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)) - ]) + suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True) + ) + ] + ) } policy.update(new_item) if self.pipeline_stage_manager: # set None as default - self.set_pipeline_forward(model_cls=LlamaForCausalLM, - new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, - policy=policy) + self.set_pipeline_forward( + model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy + ) return policy diff --git a/colossalai/inference/pipeline/utils.py b/colossalai/inference/pipeline/utils.py index 1a6e8a519..c26aa4e40 100644 --- a/colossalai/inference/pipeline/utils.py +++ b/colossalai/inference/pipeline/utils.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Set +from typing import Set import torch.nn as nn @@ -30,6 +30,6 @@ def get_suffix_name(suffix: str, name: str): suffix (str): The suffix of the suffix module name (str): The name of the current module """ - point = '' if suffix is '' else '.' - suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}' + point = "" if suffix is "" else "." + suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}" return suffix_name diff --git a/colossalai/pipeline/p2p.py b/colossalai/pipeline/p2p.py index 67e198ca0..f822c1819 100644 --- a/colossalai/pipeline/p2p.py +++ b/colossalai/pipeline/p2p.py @@ -167,7 +167,7 @@ def _p2p_comm( group: ProcessGroup, comm_dtype: torch.dtype = torch.float16, ): - """ + """ Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication. Agrs: @@ -176,7 +176,7 @@ def _p2p_comm( peer (int): rank of the peer group (ProcessGroup): process group comm_dtype (torch.dtype): dtype of the tensor to be sent - + Returns: torch.Tensor: tensor received from previous stage """ @@ -302,7 +302,9 @@ class PipelineP2PCommunication: cur_rank = self.stage_manager.get_rank() _send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank)) - def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None: + def p2p_communicate( + self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16 + ) -> None: """ Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch. @@ -313,5 +315,7 @@ class PipelineP2PCommunication: if peer is None: peer = self.stage_manager.get_next_rank() cur_rank = self.stage_manager.get_rank() - recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype) + recv_tensor = _p2p_comm( + output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype + ) return recv_tensor diff --git a/colossalai/pipeline/schedule/generate.py b/colossalai/pipeline/schedule/generate.py index 8f6acd5fc..1f4bbe9f8 100644 --- a/colossalai/pipeline/schedule/generate.py +++ b/colossalai/pipeline/schedule/generate.py @@ -1,6 +1,6 @@ import time from functools import partial -from typing import Any, Iterable, List, Optional, Union +from typing import Any, Iterable, Optional, Union import torch import torch.cuda @@ -16,7 +16,7 @@ from ._utils import get_batch_size, get_micro_batch, model_forward, to_device from .base import PipelineSchedule -class ActionIntervalBuffer(): +class ActionIntervalBuffer: """ The buffer to save the interval hidden states and new token for stage to use. @@ -70,8 +70,9 @@ class GenerateSchedule(PipelineSchedule): self.batch = batch self.batch_size = get_batch_size(batch) self.microbatch_offset = 0 - assert self.batch_size % self.microbatch_size == 0, \ - f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" + assert ( + self.batch_size % self.microbatch_size == 0 + ), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}" self.num_microbatches = self.batch_size // self.microbatch_size self.round = self.num_microbatches // self.stage_manager.num_stages @@ -86,26 +87,26 @@ class GenerateSchedule(PipelineSchedule): return tree_map(partial(to_device, device=get_current_device()), micro_batch) def _prepare_inputs_for_interval_stage(self): - ''' + """ Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values Returns: dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None` - ''' - model_inputs = { - 'past_key_values': self.mb_manager.cur_kv_cache - } if self.mb_manager.cur_kv_cache is not None else None + """ + model_inputs = ( + {"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None + ) return model_inputs def _prepare_inputs_for_new_token(self, new_token: torch.Tensor): - ''' + """ Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values` `input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end, `past_key_values` is the past_key_values save in the micro batch manager Returns: dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}` - ''' + """ new_mask = self.mb_manager.cur_descrption.attn_mask past_key_values = self.mb_manager.cur_descrption.kv_cache @@ -117,12 +118,12 @@ class GenerateSchedule(PipelineSchedule): return input_ids def _recv_pre_stage(self) -> Any: - ''' + """ Receive the output from previous stage Returns: Any: The output from previous stage - ''' + """ if self.stage_manager.num_stages == 2: return self.comm.p2p_recv() return self.comm.recv_forward() @@ -138,7 +139,7 @@ class GenerateSchedule(PipelineSchedule): output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _gen_token_action(self, model: Module): """ @@ -146,13 +147,15 @@ class GenerateSchedule(PipelineSchedule): """ hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" - hidden_states = {'hidden_states': hidden_states} + hidden_states = {"hidden_states": hidden_states} logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits['logits']) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(None, None, new_token) self.action_interval_buffer.new_token = new_token @@ -168,17 +171,17 @@ class GenerateSchedule(PipelineSchedule): output_dict = model_forward(model, inputs_dict, None) self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _body_encoding_action(self, model: Module): hidden_states = self.action_interval_buffer.hidden_states assert hidden_states is not None, "When not first stage, the hidden states should not be None" inputs_dict = self._prepare_inputs_for_interval_stage() - hidden_states = {'hidden_states': hidden_states} + hidden_states = {"hidden_states": hidden_states} output_dict = model_forward(model, inputs_dict, hidden_states) self.mb_manager.step(inputs_dict, output_dict, None) - self.action_interval_buffer.hidden_states = output_dict['hidden_states'] + self.action_interval_buffer.hidden_states = output_dict["hidden_states"] def _comm_action(self, recv_pre: bool) -> torch.Tensor: """ @@ -246,10 +249,13 @@ class GenerateSchedule(PipelineSchedule): whole_timestamp = [] - #run by round + # run by round for _ in range(self.round): - self.timestamps = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + self.timestamps = ( + [[] for _ in range(self.stage_manager.num_stages)] + if self.verbose and self.stage_manager.is_first_stage() + else None + ) self.action_interval_buffer.clear() while self.mb_manager.is_micro_batch_done() is False: actions = self._gen_action(model) @@ -286,8 +292,11 @@ class GenerateSchedule(PipelineSchedule): whole_timestamp = [] # run by round for _ in range(self.round): - self.timestamps = [[] for _ in range(self.stage_manager.num_stages) - ] if self.verbose and self.stage_manager.is_first_stage() else None + self.timestamps = ( + [[] for _ in range(self.stage_manager.num_stages)] + if self.verbose and self.stage_manager.is_first_stage() + else None + ) while self.mb_manager.is_micro_batch_done() is False: inputs_dict = None new_token = None @@ -307,13 +316,17 @@ class GenerateSchedule(PipelineSchedule): hidden_states = self.comm.recv_forward() if self.stage_manager.is_first_stage(): # First just generate a new token - assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None" + assert ( + hidden_states is not None + ), "When first stage in GENERATE phase, the hidden states should not be None" logits = model_forward(model, None, hidden_states) if self.verbose and self.stage_manager.is_first_stage(): torch.cuda.synchronize() self.timestamps[self.mb_manager.idx].append(time.time()) - assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" - new_token = self._get_token_id(logits['logits']) + assert ( + "logits" in logits + ), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}" + new_token = self._get_token_id(logits["logits"]) self.mb_manager.step(None, None, new_token) # If the current micro batch is not DONE, go through blocks if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN): @@ -327,9 +340,11 @@ class GenerateSchedule(PipelineSchedule): self.mb_manager.step(inputs_dict, output_dict, None) # Current microbatch is not DONE, send hidden_state to next stage - if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE, - Status.COOLDOWN): - self.comm.send_forward({'hidden_states': output_dict['hidden_states']}) + if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in ( + Status.GENERATE, + Status.COOLDOWN, + ): + self.comm.send_forward({"hidden_states": output_dict["hidden_states"]}) self.mb_manager.next() diff --git a/tests/test_infer/test_pipeline_infer.py b/tests/test_infer/test_pipeline_infer.py index 47cf9e78d..ad8e32b48 100644 --- a/tests/test_infer/test_pipeline_infer.py +++ b/tests/test_infer/test_pipeline_infer.py @@ -1,9 +1,6 @@ -from copy import deepcopy - import pytest import torch import torch.distributed as dist -import torch.nn as nn import transformers import colossalai @@ -20,27 +17,29 @@ def data_gen(): inputs = data_gen() for k, v in inputs.items(): - if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__: + if torch.is_tensor(v) or "Tensor" in v.__class__.__name__: new_shape = [1] * v.dim() new_shape[0] = 16 - inputs[k] = v.to('cuda').repeat(*new_shape) + inputs[k] = v.to("cuda").repeat(*new_shape) def pipeline_inference_test(pp_size, new_length, micro_batch_size): model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8)) - engine = PPInferEngine(pp_size=pp_size, - model=model, - model_policy=GPT2LMHeadModelPipelinePolicy(), - new_length=new_length, - micro_batch_size=micro_batch_size) + engine = PPInferEngine( + pp_size=pp_size, + model=model, + model_policy=GPT2LMHeadModelPipelinePolicy(), + new_length=new_length, + micro_batch_size=micro_batch_size, + ) output = engine.inference([inputs]) if dist.get_rank() == 0: assert len(output[0]) == new_length, f"{len(output)}, {new_length}" -@parameterize('pp_size', [4]) -@parameterize('new_length', [4, 8, 16]) -@parameterize('micro_batch_size', [1, 4]) +@parameterize("pp_size", [4]) +@parameterize("new_length", [4, 8, 16]) +@parameterize("micro_batch_size", [1, 4]) @clear_cache_before_run() def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): pipeline_inference_test(pp_size, new_length, micro_batch_size) @@ -48,7 +47,7 @@ def run_pipeline_inference_test(pp_size, new_length, micro_batch_size): def check_pipeline_inference(rank, world_size, port): - colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") run_pipeline_inference_test() @@ -59,5 +58,5 @@ def test_pipeline_inference(): spawn(check_pipeline_inference, nprocs=4) -if __name__ == '__main__': +if __name__ == "__main__": test_pipeline_inference()