mirror of https://github.com/hpcaitech/ColossalAI
[format] applied code formatting on changed files in pull request 4820 (#4886)
Co-authored-by: github-actions <github-actions@github.com>pull/4990/head
parent
c7aa319ba0
commit
486d06a2d5
|
@ -1,3 +1,3 @@
|
||||||
from .pipeline import PPInferEngine
|
from .pipeline import PPInferEngine
|
||||||
|
|
||||||
__all__ = ['PPInferEngine']
|
__all__ = ["PPInferEngine"]
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
from .engine import PPInferEngine
|
from .engine import PPInferEngine
|
||||||
|
|
||||||
__all__ = ['PPInferEngine']
|
__all__ = ["PPInferEngine"]
|
||||||
|
|
|
@ -1,28 +1,32 @@
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
import time
|
|
||||||
from colossalai.inference import PPInferEngine
|
from colossalai.inference import PPInferEngine
|
||||||
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
|
from colossalai.inference.pipeline.policy.llama_ppinfer import LlamaForCausalLMPipelinePolicy
|
||||||
import argparse
|
|
||||||
GIGABYTE = 1024 ** 3
|
GIGABYTE = 1024**3
|
||||||
MEGABYTE = 1024 * 1024
|
MEGABYTE = 1024 * 1024
|
||||||
|
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config={})
|
||||||
|
|
||||||
def data_gen(batch_size: int=4, seq_len: int=512):
|
|
||||||
|
def data_gen(batch_size: int = 4, seq_len: int = 512):
|
||||||
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
|
input_ids = torch.randint(10, 30000, (1, seq_len), dtype=torch.int32)
|
||||||
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
|
attention_mask = torch.ones((1, seq_len), dtype=torch.int32)
|
||||||
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
data = dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||||
new_shape = [1] * v.dim()
|
new_shape = [1] * v.dim()
|
||||||
new_shape[0] = batch_size
|
new_shape[0] = batch_size
|
||||||
data[k] = v.to('cuda').repeat(*new_shape)
|
data[k] = v.to("cuda").repeat(*new_shape)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def print_details_info(timestamps, model_config, args, whole_end2end):
|
def print_details_info(timestamps, model_config, args, whole_end2end):
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
prefill = []
|
prefill = []
|
||||||
|
@ -31,32 +35,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
|
||||||
for timestamp in timestamps:
|
for timestamp in timestamps:
|
||||||
prefill.append(timestamp[1] - timestamp[0])
|
prefill.append(timestamp[1] - timestamp[0])
|
||||||
encoder.append(
|
encoder.append(
|
||||||
sum(timestamp[i + 1] - timestamp[i] for i in range(1,len(timestamp) - 1)) / (len(timestamp) - 2))
|
sum(timestamp[i + 1] - timestamp[i] for i in range(1, len(timestamp) - 1)) / (len(timestamp) - 2)
|
||||||
|
)
|
||||||
end2end.append(timestamp[-1] - timestamp[0])
|
end2end.append(timestamp[-1] - timestamp[0])
|
||||||
print(whole_end2end)
|
print(whole_end2end)
|
||||||
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","w+") as f:
|
with open(
|
||||||
mb_avg_end2end = sum(end2end)/len(end2end)
|
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
|
||||||
mb_avg_latency = mb_avg_end2end/(args.new_length * args.mb_size)
|
"w+",
|
||||||
whole_avg_latency = whole_end2end/(args.new_length * args.batch_size)
|
) as f:
|
||||||
|
mb_avg_end2end = sum(end2end) / len(end2end)
|
||||||
|
mb_avg_latency = mb_avg_end2end / (args.new_length * args.mb_size)
|
||||||
|
whole_avg_latency = whole_end2end / (args.new_length * args.batch_size)
|
||||||
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
num_layers = getattr(model_config, "num_layers", model_config.num_hidden_layers)
|
||||||
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
num_parameters = num_layers * model_config.hidden_size * model_config.hidden_size * 12 / args.pp_size
|
||||||
if args.dtype in ['fp16','bf16']:
|
if args.dtype in ["fp16", "bf16"]:
|
||||||
num_bytes = 2
|
num_bytes = 2
|
||||||
else:
|
else:
|
||||||
num_bytes = 4
|
num_bytes = 4
|
||||||
|
|
||||||
f.write(f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n")
|
f.write(
|
||||||
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill)/len(prefill)*1000))
|
f"llama-{args.model}{args.dtype}_pp{args.pp_size}, input_len:{args.seq_len}, output_len:{args.new_length}, bsz:{args.batch_size}, mbsz:{args.mb_size}\n"
|
||||||
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder)/len(encoder)*1000))
|
)
|
||||||
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end*1000))
|
f.write("Average prefill time: {0:8.2f} ms\n".format(sum(prefill) / len(prefill) * 1000))
|
||||||
|
f.write("Average encode time: {0:8.2f} ms\n".format(sum(encoder) / len(encoder) * 1000))
|
||||||
|
f.write("Average micro batch end2end time: {0:8.2f} ms\n".format(mb_avg_end2end * 1000))
|
||||||
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
|
f.write("Average micro batch Per Token Latency: {0:8.2f} ms\n".format(mb_avg_latency * 1000))
|
||||||
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end*1000))
|
f.write("Whole batch end2end time: {0:8.2f} ms\n".format(whole_end2end * 1000))
|
||||||
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
|
f.write("Whole batch Per Token Latency: {0:8.2f} ms\n".format(whole_avg_latency * 1000))
|
||||||
f.write("Throughput: {} tokens/s\n".format((1000/(whole_avg_latency * 1000))))
|
f.write("Throughput: {} tokens/s\n".format((1000 / (whole_avg_latency * 1000))))
|
||||||
f.write("flops: {0:8.2f} TFlops/s\n".format(1/whole_avg_latency * num_parameters * num_bytes / 1e12))
|
f.write("flops: {0:8.2f} TFlops/s\n".format(1 / whole_avg_latency * num_parameters * num_bytes / 1e12))
|
||||||
f.write("----------------------------------------------------------\n")
|
f.write("----------------------------------------------------------\n")
|
||||||
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
current_device = torch.cuda.current_device()
|
current_device = torch.cuda.current_device()
|
||||||
|
|
||||||
|
@ -66,7 +75,10 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
|
||||||
max_memory_allocated = torch.cuda.max_memory_allocated()
|
max_memory_allocated = torch.cuda.max_memory_allocated()
|
||||||
memory_reserved = torch.cuda.memory_reserved()
|
memory_reserved = torch.cuda.memory_reserved()
|
||||||
max_memory_reserved = torch.cuda.max_memory_reserved()
|
max_memory_reserved = torch.cuda.max_memory_reserved()
|
||||||
with open(f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log","a") as f:
|
with open(
|
||||||
|
f"{args.log_path}/llama-{args.model}{args.dtype}_pp{args.pp_size}_{args.seq_len}_{args.new_length}_bsz{args.batch_size}_mbsz{args.mb_size}.log",
|
||||||
|
"a",
|
||||||
|
) as f:
|
||||||
f.write(
|
f.write(
|
||||||
f"\nCurrently using GPU: {current_device}\n"
|
f"\nCurrently using GPU: {current_device}\n"
|
||||||
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
|
f"free memory : {global_free_memory / GIGABYTE:.4f} GB,\n"
|
||||||
|
@ -77,29 +89,37 @@ def print_details_info(timestamps, model_config, args, whole_end2end):
|
||||||
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
|
f"Max CUDA memory reserved/cached: {max_memory_reserved / GIGABYTE:.4f} GB,\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', default='toy', help='the size of model')
|
parser.add_argument("--model", default="toy", help="the size of model")
|
||||||
parser.add_argument('-b', '--batch_size', type=int, default=8, help='batch size')
|
parser.add_argument("-b", "--batch_size", type=int, default=8, help="batch size")
|
||||||
parser.add_argument('-s', '--seq_len', type=int, default=8, help='sequence length')
|
parser.add_argument("-s", "--seq_len", type=int, default=8, help="sequence length")
|
||||||
parser.add_argument('--new_length', type=int, default=4, help='new tokens length')
|
parser.add_argument("--new_length", type=int, default=4, help="new tokens length")
|
||||||
parser.add_argument('--mb_size', type=int, default=1, help='micro_batch_size')
|
parser.add_argument("--mb_size", type=int, default=1, help="micro_batch_size")
|
||||||
parser.add_argument('--pp_size', type=int, default=2, help='pipeline size')
|
parser.add_argument("--pp_size", type=int, default=2, help="pipeline size")
|
||||||
parser.add_argument('--log_path', type=str, default='./log' ,help='where to store the benchmark log')
|
parser.add_argument("--log_path", type=str, default="./log", help="where to store the benchmark log")
|
||||||
parser.add_argument('--dtype', type=str, default='fp16', help='data type')
|
parser.add_argument("--dtype", type=str, default="fp16", help="data type")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.model == 'toy':
|
if args.model == "toy":
|
||||||
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
|
model = transformers.LlamaForCausalLM(transformers.LlamaConfig(num_hidden_layers=8))
|
||||||
elif args.model == '7b':
|
elif args.model == "7b":
|
||||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-7b-hf'))
|
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-7b-hf"))
|
||||||
elif args.model == '13b':
|
elif args.model == "13b":
|
||||||
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained('decapoda-research/llama-13b-hf'))
|
model = transformers.LlamaForCausalLM(transformers.AutoConfig.from_pretrained("decapoda-research/llama-13b-hf"))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
engine = PPInferEngine(
|
||||||
engine = PPInferEngine(pp_size=args.pp_size, dtype=args.dtype, micro_batch_size=args.mb_size, new_length=args.new_length, model=model, model_policy=LlamaForCausalLMPipelinePolicy(),verbose=True)
|
pp_size=args.pp_size,
|
||||||
|
dtype=args.dtype,
|
||||||
|
micro_batch_size=args.mb_size,
|
||||||
|
new_length=args.new_length,
|
||||||
|
model=model,
|
||||||
|
model_policy=LlamaForCausalLMPipelinePolicy(),
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
data = data_gen(args.batch_size, args.seq_len)
|
data = data_gen(args.batch_size, args.seq_len)
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
@ -109,4 +129,3 @@ if __name__ == '__main__':
|
||||||
whole_end2end = time.time() - whole_end2end
|
whole_end2end = time.time() - whole_end2end
|
||||||
|
|
||||||
print_details_info(timestamps, model.config, args, whole_end2end)
|
print_details_info(timestamps, model.config, args, whole_end2end)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
from typing import Callable, List, Optional, Set, Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -13,7 +11,7 @@ from .microbatch_manager import MicroBatchManager
|
||||||
|
|
||||||
|
|
||||||
class PPInferEngine:
|
class PPInferEngine:
|
||||||
'''
|
"""
|
||||||
PPInferEngine is a class that handles the pipeline parallel inference.
|
PPInferEngine is a class that handles the pipeline parallel inference.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -41,12 +39,12 @@ class PPInferEngine:
|
||||||
output = engine.inference([tokenized_input])
|
output = engine.inference([tokenized_input])
|
||||||
```
|
```
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
pp_size: int,
|
pp_size: int,
|
||||||
dtype: str = 'fp16',
|
dtype: str = "fp16",
|
||||||
pp_model: nn.Module = None,
|
pp_model: nn.Module = None,
|
||||||
model: nn.Module = None,
|
model: nn.Module = None,
|
||||||
model_policy: Policy = None,
|
model_policy: Policy = None,
|
||||||
|
@ -54,7 +52,7 @@ class PPInferEngine:
|
||||||
micro_batch_size: int = 1,
|
micro_batch_size: int = 1,
|
||||||
micro_batch_buffer_size: int = None,
|
micro_batch_buffer_size: int = None,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
# TODO: implement early_stopping, and various gerneration options
|
# TODO: implement early_stopping, and various gerneration options
|
||||||
early_stopping: bool = False,
|
early_stopping: bool = False,
|
||||||
do_sample: bool = False,
|
do_sample: bool = False,
|
||||||
num_beams: int = 1,
|
num_beams: int = 1,
|
||||||
|
@ -63,15 +61,16 @@ class PPInferEngine:
|
||||||
self.pp_size = pp_size
|
self.pp_size = pp_size
|
||||||
self.pg_mesh = ProcessGroupMesh(pp_size)
|
self.pg_mesh = ProcessGroupMesh(pp_size)
|
||||||
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
|
self.stage_manager = PipelineStageManager(self.pg_mesh, 0, True)
|
||||||
self.mb_manager = MicroBatchManager(self.stage_manager.stage, new_length, micro_batch_size,
|
self.mb_manager = MicroBatchManager(
|
||||||
micro_batch_buffer_size or pp_size)
|
self.stage_manager.stage, new_length, micro_batch_size, micro_batch_buffer_size or pp_size
|
||||||
|
)
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
|
self.schedule = GenerateSchedule(self.stage_manager, self.mb_manager, verbose)
|
||||||
|
|
||||||
assert dtype in ['fp16', 'fp32', 'bf16'], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
assert dtype in ["fp16", "fp32", "bf16"], "dtype should be one of 'fp16', 'fp32', 'bf16'"
|
||||||
if dtype == 'fp16':
|
if dtype == "fp16":
|
||||||
model.half()
|
model.half()
|
||||||
elif dtype == 'bf16':
|
elif dtype == "bf16":
|
||||||
model.to(torch.bfloat16)
|
model.to(torch.bfloat16)
|
||||||
self.model = pp_model or self._shardformer(model, model_policy)
|
self.model = pp_model or self._shardformer(model, model_policy)
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Dict, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
__all__ = 'MicroBatchManager'
|
__all__ = "MicroBatchManager"
|
||||||
|
|
||||||
|
|
||||||
class Status(Enum):
|
class Status(Enum):
|
||||||
|
@ -13,7 +13,7 @@ class Status(Enum):
|
||||||
COOLDOWN = 4
|
COOLDOWN = 4
|
||||||
|
|
||||||
|
|
||||||
class MicroBatchDescription():
|
class MicroBatchDescription:
|
||||||
"""
|
"""
|
||||||
This is the class to record the infomation of each microbatch, and also do some update operation.
|
This is the class to record the infomation of each microbatch, and also do some update operation.
|
||||||
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
This clase is the base class of `HeadMicroBatchDescription` and `BodyMicroBatchDescription`, for more
|
||||||
|
@ -30,14 +30,14 @@ class MicroBatchDescription():
|
||||||
output_dict: Dict[str, torch.Tensor],
|
output_dict: Dict[str, torch.Tensor],
|
||||||
new_length: int,
|
new_length: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert output_dict.get('hidden_states') is not None
|
assert output_dict.get("hidden_states") is not None
|
||||||
self.mb_length = output_dict['hidden_states'].shape[-2]
|
self.mb_length = output_dict["hidden_states"].shape[-2]
|
||||||
self.target_length = self.mb_length + new_length
|
self.target_length = self.mb_length + new_length
|
||||||
self.kv_cache = ()
|
self.kv_cache = ()
|
||||||
|
|
||||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||||
if output_dict is not None:
|
if output_dict is not None:
|
||||||
self._update_kvcache(output_dict['past_key_values'])
|
self._update_kvcache(output_dict["past_key_values"])
|
||||||
|
|
||||||
def _update_kvcache(self, kv_cache: Tuple):
|
def _update_kvcache(self, kv_cache: Tuple):
|
||||||
assert type(kv_cache) == tuple
|
assert type(kv_cache) == tuple
|
||||||
|
@ -64,7 +64,6 @@ class MicroBatchDescription():
|
||||||
Return the current sequnence length of micro batch
|
Return the current sequnence length of micro batch
|
||||||
|
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HeadMicroBatchDescription(MicroBatchDescription):
|
class HeadMicroBatchDescription(MicroBatchDescription):
|
||||||
|
@ -80,13 +79,14 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
|
def __init__(
|
||||||
new_length: int) -> None:
|
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
|
||||||
|
) -> None:
|
||||||
super().__init__(inputs_dict, output_dict, new_length)
|
super().__init__(inputs_dict, output_dict, new_length)
|
||||||
assert inputs_dict is not None
|
assert inputs_dict is not None
|
||||||
assert inputs_dict.get('input_ids') is not None and inputs_dict.get('attention_mask') is not None
|
assert inputs_dict.get("input_ids") is not None and inputs_dict.get("attention_mask") is not None
|
||||||
self.input_ids = inputs_dict['input_ids']
|
self.input_ids = inputs_dict["input_ids"]
|
||||||
self.attn_mask = inputs_dict['attention_mask']
|
self.attn_mask = inputs_dict["attention_mask"]
|
||||||
self.new_tokens = None
|
self.new_tokens = None
|
||||||
|
|
||||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||||
|
@ -104,7 +104,8 @@ class HeadMicroBatchDescription(MicroBatchDescription):
|
||||||
|
|
||||||
def _update_attnmask(self):
|
def _update_attnmask(self):
|
||||||
self.attn_mask = torch.cat(
|
self.attn_mask = torch.cat(
|
||||||
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device='cuda')), dim=-1)
|
(self.attn_mask, torch.ones((self.attn_mask.shape[0], 1), dtype=torch.int64, device="cuda")), dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def cur_length(self):
|
def cur_length(self):
|
||||||
|
@ -127,8 +128,9 @@ class BodyMicroBatchDescription(MicroBatchDescription):
|
||||||
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
output_dict (Dict[str, torch.Tensor]): the outputs of previous stage. The key should have `hidden_states` and `past_key_values`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor],
|
def __init__(
|
||||||
new_length: int) -> None:
|
self, inputs_dict: Dict[str, torch.Tensor], output_dict: Dict[str, torch.Tensor], new_length: int
|
||||||
|
) -> None:
|
||||||
super().__init__(inputs_dict, output_dict, new_length)
|
super().__init__(inputs_dict, output_dict, new_length)
|
||||||
|
|
||||||
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
def update(self, output_dict: Dict[str, torch.Tensor] = None, new_token: torch.Tensor = None):
|
||||||
|
@ -146,8 +148,8 @@ class BodyMicroBatchDescription(MicroBatchDescription):
|
||||||
return self.kv_cache[0][0].shape[-2] + 1
|
return self.kv_cache[0][0].shape[-2] + 1
|
||||||
|
|
||||||
|
|
||||||
class MicroBatchManager():
|
class MicroBatchManager:
|
||||||
'''
|
"""
|
||||||
MicroBatchManager is a class that manages the micro batch.
|
MicroBatchManager is a class that manages the micro batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -156,7 +158,7 @@ class MicroBatchManager():
|
||||||
micro_batch_size (int): the micro batch size.
|
micro_batch_size (int): the micro batch size.
|
||||||
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
micro_batch_buffer_size (int): the buffer size for micro batch. Normally, it should be the same as the number of pipeline stages.
|
||||||
|
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
|
def __init__(self, stage: int, new_length: int, micro_batch_size: int, micro_batch_buffer_size: int):
|
||||||
self.stage = stage
|
self.stage = stage
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import CrossEntropyLoss
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
@ -10,41 +9,41 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
|
|
||||||
class GPT2PipelineForwards:
|
class GPT2PipelineForwards:
|
||||||
'''
|
"""
|
||||||
This class serves as a micro library for forward function substitution of GPT2 models
|
This class serves as a micro library for forward function substitution of GPT2 models
|
||||||
under pipeline setting.
|
under pipeline setting.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gpt2_model_forward(
|
def gpt2_model_forward(
|
||||||
self: GPT2Model,
|
self: GPT2Model,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
stage_index: Optional[List[int]] = None,
|
||||||
|
) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
||||||
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
|
||||||
# Please refer to original code of transformers for more details.
|
# Please refer to original code of transformers for more details.
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
# Preprocess passed in arguments
|
# Preprocess passed in arguments
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
@ -96,7 +95,7 @@ class GPT2PipelineForwards:
|
||||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
# positions we want to attend and the dtype's smallest value for masked positions.
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
# effectively the same as removing these entirely.
|
# effectively the same as removing these entirely.
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
||||||
|
|
||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
|
@ -137,7 +136,8 @@ class GPT2PipelineForwards:
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
presents = () if use_cache else None
|
presents = () if use_cache else None
|
||||||
|
@ -166,7 +166,6 @@ class GPT2PipelineForwards:
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs, use_cache, output_attentions)
|
return module(*inputs, use_cache, output_attentions)
|
||||||
|
@ -218,61 +217,64 @@ class GPT2PipelineForwards:
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
return {'hidden_states': hidden_states, 'past_key_values': presents}
|
return {"hidden_states": hidden_states, "past_key_values": presents}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def gpt2_lmhead_model_forward(
|
def gpt2_lmhead_model_forward(
|
||||||
self: GPT2LMHeadModel,
|
self: GPT2LMHeadModel,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
stage_manager: Optional[PipelineStageManager] = None,
|
stage_manager: Optional[PipelineStageManager] = None,
|
||||||
hidden_states: Optional[torch.FloatTensor] = None,
|
hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
stage_index: Optional[List[int]] = None,
|
||||||
|
) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||||
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
||||||
|
|
||||||
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel.forward.
|
||||||
Please refer to original code of transformers for more details.
|
Please refer to original code of transformers for more details.
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# If is first stage and after warmup, go throught lm_head first
|
# If is first stage and after warmup, go throught lm_head first
|
||||||
if stage_manager.is_first_stage() and hidden_states is not None:
|
if stage_manager.is_first_stage() and hidden_states is not None:
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
return {'logits': lm_logits}
|
return {"logits": lm_logits}
|
||||||
|
|
||||||
# Not first stage or before warmup, go through gpt2 model
|
# Not first stage or before warmup, go through gpt2 model
|
||||||
outputs = GPT2PipelineForwards.gpt2_model_forward(self.transformer,
|
outputs = GPT2PipelineForwards.gpt2_model_forward(
|
||||||
input_ids,
|
self.transformer,
|
||||||
past_key_values=past_key_values,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
past_key_values=past_key_values,
|
||||||
token_type_ids=token_type_ids,
|
attention_mask=attention_mask,
|
||||||
position_ids=position_ids,
|
token_type_ids=token_type_ids,
|
||||||
head_mask=head_mask,
|
position_ids=position_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
head_mask=head_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
inputs_embeds=inputs_embeds,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
use_cache=use_cache,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
output_attentions=output_attentions,
|
use_cache=use_cache,
|
||||||
output_hidden_states=output_hidden_states,
|
output_attentions=output_attentions,
|
||||||
return_dict=return_dict,
|
output_hidden_states=output_hidden_states,
|
||||||
stage_manager=stage_manager,
|
return_dict=return_dict,
|
||||||
hidden_states=hidden_states,
|
stage_manager=stage_manager,
|
||||||
stage_index=stage_index)
|
hidden_states=hidden_states,
|
||||||
|
stage_index=stage_index,
|
||||||
|
)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn import CrossEntropyLoss, MSELoss
|
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
|
||||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
@ -10,10 +8,10 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
|
|
||||||
|
|
||||||
class LlamaPipelineForwards:
|
class LlamaPipelineForwards:
|
||||||
'''
|
"""
|
||||||
This class serves as a micro library for forward function substitution of Llama models
|
This class serves as a micro library for forward function substitution of Llama models
|
||||||
under pipeline setting.
|
under pipeline setting.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def llama_model_forward(
|
def llama_model_forward(
|
||||||
self: LlamaModel,
|
self: LlamaModel,
|
||||||
|
@ -34,10 +32,10 @@ class LlamaPipelineForwards:
|
||||||
|
|
||||||
# Preprocess passed in arguments
|
# Preprocess passed in arguments
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
|
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
@ -70,10 +68,9 @@ class LlamaPipelineForwards:
|
||||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_key_values_length,
|
position_ids = torch.arange(
|
||||||
seq_length + past_key_values_length,
|
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
||||||
dtype=torch.long,
|
)
|
||||||
device=device)
|
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
else:
|
else:
|
||||||
position_ids = position_ids.view(-1, seq_length).long()
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
@ -81,16 +78,18 @@ class LlamaPipelineForwards:
|
||||||
# embed positions, for the first stage, hidden_states is the input embeddings,
|
# embed positions, for the first stage, hidden_states is the input embeddings,
|
||||||
# for the other stages, hidden_states is the output of the previous stage
|
# for the other stages, hidden_states is the output of the previous stage
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones((batch_size, seq_length_with_past),
|
attention_mask = torch.ones(
|
||||||
dtype=torch.bool,
|
(batch_size, seq_length_with_past), dtype=torch.bool, device=hidden_states.device
|
||||||
device=hidden_states.device)
|
)
|
||||||
attention_mask = self._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), hidden_states,
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
past_key_values_length)
|
attention_mask, (batch_size, seq_length), hidden_states, past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
# decoder layers
|
# decoder layers
|
||||||
|
@ -112,7 +111,6 @@ class LlamaPipelineForwards:
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
|
|
||||||
def create_custom_forward(module):
|
def create_custom_forward(module):
|
||||||
|
|
||||||
def custom_forward(*inputs):
|
def custom_forward(*inputs):
|
||||||
# None for past_key_value
|
# None for past_key_value
|
||||||
return module(*inputs, output_attentions, None)
|
return module(*inputs, output_attentions, None)
|
||||||
|
@ -152,7 +150,7 @@ class LlamaPipelineForwards:
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
|
|
||||||
# always return dict for imediate stage
|
# always return dict for imediate stage
|
||||||
return {'hidden_states': hidden_states, 'past_key_values': next_cache}
|
return {"hidden_states": hidden_states, "past_key_values": next_cache}
|
||||||
|
|
||||||
def llama_for_causal_lm_forward(
|
def llama_for_causal_lm_forward(
|
||||||
self: LlamaForCausalLM,
|
self: LlamaForCausalLM,
|
||||||
|
@ -171,45 +169,45 @@ class LlamaPipelineForwards:
|
||||||
stage_index: Optional[List[int]] = None,
|
stage_index: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
||||||
|
|
||||||
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
||||||
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
||||||
|
|
||||||
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
>>> prompt = "Hey, are you consciours? Can you talk to me?"
|
||||||
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
||||||
|
|
||||||
>>> # Generate
|
>>> # Generate
|
||||||
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
||||||
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
||||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||||
```"""
|
```"""
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
|
logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
|
||||||
output_attentions = False
|
output_attentions = False
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
|
logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
|
||||||
output_hidden_states = False
|
output_hidden_states = False
|
||||||
|
|
||||||
# If is first stage and after warmup, go throught lm_head first
|
# If is first stage and after warmup, go throught lm_head first
|
||||||
if stage_manager.is_first_stage() and hidden_states is not None:
|
if stage_manager.is_first_stage() and hidden_states is not None:
|
||||||
lm_logits = self.lm_head(hidden_states)
|
lm_logits = self.lm_head(hidden_states)
|
||||||
return {'logits': lm_logits}
|
return {"logits": lm_logits}
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = LlamaPipelineForwards.llama_model_forward(
|
outputs = LlamaPipelineForwards.llama_model_forward(
|
||||||
|
|
|
@ -11,7 +11,6 @@ from ..modeling.gpt2 import GPT2PipelineForwards
|
||||||
|
|
||||||
|
|
||||||
class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -22,18 +21,22 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
||||||
|
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
addon_module = {
|
addon_module = {
|
||||||
GPT2LMHeadModel:
|
GPT2LMHeadModel: ModulePolicyDescription(
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True})
|
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs={"gather_output": True}
|
||||||
])
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
}
|
}
|
||||||
module_policy.update(addon_module)
|
module_policy.update(addon_module)
|
||||||
|
|
||||||
if self.pipeline_stage_manager is not None:
|
if self.pipeline_stage_manager is not None:
|
||||||
self.set_pipeline_forward(model_cls=GPT2LMHeadModel,
|
self.set_pipeline_forward(
|
||||||
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
model_cls=GPT2LMHeadModel,
|
||||||
policy=module_policy)
|
new_forward=GPT2PipelineForwards.gpt2_lmhead_model_forward,
|
||||||
|
policy=module_policy,
|
||||||
|
)
|
||||||
return module_policy
|
return module_policy
|
||||||
|
|
||||||
def get_held_layers(self) -> List[nn.Module]:
|
def get_held_layers(self) -> List[nn.Module]:
|
||||||
|
@ -45,7 +48,7 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
||||||
return held_layers
|
return held_layers
|
||||||
|
|
||||||
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
def get_shared_params(self) -> List[Dict[int, Tensor]]:
|
||||||
'''The weights of wte and lm_head are shared.'''
|
"""The weights of wte and lm_head are shared."""
|
||||||
module = self.model
|
module = self.model
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if stage_manager is not None:
|
if stage_manager is not None:
|
||||||
|
@ -56,16 +59,16 @@ class GPT2LMHeadModelPipelinePolicy(GPT2Policy):
|
||||||
|
|
||||||
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
|
||||||
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
"""If under pipeline parallel setting, replacing the original forward method of huggingface
|
||||||
to customized forward method, and add this changing to policy."""
|
to customized forward method, and add this changing to policy."""
|
||||||
if not self.pipeline_stage_manager:
|
if not self.pipeline_stage_manager:
|
||||||
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
|
||||||
stage_manager = self.pipeline_stage_manager
|
stage_manager = self.pipeline_stage_manager
|
||||||
if self.model.__class__.__name__ == 'GPT2Model':
|
if self.model.__class__.__name__ == "GPT2Model":
|
||||||
module = self.model
|
module = self.model
|
||||||
else:
|
else:
|
||||||
module = self.model.transformer
|
module = self.model.transformer
|
||||||
|
|
||||||
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
|
||||||
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
|
||||||
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
|
||||||
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
|
||||||
|
|
|
@ -1,19 +1,15 @@
|
||||||
from functools import partial
|
from typing import List
|
||||||
from typing import Callable, Dict, List, Union
|
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
from torch import Tensor
|
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
from colossalai.shardformer.layer import Linear1D_Col
|
||||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||||
from colossalai.shardformer.policies.llama import LlamaPolicy
|
from colossalai.shardformer.policies.llama import LlamaPolicy
|
||||||
|
|
||||||
from ..modeling.llama import LlamaPipelineForwards
|
from ..modeling.llama import LlamaPipelineForwards
|
||||||
|
|
||||||
|
|
||||||
class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
|
class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -25,19 +21,21 @@ class LlamaForCausalLMPipelinePolicy(LlamaPolicy):
|
||||||
if self.shard_config.enable_tensor_parallelism:
|
if self.shard_config.enable_tensor_parallelism:
|
||||||
# add a new item for casual lm
|
# add a new item for casual lm
|
||||||
new_item = {
|
new_item = {
|
||||||
LlamaForCausalLM:
|
LlamaForCausalLM: ModulePolicyDescription(
|
||||||
ModulePolicyDescription(sub_module_replacement=[
|
sub_module_replacement=[
|
||||||
SubModuleReplacementDescription(
|
SubModuleReplacementDescription(
|
||||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True))
|
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||||
])
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
}
|
}
|
||||||
policy.update(new_item)
|
policy.update(new_item)
|
||||||
|
|
||||||
if self.pipeline_stage_manager:
|
if self.pipeline_stage_manager:
|
||||||
# set None as default
|
# set None as default
|
||||||
self.set_pipeline_forward(model_cls=LlamaForCausalLM,
|
self.set_pipeline_forward(
|
||||||
new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward,
|
model_cls=LlamaForCausalLM, new_forward=LlamaPipelineForwards.llama_for_causal_lm_forward, policy=policy
|
||||||
policy=policy)
|
)
|
||||||
|
|
||||||
return policy
|
return policy
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Optional, Set
|
from typing import Set
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -30,6 +30,6 @@ def get_suffix_name(suffix: str, name: str):
|
||||||
suffix (str): The suffix of the suffix module
|
suffix (str): The suffix of the suffix module
|
||||||
name (str): The name of the current module
|
name (str): The name of the current module
|
||||||
"""
|
"""
|
||||||
point = '' if suffix is '' else '.'
|
point = "" if suffix is "" else "."
|
||||||
suffix_name = suffix + f'[{name}]' if name.isdigit() else suffix + f'{point}{name}'
|
suffix_name = suffix + f"[{name}]" if name.isdigit() else suffix + f"{point}{name}"
|
||||||
return suffix_name
|
return suffix_name
|
||||||
|
|
|
@ -302,7 +302,9 @@ class PipelineP2PCommunication:
|
||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
|
||||||
|
|
||||||
def p2p_communicate(self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16) -> None:
|
def p2p_communicate(
|
||||||
|
self, output_object: Any, recv_pre: bool, peer: int = None, comm_dtype: torch.dtype = torch.float16
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
Sends the input tensor to the next stage in pipeline, using `P2Pop` in torch.
|
||||||
|
|
||||||
|
@ -313,5 +315,7 @@ class PipelineP2PCommunication:
|
||||||
if peer is None:
|
if peer is None:
|
||||||
peer = self.stage_manager.get_next_rank()
|
peer = self.stage_manager.get_next_rank()
|
||||||
cur_rank = self.stage_manager.get_rank()
|
cur_rank = self.stage_manager.get_rank()
|
||||||
recv_tensor = _p2p_comm(output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype)
|
recv_tensor = _p2p_comm(
|
||||||
|
output_object, recv_pre, peer, self.stage_manager.get_p2p_process_group(cur_rank, peer), comm_dtype
|
||||||
|
)
|
||||||
return recv_tensor
|
return recv_tensor
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import time
|
import time
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Iterable, List, Optional, Union
|
from typing import Any, Iterable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
|
@ -16,7 +16,7 @@ from ._utils import get_batch_size, get_micro_batch, model_forward, to_device
|
||||||
from .base import PipelineSchedule
|
from .base import PipelineSchedule
|
||||||
|
|
||||||
|
|
||||||
class ActionIntervalBuffer():
|
class ActionIntervalBuffer:
|
||||||
"""
|
"""
|
||||||
The buffer to save the interval hidden states and new token for stage to use.
|
The buffer to save the interval hidden states and new token for stage to use.
|
||||||
|
|
||||||
|
@ -70,8 +70,9 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
self.batch = batch
|
self.batch = batch
|
||||||
self.batch_size = get_batch_size(batch)
|
self.batch_size = get_batch_size(batch)
|
||||||
self.microbatch_offset = 0
|
self.microbatch_offset = 0
|
||||||
assert self.batch_size % self.microbatch_size == 0, \
|
assert (
|
||||||
f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
|
self.batch_size % self.microbatch_size == 0
|
||||||
|
), f"Batch size should divided by the number of microbatches, {self.batch_size}, {self.num_microbatches}"
|
||||||
self.num_microbatches = self.batch_size // self.microbatch_size
|
self.num_microbatches = self.batch_size // self.microbatch_size
|
||||||
self.round = self.num_microbatches // self.stage_manager.num_stages
|
self.round = self.num_microbatches // self.stage_manager.num_stages
|
||||||
|
|
||||||
|
@ -86,26 +87,26 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
|
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
|
||||||
|
|
||||||
def _prepare_inputs_for_interval_stage(self):
|
def _prepare_inputs_for_interval_stage(self):
|
||||||
'''
|
"""
|
||||||
Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values
|
Prepare inputs for interval stage, for all the interval stage, the inputs is just the past_key_values
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
dict: inputs for interval stage, `{'past_key_values': torch.Tensor}` or `None`
|
||||||
'''
|
"""
|
||||||
model_inputs = {
|
model_inputs = (
|
||||||
'past_key_values': self.mb_manager.cur_kv_cache
|
{"past_key_values": self.mb_manager.cur_kv_cache} if self.mb_manager.cur_kv_cache is not None else None
|
||||||
} if self.mb_manager.cur_kv_cache is not None else None
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
def _prepare_inputs_for_new_token(self, new_token: torch.Tensor):
|
||||||
'''
|
"""
|
||||||
Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values`
|
Prepare inputs for new token, the inputs is a dict with `input_ids`, `attention_mask` and `past_key_values`
|
||||||
`input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end,
|
`input_ids` is the new token, `attention_mask` is the previous mask add `1` in the end,
|
||||||
`past_key_values` is the past_key_values save in the micro batch manager
|
`past_key_values` is the past_key_values save in the micro batch manager
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
|
dict: inputs for new token, `{'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'past_key_values': torch.Tensor}`
|
||||||
'''
|
"""
|
||||||
new_mask = self.mb_manager.cur_descrption.attn_mask
|
new_mask = self.mb_manager.cur_descrption.attn_mask
|
||||||
past_key_values = self.mb_manager.cur_descrption.kv_cache
|
past_key_values = self.mb_manager.cur_descrption.kv_cache
|
||||||
|
|
||||||
|
@ -117,12 +118,12 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
return input_ids
|
return input_ids
|
||||||
|
|
||||||
def _recv_pre_stage(self) -> Any:
|
def _recv_pre_stage(self) -> Any:
|
||||||
'''
|
"""
|
||||||
Receive the output from previous stage
|
Receive the output from previous stage
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Any: The output from previous stage
|
Any: The output from previous stage
|
||||||
'''
|
"""
|
||||||
if self.stage_manager.num_stages == 2:
|
if self.stage_manager.num_stages == 2:
|
||||||
return self.comm.p2p_recv()
|
return self.comm.p2p_recv()
|
||||||
return self.comm.recv_forward()
|
return self.comm.recv_forward()
|
||||||
|
@ -138,7 +139,7 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
output_dict = model_forward(model, inputs_dict, None)
|
output_dict = model_forward(model, inputs_dict, None)
|
||||||
|
|
||||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||||
|
|
||||||
def _gen_token_action(self, model: Module):
|
def _gen_token_action(self, model: Module):
|
||||||
"""
|
"""
|
||||||
|
@ -146,13 +147,15 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
"""
|
"""
|
||||||
hidden_states = self.action_interval_buffer.hidden_states
|
hidden_states = self.action_interval_buffer.hidden_states
|
||||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
||||||
hidden_states = {'hidden_states': hidden_states}
|
hidden_states = {"hidden_states": hidden_states}
|
||||||
logits = model_forward(model, None, hidden_states)
|
logits = model_forward(model, None, hidden_states)
|
||||||
if self.verbose and self.stage_manager.is_first_stage():
|
if self.verbose and self.stage_manager.is_first_stage():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
assert (
|
||||||
new_token = self._get_token_id(logits['logits'])
|
"logits" in logits
|
||||||
|
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||||
|
new_token = self._get_token_id(logits["logits"])
|
||||||
|
|
||||||
self.mb_manager.step(None, None, new_token)
|
self.mb_manager.step(None, None, new_token)
|
||||||
self.action_interval_buffer.new_token = new_token
|
self.action_interval_buffer.new_token = new_token
|
||||||
|
@ -168,17 +171,17 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
output_dict = model_forward(model, inputs_dict, None)
|
output_dict = model_forward(model, inputs_dict, None)
|
||||||
|
|
||||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||||
|
|
||||||
def _body_encoding_action(self, model: Module):
|
def _body_encoding_action(self, model: Module):
|
||||||
hidden_states = self.action_interval_buffer.hidden_states
|
hidden_states = self.action_interval_buffer.hidden_states
|
||||||
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
assert hidden_states is not None, "When not first stage, the hidden states should not be None"
|
||||||
inputs_dict = self._prepare_inputs_for_interval_stage()
|
inputs_dict = self._prepare_inputs_for_interval_stage()
|
||||||
hidden_states = {'hidden_states': hidden_states}
|
hidden_states = {"hidden_states": hidden_states}
|
||||||
output_dict = model_forward(model, inputs_dict, hidden_states)
|
output_dict = model_forward(model, inputs_dict, hidden_states)
|
||||||
|
|
||||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||||
self.action_interval_buffer.hidden_states = output_dict['hidden_states']
|
self.action_interval_buffer.hidden_states = output_dict["hidden_states"]
|
||||||
|
|
||||||
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
def _comm_action(self, recv_pre: bool) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
|
@ -246,10 +249,13 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
|
|
||||||
whole_timestamp = []
|
whole_timestamp = []
|
||||||
|
|
||||||
#run by round
|
# run by round
|
||||||
for _ in range(self.round):
|
for _ in range(self.round):
|
||||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
|
self.timestamps = (
|
||||||
] if self.verbose and self.stage_manager.is_first_stage() else None
|
[[] for _ in range(self.stage_manager.num_stages)]
|
||||||
|
if self.verbose and self.stage_manager.is_first_stage()
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.action_interval_buffer.clear()
|
self.action_interval_buffer.clear()
|
||||||
while self.mb_manager.is_micro_batch_done() is False:
|
while self.mb_manager.is_micro_batch_done() is False:
|
||||||
actions = self._gen_action(model)
|
actions = self._gen_action(model)
|
||||||
|
@ -286,8 +292,11 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
whole_timestamp = []
|
whole_timestamp = []
|
||||||
# run by round
|
# run by round
|
||||||
for _ in range(self.round):
|
for _ in range(self.round):
|
||||||
self.timestamps = [[] for _ in range(self.stage_manager.num_stages)
|
self.timestamps = (
|
||||||
] if self.verbose and self.stage_manager.is_first_stage() else None
|
[[] for _ in range(self.stage_manager.num_stages)]
|
||||||
|
if self.verbose and self.stage_manager.is_first_stage()
|
||||||
|
else None
|
||||||
|
)
|
||||||
while self.mb_manager.is_micro_batch_done() is False:
|
while self.mb_manager.is_micro_batch_done() is False:
|
||||||
inputs_dict = None
|
inputs_dict = None
|
||||||
new_token = None
|
new_token = None
|
||||||
|
@ -307,13 +316,17 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
hidden_states = self.comm.recv_forward()
|
hidden_states = self.comm.recv_forward()
|
||||||
if self.stage_manager.is_first_stage():
|
if self.stage_manager.is_first_stage():
|
||||||
# First just generate a new token
|
# First just generate a new token
|
||||||
assert hidden_states is not None, "When first stage in GENERATE phase, the hidden states should not be None"
|
assert (
|
||||||
|
hidden_states is not None
|
||||||
|
), "When first stage in GENERATE phase, the hidden states should not be None"
|
||||||
logits = model_forward(model, None, hidden_states)
|
logits = model_forward(model, None, hidden_states)
|
||||||
if self.verbose and self.stage_manager.is_first_stage():
|
if self.verbose and self.stage_manager.is_first_stage():
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.timestamps[self.mb_manager.idx].append(time.time())
|
self.timestamps[self.mb_manager.idx].append(time.time())
|
||||||
assert 'logits' in logits, f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
assert (
|
||||||
new_token = self._get_token_id(logits['logits'])
|
"logits" in logits
|
||||||
|
), f"When first stage in GENERATE phase, the ouput should have attribute `logits`, but has {logits.keys()}"
|
||||||
|
new_token = self._get_token_id(logits["logits"])
|
||||||
self.mb_manager.step(None, None, new_token)
|
self.mb_manager.step(None, None, new_token)
|
||||||
# If the current micro batch is not DONE, go through blocks
|
# If the current micro batch is not DONE, go through blocks
|
||||||
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
if self.mb_manager.cur_state in (Status.GENERATE, Status.COOLDOWN):
|
||||||
|
@ -327,9 +340,11 @@ class GenerateSchedule(PipelineSchedule):
|
||||||
self.mb_manager.step(inputs_dict, output_dict, None)
|
self.mb_manager.step(inputs_dict, output_dict, None)
|
||||||
|
|
||||||
# Current microbatch is not DONE, send hidden_state to next stage
|
# Current microbatch is not DONE, send hidden_state to next stage
|
||||||
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (Status.GENERATE,
|
if not self.stage_manager.is_first_stage() or self.mb_manager.cur_state in (
|
||||||
Status.COOLDOWN):
|
Status.GENERATE,
|
||||||
self.comm.send_forward({'hidden_states': output_dict['hidden_states']})
|
Status.COOLDOWN,
|
||||||
|
):
|
||||||
|
self.comm.send_forward({"hidden_states": output_dict["hidden_states"]})
|
||||||
|
|
||||||
self.mb_manager.next()
|
self.mb_manager.next()
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,6 @@
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
@ -20,27 +17,29 @@ def data_gen():
|
||||||
|
|
||||||
inputs = data_gen()
|
inputs = data_gen()
|
||||||
for k, v in inputs.items():
|
for k, v in inputs.items():
|
||||||
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
|
if torch.is_tensor(v) or "Tensor" in v.__class__.__name__:
|
||||||
new_shape = [1] * v.dim()
|
new_shape = [1] * v.dim()
|
||||||
new_shape[0] = 16
|
new_shape[0] = 16
|
||||||
inputs[k] = v.to('cuda').repeat(*new_shape)
|
inputs[k] = v.to("cuda").repeat(*new_shape)
|
||||||
|
|
||||||
|
|
||||||
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
def pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||||
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
|
model = transformers.GPT2LMHeadModel(transformers.GPT2Config(n_layer=8))
|
||||||
engine = PPInferEngine(pp_size=pp_size,
|
engine = PPInferEngine(
|
||||||
model=model,
|
pp_size=pp_size,
|
||||||
model_policy=GPT2LMHeadModelPipelinePolicy(),
|
model=model,
|
||||||
new_length=new_length,
|
model_policy=GPT2LMHeadModelPipelinePolicy(),
|
||||||
micro_batch_size=micro_batch_size)
|
new_length=new_length,
|
||||||
|
micro_batch_size=micro_batch_size,
|
||||||
|
)
|
||||||
output = engine.inference([inputs])
|
output = engine.inference([inputs])
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
|
assert len(output[0]) == new_length, f"{len(output)}, {new_length}"
|
||||||
|
|
||||||
|
|
||||||
@parameterize('pp_size', [4])
|
@parameterize("pp_size", [4])
|
||||||
@parameterize('new_length', [4, 8, 16])
|
@parameterize("new_length", [4, 8, 16])
|
||||||
@parameterize('micro_batch_size', [1, 4])
|
@parameterize("micro_batch_size", [1, 4])
|
||||||
@clear_cache_before_run()
|
@clear_cache_before_run()
|
||||||
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||||
pipeline_inference_test(pp_size, new_length, micro_batch_size)
|
pipeline_inference_test(pp_size, new_length, micro_batch_size)
|
||||||
|
@ -48,7 +47,7 @@ def run_pipeline_inference_test(pp_size, new_length, micro_batch_size):
|
||||||
|
|
||||||
|
|
||||||
def check_pipeline_inference(rank, world_size, port):
|
def check_pipeline_inference(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
run_pipeline_inference_test()
|
run_pipeline_inference_test()
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,5 +58,5 @@ def test_pipeline_inference():
|
||||||
spawn(check_pipeline_inference, nprocs=4)
|
spawn(check_pipeline_inference, nprocs=4)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
test_pipeline_inference()
|
test_pipeline_inference()
|
||||||
|
|
Loading…
Reference in New Issue