[Pipeline Inference] Sync pipeline inference branch to main (#4820)

* [pipeline inference] pipeline inference (#4492)

* add pp stage manager as circle stage

* fix a bug when create process group

* add ppinfer basic framework

* add micro batch manager and support kvcache-pp gpt2 fwd

* add generate schedule

* use mb size to control mb number

* support generate with kv cache

* add output, remove unused code

* add test

* reuse shardformer to build model

* refactor some code and use the same attribute name of hf

* fix review and add test for generation

* remove unused file

* fix CI

* add cache clear

* fix code error

* fix typo

* [Pipeline inference] Modify to tieweight (#4599)

* add pp stage manager as circle stage

* fix a bug when create process group

* add ppinfer basic framework

* add micro batch manager and support kvcache-pp gpt2 fwd

* add generate schedule

* use mb size to control mb number

* support generate with kv cache

* add output, remove unused code

* add test

* reuse shardformer to build model

* refactor some code and use the same attribute name of hf

* fix review and add test for generation

* remove unused file

* modify the way of saving newtokens

* modify to tieweight

* modify test

* remove unused file

* solve review

* add docstring

* [Pipeline inference] support llama pipeline inference (#4647)

* support llama pipeline inference

* remove tie weight operation

* [pipeline inference] Fix the blocking of communication when ppsize is 2 (#4708)

* add benchmark verbose

* fix export tokens

* fix benchmark verbose

* add P2POp style to do p2p communication

* modify schedule as p2p type when ppsize is 2

* remove unused code and add docstring

* [Pipeline inference] Refactor code, add docsting, fix bug (#4790)

* add benchmark script

* update argparse

* fix fp16 load

* refactor code style

* add docstring

* polish code

* fix test bug

* [Pipeline inference] Add pipeline inference docs (#4817)

* add readme doc

* add a ico

* Add performance

* update table of contents

* refactor code (#4873)
pull/4886/head
Bin Jia 2023-10-11 11:40:06 +08:00 committed by GitHub
parent 652adc2215
commit 08a9f76b2f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 1754 additions and 0 deletions

View File

@ -0,0 +1,3 @@
from .pipeline import PPInferEngine
__all__ = ['PPInferEngine']

View File

@ -0,0 +1,84 @@
# 🐳 Pipeline Inference
## Table of Contents
- [💡 Introduction](#introduction)
- [🔗 Design](#design)
- [🔨 Usage](#usage)
- [Example](#example)
- [Quick start](#quick-start)
- [📊 Performance](#performance)
## Introduction
`Pipeline Inference` is a module designed to make inference on a pipeline way. In inference systems, although there is no need to store intermediate information such as activations during forward propagation for backward propagation, the weights of some larger models still cannot fit on a single GPU for inference. This requires us to use model parallelism and other methods to reduce the memory occupation on a single GPU. Pipeline parallelism, as one of the traditional model parallelism approaches, has been widely used due to its reduced all-reduce communication requirements and simple layout. The main issue with pipeline parallelism, known as bubbles, can be almost eliminated in inference because the backward propagation that causes bubbles no longer exists in inference. This makes pipeline parallelism almost bubble-free in the ideal scenario where the sequence length is the same across the pipeline.
## Design
Pipeline Inference is composed of three parts: `PPInferEngine`, `MicroBatchManager` and `generate` [schedule](https://github.com/hpcaitech/ColossalAI/blob/feature/pipeline-infer/colossalai/pipeline/schedule/generate.py).
1. `PPInderEngine` is the High-Level API for users to use. It is responsible for the following tasks:
- Initialize the pipeline inference environment with `PipelineStageManager` and mdoel with `ShardFormer`.
- Run the pipeline inference model.
2. `MicroBatchManager` is a structure to manage the micro-batch information. It is responsible for the following tasks:
- Record each micro-batch information, like generated new tokens and kvcache.
- Record each micro-batch inference state, like prefill, generate or done.
- Update the micro-batch information.
3. `generate` schedule implements the simple pipeline inference layout. When pipeline size is 2, we use `torch.distributed.P2Pop` to implement the communication between stages, mainly to solve the race communication. When pipeline size is larger than 2, we use `torch.distributed.broadcast` which is faster than `torch.distributed.P2Pop`.
## Usage
### Example
```python
from colossalai.pipeline import PPInferEngine
# Suppose the pipeline size is 2, and use fp16 to do infenrence. Use Llama as an example.
model = LlamaForCausalLM.from_pretrained('/path/to/model')
inputs = tokenizer("Hello, my dog is cute", "What a good day", return_tensors="pt")
engine = PPInferEngine(
pp_size=2,
dtype='fp16',
micro_batch_size=1,
new_length=10,
model=model,
model_policy=LlamaForCausalLMPipelinePolicy())
output = engine.inference([inputs])
```
### Quick start
```shell
cd benchmark
sh run.sh
```
## Performance
We conducted multiple benchmark tests to evaluate the performance. We compared the inference `latency` and `throughputs` between `Pipeline Inference` and `hugging face` pipeline. The test environment is 2*A10, 20G.
### Llama Throughput(tokens/s)
#### 7b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(8) | 32(8) | 32(16)|
| :---: | :---: | :---: | :---: | :---: | :---: | :---:|
| Pipeline Inference(1024, 128) | 33.31 | 59.98 | 98.92 | 143.47 | 152.61 | OOM |
| Hugging Face(1024, 128) | 41.43 | 65.30 | 91.93 | 114.62 | OOM| OOM |
| Pipeline Inference(512, 512) | 43.37 | 82.81 | 148.03 | 229.06 | 238.67 | 312.82 |
| Hugging Face(512, 512) | 49.13 | 84.91 | 132.87 | 178.30 | OOM| OOM |
#### 7b, fp32
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 20.61 | 31.23 | 45.20 | 47.46 |
| Hugging Face(1024, 128) | 19.80 | 29.37| OOM | OOM |
| Pipeline Inference(512, 512) | 28.07 | 46.76 | 79.35 | 81.70 |
| Hugging Face(512, 512) | 25.67 | 43.97 | 60.67 | OOM |
#### 13b, fp16
| batch_size(micro_batch size)| 2(1) | 4(2) | 8(4) | 16(4) |
| :---: | :---: | :---: | :---: | :---: |
| Pipeline Inference(1024, 128) | 21.73 | 38.06 | 61.02 | 64.30 |
| Hugging Face(1024, 128) | 23.48 | 37.59 | 53.44 | OOM |
| Pipeline Inference(512, 512) | 26.65 | 49.48 | 86.11 | 88.44 |
| Hugging Face(512, 512) | 27.45 | 47.74 | 74.46 | OOM |

View File

@ -0,0 +1,3 @@
from .engine import PPInferEngine
__all__ = ['PPInferEngine']

View File

@ -0,0 +1,112 @@
import torch
import torch.distributed as dist
import transformers
import colossalai
import time
from colossalai.inference import PPInferEngine
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
import argparse
GIGABYTE = 1024 ** 3
MEGABYTE = 1024 * 1024
colossalai.launch_from_torch(config={})
def data_gen(batch_size: int=4, seq_len: int=512):
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
data = dict(input_ids=input_ids, attention_mask=attention_mask)
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = batch_size
data[k] = v.to('cuda').repeat(*new_shape)
return data
def print_details_info(timestamps, model_config, args, whole_end2end):
if dist.get_rank() == 0:
prefill = []
encoder = []
end2end = []
for timestamp in timestamps:
prefill.append(timestamp[1] - timestamp[0])
encoder.append(
sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2))
end2end.append(timestamp[-1] - timestamp[0])
print(whole_end2end)
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f:
mb_avg_end2end = sum(end2end)/len(end2end)
mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size)
whole_avg_latency = whole_end2end/(args.new_length * args.batch_size)
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
if args.dtype in ['fp16','bf16']:
num_bytes = 2
else:
num_bytes = 4
f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n")
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000))
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000))
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000))
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000))
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000))))
f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12))
f.write("----------------------------------------------------------\n")
if torch.cuda.is_available():
current_device = torch.cuda.current_device()
# free memory and the total available memory in bytes
global_free_memory, total_GPU_memory_occupied = torch.cuda.mem_get_info()
memory_allocated = torch.cuda.memory_allocated()
max_memory_allocated = torch.cuda.max_memory_allocated()
memory_reserved = torch.cuda.memory_reserved()
max_memory_reserved = torch.cuda.max_memory_reserved()
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f:
f.write(
f"\nCurrently using GPU: {current_device}\n"
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
f"total memory: {total_GPU_memory_occupied / GIGABYTE:.4f} GB,\n"
f"memory allocated: {memory_allocated / GIGABYTE:.4f} GB,\n"
f"Max CUDA memory allocated: {max_memory_allocated / GIGABYTE:.4f} GB,\n"
f"memory reserved/cached: {memory_reserved / GIGABYTE:.4f} GB,\n"
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', default='toy', help='the size of model')
parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size')
parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length')
parser.add_argument('--new_length', type=int, default=4, help='new tokens length')
parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size')
parser.add_argument('--pp_size', type=int, default=2, help='pipeline size')
parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log')
parser.add_argument('--dtype', type=str, default='fp16', help='data type')
args = parser.parse_args()
if args.model == 'toy':
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
elif args.model == '7b':
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf'))
elif args.model == '13b':
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf'))
else:
raise NotImplementedError
engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True)
data = data_gen(args.batch_size, args.seq_len)
torch.cuda.synchronize()
whole_end2end = time.time()
output, timestamps = engine.inference([data])
torch.cuda.synchronize()
whole_end2end = time.time() - whole_end2end
print_details_info(timestamps, model.config, args, whole_end2end)

View File

@ -0,0 +1,50 @@
script_dir=$(cd "$(dirname "$0")" && pwd)
cd "${script_dir}"
# 7b, fp32, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=1024 \
--new_length=128 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done
# 7b, fp32, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16 32; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="7b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=512 \
--new_length=512 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done
# 7b, fp32, 2 gpu, 1024, 128
for BATCH_SIZE in 2 4 8; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=1024 \
--new_length=128 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done
# 13b, fp16, 2 gpu, 512, 512
for BATCH_SIZE in 2 4 8 16; do
CUDA_VISIBLE_DEVICES=0,1 colossalai run --nproc_per_node 2 --master_port 29800 ./benchmark.py \
--model="13b" \
--dtype="fp16" \
--batch_size=${BATCH_SIZE} \
--seq_len=512 \
--new_length=512 \
--mb_size=$((${BATCH_SIZE}/2)) \
--pp_size=2
done

View File

@ -0,0 +1,98 @@
from typing import Callable, List, Optional, Set, Union
import torch
import torch.nn as nn
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.schedule.generate import GenerateSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer.policies.base_policy import Policy
from .microbatch_manager import MicroBatchManager
class PPInferEngine:
'''
PPInferEngine is a class that handles the pipeline parallel inference.
Args:
pp_size (int): the number of pipeline stages.
pp_model (`nn.Module`): the model already in pipeline parallelism style.
model (`nn.Module`): the model not in pipeline style, and will be modified with `ShardFormer`.
model_policy (`colossalai.shardformer.policies.base_policy.Policy`): the policy to shardformer model.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
new_length (int): the new length of the input sequence.
early_stopping (bool): whether to stop early.
Example:
```python
from colossalai.ppinference import PPInferEngine
from transformers import GPT2LMHeadModel, GPT2Tokenizer
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2')
# assume the model is infered with 4 pipeline stages
inferengine = PPInferEngine(pp_size=4, model=model, model_policy={Your own policy for pipeline sharding})
input = ["Hello, my dog is cute, and I like"]
tokenized_input = tokenizer(input, return_tensors='pt')
output = engine.inference([tokenized_input])
```
'''
def __init__(
self,
pp_size: int,
dtype: str = 'fp16',
pp_model: nn.Module = None,
model: nn.Module = None,
model_policy: Policy = None,
new_length: int = 32,
micro_batch_size: int = 1,
micro_batch_buffer_size: int = None,
verbose: bool = False,
# TODO: implement early_stopping, and various gerneration options
early_stopping: bool = False,
do_sample: bool = False,
num_beams: int = 1,
) -> None:
assert pp_model or (model and model_policy), "Either pp_model or model with model_policy should be provided."
self.pp_size = pp_size
self.pg_mesh = ProcessGroupMesh(pp_size)
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
micro_batch_buffer_size or pp_size)
self.verbose = verbose
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'"
if dtype == 'fp16':
model.half()
elif dtype == 'bf16':
model.to(torch.bfloat16)
self.model = pp_model or self._shardformer(model, model_policy)
def inference(self, input_list):
out, timestamp = self.schedule.generate_step(self.model, iter(input_list))
if self.verbose:
return out, timestamp
else:
return out
def _shardformer(self, model, model_policy):
shardconfig = ShardConfig(
tensor_parallel_process_group=None,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=False,
enable_fused_normalization=False,
enable_all_optimization=False,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
)
shardformer = ShardFormer(shard_config=shardconfig)
shard_model, _ = shardformer.optimize(model, model_policy)
return shard_model.cuda()

View File

@ -0,0 +1,236 @@
from enum import Enum
from typing import Dict, Tuple
import torch
__all__ = 'MicroBatchManager'
class Status(Enum):
PREFILL = 1
GENERATE = 2
DONE = 3
COOLDOWN = 4
class MicroBatchDescription():
"""
This is the class to record the infomation of each microbatch, and also do some update operation.
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
details, please refer to the doc of these two classes blow.
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""
def __init__(
self,
inputs_dict: Dict[str, torch.Tensor],
output_dict: Dict[str, torch.Tensor],
new_length: int,
) -> None:
assert output_dict.get('hidden_states') is not None
self.mb_length = output_dict['hidden_states'].shape[-2]
self.target_length = self.mb_length + new_length
self.kv_cache = ()
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
if output_dict is not None:
self._update_kvcache(output_dict['past_key_values'])
def _update_kvcache(self, kv_cache: Tuple):
assert type(kv_cache) == tuple
self.kv_cache = kv_cache
@property
def state(self):
"""
Return the state of current micro batch, when current length is equal to target length,
the state is DONE, otherwise GENERATE
"""
# TODO: add the condition for early stopping
if self.cur_length == self.target_length:
return Status.DONE
elif self.cur_length == self.target_length - 1:
return Status.COOLDOWN
else:
return Status.GENERATE
@property
def cur_length(self):
"""
Return the current sequnence length of micro batch
"""
pass
class HeadMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the first stage of pipeline, the first stage should have attributes `input_ids` and `attention_mask`
and `new_tokens`, and the `new_tokens` is the tokens generated by the first stage. Also due to the schdule of pipeline, the operation to update the
information and the condition to determine the state is different from other stages.
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_length (int): the new length of the input sequence.
"""
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
new_length: int) -> None:
super().__init__(inputs_dict, output_dict, new_length)
assert inputs_dict is not None
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None
self.input_ids = inputs_dict['input_ids']
self.attn_mask = inputs_dict['attention_mask']
self.new_tokens = None
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
super().update(output_dict, new_token)
if new_token is not None:
self._update_newtokens(new_token)
if self.state is not Status.DONE and new_token is not None:
self._update_attnmask()
def _update_newtokens(self, new_token: torch.Tensor):
if self.new_tokens is None:
self.new_tokens = new_token
else:
self.new_tokens = torch.cat([self.new_tokens, new_token], dim=-1)
def _update_attnmask(self):
self.attn_mask = torch.cat(
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1)
@property
def cur_length(self):
"""
When there is no new_token, the length is mb_length, otherwise the sequence length is `mb_length` plus the length of new_token
"""
if self.new_tokens is None:
return self.mb_length
else:
return self.mb_length + len(self.new_tokens[0])
class BodyMicroBatchDescription(MicroBatchDescription):
"""
This class is used to record the infomation of the stages except the first stage of pipeline, the stages should have attributes `hidden_states` and `past_key_values`,
Args:
inputs_dict (Dict[str, torch.Tensor]): will always be `None`. Other stages only receive hiddenstates from previous stage.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
"""
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
new_length: int) -> None:
super().__init__(inputs_dict, output_dict, new_length)
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
super().update(output_dict, new_token)
@property
def cur_length(self):
"""
When there is no kv_cache, the length is mb_length, otherwise the sequence length is `kv_cache[0][0].shape[-2]` plus 1
"""
if len(self.kv_cache) == 0:
return self.mb_length
else:
return self.kv_cache[0][0].shape[-2] + 1
class MicroBatchManager():
'''
MicroBatchManager is a class that manages the micro batch.
Args:
stage (int): stage id of current stage.
new_length (int): the new length of the input sequence.
micro_batch_size (int): the micro batch size.
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
'''
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
self.stage = stage
self.new_length = new_length
self.micro_batch_size = micro_batch_size
self.buffer_size = micro_batch_buffer_size
self.mb_descrption_buffer = {}
self.new_tokens_buffer = {}
self.idx = 0
def step(self, inputs_dict=None, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
"""
Update the state if microbatch manager, 2 conditions.
1. For first stage in PREFILL, receive inputs and outputs, `_add_descrption` will save its inputs.
2. For other conditon, only receive the output of previous stage, and update the descrption.
Args:
inputs_dict (Dict[str, torch.Tensor]): the inputs of current stage. The key should have `input_ids` and `attention_mask`.
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
new_token (torch.Tensor): the new token generated by current stage.
"""
# Add descrption first if the descrption is None
if inputs_dict is None and output_dict is None and new_token is None:
return Status.PREFILL
if self.mb_descrption_buffer.get(self.idx) is None:
self._add_descrption(inputs_dict, output_dict)
self.cur_descrption.update(output_dict, new_token)
return self.cur_state
def export_new_tokens(self):
new_tokens_list = []
for i in self.mb_descrption_buffer.values():
new_tokens_list.extend(i.new_tokens.tolist())
return new_tokens_list
def is_micro_batch_done(self):
if len(self.mb_descrption_buffer) == 0:
return False
for mb in self.mb_descrption_buffer.values():
if mb.state != Status.DONE:
return False
return True
def clear(self):
self.mb_descrption_buffer.clear()
def next(self):
self.idx = (self.idx + 1) % self.buffer_size
def _add_descrption(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor]):
if self.stage == 0:
self.mb_descrption_buffer[self.idx] = HeadMicroBatchDescription(inputs_dict, output_dict, self.new_length)
else:
self.mb_descrption_buffer[self.idx] = BodyMicroBatchDescription(inputs_dict, output_dict, self.new_length)
def _remove_descrption(self):
self.mb_descrption_buffer.pop(self.idx)
@property
def cur_descrption(self) -> MicroBatchDescription:
return self.mb_descrption_buffer.get(self.idx)
@property
def cur_kv_cache(self):
if self.cur_descrption is None:
return None
return self.cur_descrption.kv_cache
@property
def cur_state(self):
"""
Return the state of current micro batch, when current descrption is None, the state is PREFILL
"""
if self.cur_descrption is None:
return Status.PREFILL
return self.cur_descrption.state

View File

@ -0,0 +1,278 @@
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
class GPT2PipelineForwards:
'''
This class serves as a micro library for forward function substitution of GPT2 models
under pipeline setting.
'''
@staticmethod
def gpt2_model_forward(
self: GPT2Model,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
else:
if hidden_states is None:
raise ValueError("hidden_states shouldn't be None for stages other than the first stage.")
input_shape = hidden_states.size()[:-1]
batch_size, seq_length = input_shape[0], input_shape[1]
device = hidden_states.device
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if stage_manager.is_first_stage():
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
else:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
for i, layer_past in zip(range(start_idx, end_idx), past_key_values):
block = self.h[i]
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
if stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
return {'hidden_states': hidden_states, 'past_key_values': presents}
@staticmethod
def gpt2_lmhead_model_forward(
self: GPT2LMHeadModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
Please refer to original code of transformers for more details.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# If is first stage and after warmup, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {'logits': lm_logits}
# Not first stage or before warmup, go through gpt2 model
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
return outputs

View File

@ -0,0 +1,231 @@
from typing import List, Optional, Tuple
import torch
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
class LlamaPipelineForwards:
'''
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
'''
def llama_model_forward(
self: LlamaModel,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
logger = logging.get_logger(__name__)
# Preprocess passed in arguments
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# retrieve input_ids and inputs_embeds
if stage_manager.is_first_stage():
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
else:
input_shape = hidden_states.shape[:-1]
batch_size, seq_length = input_shape
device = hidden_states.device
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if position_ids is None:
position_ids = torch.arange(past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# embed positions, for the first stage, hidden_states is the input embeddings,
# for the other stages, hidden_states is the output of the previous stage
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past),
dtype=torch.bool,
device=hidden_states.device)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states,
past_key_values_length)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = () if use_cache else None
start_idx, end_idx = stage_index[0], stage_index[1]
if past_key_values is None:
past_key_values = tuple([None] * (end_idx - start_idx + 1))
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
decoder_layer = self.layers[idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
# past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
position_ids,
None,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if stage_manager.is_last_stage():
hidden_states = self.norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
# always return dict for imediate stage
return {'hidden_states': hidden_states, 'past_key_values': next_cache}
def llama_for_causal_lm_forward(
self: LlamaForCausalLM,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
logger = logging.get_logger(__name__)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
# If is first stage and after warmup, go throught lm_head first
if stage_manager.is_first_stage() and hidden_states is not None:
lm_logits = self.lm_head(hidden_states)
return {'logits': lm_logits}
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward(
self.model,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
)
return outputs

View File

@ -0,0 +1,71 @@
from functools import partial
from typing import Callable, Dict, List
from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.policies.gpt2 import GPT2Policy
from ..modeling.gpt2 import GPT2PipelineForwards
class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
module_policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
addon_module = {
GPT2LMHeadModel:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
policy=module_policy)
return module_policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
# make the tie weight lm_head and embedding in the same device to save memory
# if self.pipeline_stage_manager.is_first_stage():
if self.pipeline_stage_manager.is_first_stage():
held_layers.append(self.model.lm_head)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
'''The weights of wte and lm_head are shared.'''
module = self.model
stage_manager = self.pipeline_stage_manager
if stage_manager is not None:
if stage_manager.num_stages > 1 and id(module.transformer.wte.weight) == id(module.lm_head.weight):
first_stage, last_stage = 0, stage_manager.num_stages - 1
return [{first_stage: module.transformer.wte.weight, last_stage: module.lm_head.weight}]
return []
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'GPT2Model':
module = self.model
else:
module = self.model.transformer
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)

View File

@ -0,0 +1,50 @@
from functools import partial
from typing import Callable, Dict, List, Union
import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
from colossalai.shardformer.policies.llama import LlamaPolicy
from ..modeling.llama import LlamaPipelineForwards
class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import LlamaForCausalLM
policy = super().module_policy()
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default
self.set_pipeline_forward(model_cls=LlamaForCausalLM,
new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[Module]:
"""Get pipeline layers for current stage."""
stage_manager = self.pipeline_stage_manager
held_layers = super().get_held_layers()
if stage_manager.is_first_stage():
held_layers.append(self.model.lm_head)
return held_layers

View File

@ -0,0 +1,35 @@
from typing import List, Optional, Set
import torch.nn as nn
from colossalai.shardformer._utils import getattr_, setattr_
def set_tensors_to_none(model: nn.Module, include: Set[str] = set()) -> None:
"""
Set all parameters and buffers of model to None
Args:
model (nn.Module): The model to set
"""
for module_suffix in include:
set_module = getattr_(model, module_suffix)
for n, p in set_module.named_parameters():
setattr_(set_module, n, None)
for n, buf in set_module.named_buffers():
setattr_(set_module, n, None)
setattr_(model, module_suffix, None)
def get_suffix_name(suffix: str, name: str):
"""
Get the suffix name of the module, as `suffix.name` when name is string or `suffix[name]` when name is a digit,
and 'name' when `suffix` is empty.
Args:
suffix (str): The suffix of the suffix module
name (str): The name of the current module
"""
point = '' if suffix is '' else '.'
suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}'
return suffix_name

View File

@ -160,6 +160,86 @@ def _recv_object(src: int, dst: int, group: ProcessGroup) -> Any:
return object_list[0]
def _p2p_comm(
tensor_send_next: torch.Tensor,
recv_prev: bool,
peer: int,
group: ProcessGroup,
comm_dtype: torch.dtype = torch.float16,
):
"""
Send and recv tensor using P2P communication, used when pipeline size is 2 to solve the race communication.
Agrs:
tensor_send_next (torch.Tensor): tensor to be sent to next stage
recv_prev (bool): whether to receive tensor from previous stage
peer (int): rank of the peer
group (ProcessGroup): process group
comm_dtype (torch.dtype): dtype of the tensor to be sent
Returns:
torch.Tensor: tensor received from previous stage
"""
# send and recv shape
send_next_shape = None
recv_prev_shape = None
if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
ops = []
if send_next_shape is not None:
send_next_op = dist.P2POp(dist.isend, send_next_shape, peer=peer, group=group)
ops.append(send_next_op)
if recv_prev_shape is not None:
recv_prev_op = dist.P2POp(
dist.irecv,
recv_prev_shape,
peer=peer,
group=group,
)
ops.append(recv_prev_op)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
if recv_prev_shape is not None:
recv_prev_shape = recv_prev_shape.tolist()
# send and recv data
tensor_recv_prev = None
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
ops = []
if tensor_send_next is not None:
send_next_op = dist.P2POp(
dist.isend,
tensor_send_next,
peer=peer,
group=group,
)
ops.append(send_next_op)
if tensor_recv_prev is not None:
recv_prev_op = dist.P2POp(
dist.irecv,
tensor_recv_prev,
peer=peer,
group=group,
)
ops.append(recv_prev_op)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return tensor_recv_prev
class PipelineP2PCommunication:
def __init__(self, stage_manager: PipelineStageManager) -> None:
self.stage_manager = stage_manager
@ -221,3 +301,17 @@ class PipelineP2PCommunication:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None:
"""
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if peer is None:
peer = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype)
return recv_tensor

View File

@ -0,0 +1,343 @@
import time
from functools import partial
from typing import Any, Iterable, List, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.inference.pipeline.microbatch_manager import MicroBatchManager, Status
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
from .base import PipelineSchedule
class ActionIntervalBuffer():
"""
The buffer to save the interval hidden states and new token for stage to use.
"""
def __int__(self):
self.hidden_states = None
self.new_token = None
def clear(self):
self.hidden_states = None
self.new_token = None
class GenerateSchedule(PipelineSchedule):
"""
GenerateSchedule is a class that handles the pipeline parallel inference.
In our schedule, we place tie weight layer, embedding and lm_head in the same device to save space, so in
this schedule, the out for each encoding progress is on rank0.
Args:
stage_manager (`PipelineStageManager`): Pipeline stage manager.
mb_manager (`MicroBatchManager`): Micro batch manager.
verbose (bool): Whether to verbose the information of the pipeline.
"""
def __init__(self, stage_manager: PipelineStageManager, mb_manager: MicroBatchManager, verbose: bool) -> None:
super().__init__(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager)
self.mb_manager = mb_manager
self.microbatch_size = mb_manager.micro_batch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.num_microbatches: Optional[int] = None
self.action_interval_buffer = ActionIntervalBuffer()
self.verbose = verbose
self.timestamps = None
self.comm_dtype = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
assert self.batch_size % self.microbatch_size == 0, \
f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
self.num_microbatches = self.batch_size // self.microbatch_size
self.round = self.num_microbatches // self.stage_manager.num_stages
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.
Returns:
Any: Micro batch.
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset, self.microbatch_size)
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def _prepare_inputs_for_interval_stage(self):
'''
Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values
Returns:
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
'''
model_inputs = {
'past_key_values': self.mb_manager.cur_kv_cache
} if self.mb_manager.cur_kv_cache is not None else None
return model_inputs
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
'''
Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values`
`input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end,
`past_key_values` is the past_key_values save in the micro batch manager
Returns:
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
'''
new_mask = self.mb_manager.cur_descrption.attn_mask
past_key_values = self.mb_manager.cur_descrption.kv_cache
return dict(input_ids=new_token, attention_mask=new_mask, past_key_values=past_key_values)
def _get_token_id(self, hidden_state: torch.Tensor) -> torch.Tensor:
last_hidden_state = hidden_state[:, -1]
input_ids = torch.argmax(last_hidden_state, dim=-1).unsqueeze(1)
return input_ids
def _recv_pre_stage(self) -> Any:
'''
Receive the output from previous stage
Returns:
Any: The output from previous stage
'''
if self.stage_manager.num_stages == 2:
return self.comm.p2p_recv()
return self.comm.recv_forward()
def _load_stage_action(self, model: Module) -> None:
"""
In this action, 1.load micro_batch 2.do the forward 3.step to update
"""
inputs_dict = self.load_micro_batch()
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
def _gen_token_action(self, model: Module):
"""
In this action, 1.do the forward with hidden_states to generate new tokens 2.step to update
"""
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
hidden_states = {'hidden_states': hidden_states}
logits = model_forward(model, None, hidden_states)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
self.mb_manager.step(None, None, new_token)
self.action_interval_buffer.new_token = new_token
self.action_interval_buffer.hidden_states = None
def _head_encoding_action(self, model: Module):
"""
In this action, 1.prepare inputs for encoding for first stage. 2.do the forward to get hidden states 3.step to update
"""
new_token = self.action_interval_buffer.new_token
assert new_token is not None, "When first stage in GENERATE phase, the new token should not be None"
inputs_dict = self._prepare_inputs_for_new_token(new_token)
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
def _body_encoding_action(self, model: Module):
hidden_states = self.action_interval_buffer.hidden_states
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
inputs_dict = self._prepare_inputs_for_interval_stage()
hidden_states = {'hidden_states': hidden_states}
output_dict = model_forward(model, inputs_dict, hidden_states)
self.mb_manager.step(inputs_dict, output_dict, None)
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
"""
In this action, 1.receive the hidden_states from previous stage 2.send the hidden_states to next stage
"""
hidden_states = self.action_interval_buffer.hidden_states
ret = self.comm.p2p_communicate(hidden_states, recv_pre, comm_dtype=self.comm_dtype)
self.action_interval_buffer.hidden_states = ret
def _gen_action(self, model: Module):
"""
In p2p step method, we use `P2POp` asynchronous communication method, so the communication need to be done
at the begin of each microbatch, it's a more clear way to use an action list to do so. In this function, it will
generate a sequence action for current state, and do the action one by one.
Args:
model (Module): Model to be run.
Returns:
List[Callable]: A list of action, each action is a callable function, and it will be called in order.
"""
actions = []
if self.stage_manager.is_first_stage():
if self.mb_manager.cur_state is Status.PREFILL:
actions.append(partial(self._comm_action, False))
actions.append(partial(self._load_stage_action, model))
elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.GENERATE:
actions.append(partial(self._comm_action, True))
actions.append(partial(self._gen_token_action, model))
actions.append(partial(self._head_encoding_action, model))
elif self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.COOLDOWN:
actions.append(partial(self._comm_action, True))
actions.append(partial(self._gen_token_action, model))
# other stage
else:
actions.append(partial(self._comm_action, True))
actions.append(partial(self._body_encoding_action, model))
return actions
def generate_step(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
if self.stage_manager.num_stages == 2:
return self.generate_step_p2p(model, data_iter)
else:
return self.generate_step_broadcast(model, data_iter)
@torch.no_grad()
def generate_step_p2p(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline, when pipeline size is 2, the schedule is a circle, broadcast communication will be
blocked, so we use `P2POp` asynchronous communication method.
Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
self.comm_dtype = model.dtype
whole_timestamp = []
#run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
] if self.verbose and self.stage_manager.is_first_stage() else None
self.action_interval_buffer.clear()
while self.mb_manager.is_micro_batch_done() is False:
actions = self._gen_action(model)
for action in actions:
action()
self.mb_manager.next()
# All microbatch in current round is DONE
if self.stage_manager.is_first_stage():
output_sequence.extend(self.mb_manager.export_new_tokens())
else:
self._comm_action(False)
self.mb_manager.clear()
if self.verbose and self.stage_manager.is_first_stage():
whole_timestamp.extend(self.timestamps)
return output_sequence, whole_timestamp
@torch.no_grad()
def generate_step_broadcast(self, model: Module, data_iter: Iterable) -> Union[torch.Tensor, dict]:
"""
Forward one step of the pipeline
Args:
model (Module): Model to be run.
data_iter (Iterable): Data iterator.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
output_sequence = []
self.load_batch(data_iter)
model.eval()
whole_timestamp = []
# run by round
for _ in range(self.round):
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
] if self.verbose and self.stage_manager.is_first_stage() else None
while self.mb_manager.is_micro_batch_done() is False:
inputs_dict = None
new_token = None
output_dict = None
# First stage and in PREFILL phase, just load the inputs
if self.stage_manager.is_first_stage() and self.mb_manager.cur_state is Status.PREFILL:
inputs_dict = self.load_micro_batch()
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
# In GENERATE phase
else:
# Get hidden_states from previous stage
hidden_states = self.comm.recv_forward()
if self.stage_manager.is_first_stage():
# First just generate a new token
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
logits = model_forward(model, None, hidden_states)
if self.verbose and self.stage_manager.is_first_stage():
torch.cuda.synchronize()
self.timestamps[self.mb_manager.idx].append(time.time())
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
new_token = self._get_token_id(logits['logits'])
self.mb_manager.step(None, None, new_token)
# If the current micro batch is not DONE, go through blocks
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
inputs_dict = self._prepare_inputs_for_new_token(new_token)
output_dict = model_forward(model, inputs_dict, None)
self.mb_manager.step(inputs_dict, output_dict, None)
else:
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
inputs_dict = self._prepare_inputs_for_interval_stage()
output_dict = model_forward(model, inputs_dict, hidden_states)
self.mb_manager.step(inputs_dict, output_dict, None)
# Current microbatch is not DONE, send hidden_state to next stage
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE,
Status.COOLDOWN):
self.comm.send_forward({'hidden_states': output_dict['hidden_states']})
self.mb_manager.next()
# All microbatch in current round is DONE
if self.stage_manager.is_first_stage():
output_sequence.extend(self.mb_manager.export_new_tokens())
self.mb_manager.clear()
if self.verbose and self.stage_manager.is_first_stage():
whole_timestamp.extend(self.timestamps)
return output_sequence, whole_timestamp

View File

@ -12,6 +12,7 @@ class PipelineStageManager:
Args:
pg_mesh (ProcessGroupMesh): Process group mesh.
pipeline_axis (int): The axis along which the pipeline is constructed.
is_virtual (bool): Whether to use circle p2p communication, it will make the first and last stage communicate with each other.
Attributes:
num_stages (int): Number of stages in the pipeline.
@ -24,6 +25,7 @@ class PipelineStageManager:
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
# init prev and next coord
coord = self.pg_mesh.coordinate()
# the prev rank of rank0 is the last rank

View File

@ -66,6 +66,7 @@ def check_low_level_zero_checkpointIO(stage: int, shard: bool, offload: bool):
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.optim.state_dict(), new_optimizer.optim.state_dict(), False)
torch.cuda.empty_cache()
def run_dist(rank, world_size, port):

View File

@ -0,0 +1,63 @@
from copy import deepcopy
import pytest
import torch
import torch.distributed as dist
import torch.nn as nn
import transformers
import colossalai
from colossalai.inference.pipeline.engine import PPInferEngine
from colossalai.inference.pipeline.policy.gpt2_ppinfer import GPT2LMHeadModelPipelinePolicy
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
def data_gen():
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
inputs = data_gen()
for k, v in inputs.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 16
inputs[k] = v.to('cuda').repeat(*new_shape)
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
engine = PPInferEngine(pp_size=pp_size,
model=model,
model_policy=GPT2LMHeadModelPipelinePolicy(),
new_length=new_length,
micro_batch_size=micro_batch_size)
output = engine.inference([inputs])
if dist.get_rank() == 0:
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
@parameterize('pp_size', [4])
@parameterize('new_length', [4, 8, 16])
@parameterize('micro_batch_size', [1, 4])
@clear_cache_before_run()
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
pipeline_inference_test(pp_size, new_length, micro_batch_size)
torch.cuda.empty_cache()
def check_pipeline_inference(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_pipeline_inference_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_pipeline_inference():
spawn(check_pipeline_inference, nprocs=4)
if __name__ == '__main__':
test_pipeline_inference()