Merge pull request #4612 from hpcaitech/feature/shardformer

[shardformer] update hybrid parallel plugin and fix bugs
pull/4623/head
Hongxin Liu 2023-09-05 23:20:00 +08:00 committed by GitHub
commit efba0f44b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
77 changed files with 6298 additions and 1277 deletions

View File

@ -208,7 +208,7 @@ jobs:
- name: Execute Unit Testing
run: |
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest --testmon --testmon-cov=. --durations=10 tests/
CURL_CA_BUNDLE="" PYTHONPATH=$PWD pytest -m "not largedist" --testmon --testmon-forceselect --testmon-cov=. --durations=10 tests/
env:
DATA: /data/scratch/cifar-10
NCCL_SHM_DISABLE: 1

View File

@ -44,7 +44,7 @@ jobs:
name: Test for PyTorch Compatibility
needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
runs-on: [self-hosted, 8-gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}

View File

@ -35,7 +35,7 @@ jobs:
name: Test for PyTorch Compatibility
needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
runs-on: [self-hosted, 8-gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}

View File

@ -32,7 +32,7 @@ jobs:
name: Test for PyTorch Compatibility
needs: matrix_preparation
if: github.repository == 'hpcaitech/ColossalAI'
runs-on: [self-hosted, gpu]
runs-on: [self-hosted, 8-gpu]
strategy:
fail-fast: false
matrix: ${{fromJson(needs.matrix_preparation.outputs.matrix)}}

View File

@ -14,29 +14,43 @@ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
SFT_DATASET = [
{
"instruction": "Provide a list of the top 10 most popular mobile games in Asia",
"input": "",
"output": "The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"id": 0
"instruction":
"Provide a list of the top 10 most popular mobile games in Asia",
"input":
"",
"output":
"The top 10 most popular mobile games in Asia are:\n1) PUBG Mobile\n2) Pokemon Go\n3) Candy Crush Saga\n4) Free Fire\n5) Clash of Clans\n6) Mario Kart Tour\n7) Arena of Valor\n8) Fantasy Westward Journey\n9) Subway Surfers\n10) ARK Survival Evolved",
"id":
0
},
{
"instruction": "Please provide an action plan for reducing carbon footprint on a corporate level",
"input": "",
"output": "An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
"id": 1
"instruction":
"Please provide an action plan for reducing carbon footprint on a corporate level",
"input":
"",
"output":
"An action plan for reducing carbon footprint on a corporate level could include switching to renewable energy sources (e.g., solar or wind powered), investing in energy efficient infrastructure, implementing green production practices (e.g., reduce water and paper usage, dispose of materials responsibly, recycle/reuse materials), reducing or eliminating travel, encouraging employees to be conscious of and reduce their own energy usage, and incentivizing employees to participate in sustainable work practices.",
"id":
1
},
{
"instruction": "Write a persuasive email to your boss explaining why you should have a pay raise",
"input": "",
"output": "Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
"id": 2
"instruction":
"Write a persuasive email to your boss explaining why you should have a pay raise",
"input":
"",
"output":
"Dear [Boss], \n\nI hope this message finds you well. I am writing to request a pay raise.\n\nAs you know, I have been a dedicated and hardworking member of the team since I started working here [insert number] of months/years ago. My enthusiasm and passion for my job has remained consistent over the years, and I have always given 100% to my role. \n\nI understand that the current financial situation is challenging, however, I would sincerely appreciate you taking the time to consider my request. I believe that my dedication to the job and the value that I bring to the organization warrants a raise. I work diligently and am confident that I can continue to be an asset to the company. \n\nI hope my request is taken into account and I thank you in advance for your understanding. I look forward to our conversation. \n\nSincerely,\n[Your Name]",
"id":
2
},
]
PROMPT_DATASET = [
{
"instruction": "Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
"id": 0
"instruction":
"Edit this paragraph to make it more concise: \"Yesterday, I went to the store and bought some things. Then, I came home and put them away. After that, I went for a walk and met some friends.\"",
"id":
0
},
{
"instruction": "Write a descriptive paragraph about a memorable vacation you went on",
@ -73,9 +87,7 @@ def make_tokenizer(model: str):
return tokenizer
def check_content(input_ids_stripped: torch.Tensor,
tokenizer: PreTrainedTokenizer,
model: str):
def check_content(input_ids_stripped: torch.Tensor, tokenizer: PreTrainedTokenizer, model: str):
if model == "opt":
# NOTE: Contrary to GPT2, OPT adds the EOS token </s> to the beginning of every prompt.
assert input_ids_stripped[0] == tokenizer.eos_token_id
@ -98,13 +110,10 @@ def check_content(input_ids_stripped: torch.Tensor,
assert input_ids_stripped != tokenizer.mask_token_id
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize("max_length", [32, 1024])
@pytest.mark.parametrize("max_datasets_size", [2])
def test_prompt_dataset(model: str,
max_datasets_size: int,
max_length: int):
def test_prompt_dataset(model: str, max_datasets_size: int, max_length: int):
with tempfile.TemporaryDirectory() as tmp_dir:
dataset_name = "prompt_dataset.json"
with open(os.path.join(tmp_dir, dataset_name), "w") as f:
@ -127,19 +136,12 @@ def test_prompt_dataset(model: str,
check_content(input_ids.masked_select(attention_mask), tokenizer, model)
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama"])
@pytest.mark.parametrize(["dataset_path", "subset"], [
("Anthropic/hh-rlhf", "harmless-base"),
("Dahoas/rm-static", None)
])
@pytest.mark.parametrize(["dataset_path", "subset"], [("Anthropic/hh-rlhf", "harmless-base"),
("Dahoas/rm-static", None)])
@pytest.mark.parametrize("max_datasets_size", [32])
@pytest.mark.parametrize("max_length", [32, 1024])
def test_reward_dataset(model: str,
dataset_path: str,
subset: Optional[str],
max_datasets_size: int,
max_length: int):
def test_reward_dataset(model: str, dataset_path: str, subset: Optional[str], max_datasets_size: int, max_length: int):
data = load_dataset(dataset_path, data_dir=subset)
assert max_datasets_size <= len(data["train"]) \
and max_datasets_size <= len(data["test"])
@ -196,15 +198,12 @@ def test_reward_dataset(model: str,
assert torch.all(r_mask)
@pytest.mark.cpu
@pytest.mark.parametrize("model", ["gpt2", "bloom", "opt", "llama", "chatglm"])
@pytest.mark.parametrize("dataset_path", ["yizhongw/self_instruct", None])
@pytest.mark.parametrize("max_dataset_size", [2])
@pytest.mark.parametrize("max_length", [32, 1024])
def test_sft_dataset(model: str,
dataset_path: Optional[str],
max_dataset_size: int,
max_length: int):
def test_sft_dataset(model: str, dataset_path: Optional[str], max_dataset_size: int, max_length: int):
tokenizer = make_tokenizer(model)
if dataset_path == "yizhongw/self_instruct":
data = load_dataset(dataset_path, "super_natural_instructions")
@ -253,10 +252,7 @@ def test_sft_dataset(model: str,
if __name__ == "__main__":
test_sft_dataset(model="bloom",
dataset_path="yizhongw/self_instruct",
max_dataset_size=2,
max_length=256)
test_sft_dataset(model="bloom", dataset_path="yizhongw/self_instruct", max_dataset_size=2, max_length=256)
test_reward_dataset(model="gpt2",
dataset_path="Anthropic/hh-rlhf",
@ -266,4 +262,5 @@ if __name__ == "__main__":
test_prompt_dataset(model="opt",
max_datasets_size=2,
max_length=128)
max_length=128)

View File

@ -16,17 +16,19 @@ from coati.models.opt import OPTRM, OPTActor, OPTCritic
from coati.models.utils import calc_action_log_probs, compute_reward, masked_mean
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
@pytest.mark.gpu
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seq_len", [32])
@pytest.mark.parametrize("actor_maker", [
lambda: BLOOMActor(),
lambda: GPTActor(),
@pytest.mark.parametrize(
"actor_maker",
[
lambda: BLOOMActor(),
lambda: GPTActor(),
# HACK: skip llama due to long execution time
# lambda: LlamaActor(),
lambda: OPTActor(),
# lambda: ChatGLMActor(),
])
@pytest.mark.parametrize("generate_kwargs", [{
"max_length": 64,
"use_cache": True,
@ -34,23 +36,15 @@ from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
"temperature": 1.0,
"top_k": 50,
}])
def test_generation(actor_maker: Callable[[], Actor],
batch_size: int,
seq_len: int,
generate_kwargs: Dict[str, Any]
):
def test_generation(actor_maker: Callable[[], Actor], batch_size: int, seq_len: int, generate_kwargs: Dict[str, Any]):
actor = actor_maker()
input_ids = torch.randint(0, 100, (batch_size, seq_len)).cuda()
sequences = generate(actor.cuda(), input_ids, **generate_kwargs)
assert sequences.shape == (batch_size, generate_kwargs["max_length"])
@pytest.mark.cpu
def test_utils():
fn_input = {
"tensor": torch.ones((10, )),
"mask": torch.randint(0, 2, (10, ))
}
fn_input = {"tensor": torch.ones((10,)), "mask": torch.randint(0, 2, (10,))}
fn_output = masked_mean(dim=0, **fn_input)
assert fn_output.dim() == 0
assert torch.allclose(fn_output, torch.tensor(1.0))
@ -58,14 +52,14 @@ def test_utils():
batch_size = 4
num_labels = 10
fn_input = {
"r": torch.ones((batch_size, )),
"r": torch.ones((batch_size,)),
"kl_coef": 1.0,
"log_probs": torch.randn((batch_size, num_labels)),
"log_probs_base": torch.randn((batch_size, num_labels)),
"action_mask": torch.randint(0, 2, (batch_size, num_labels))
}
fn_output = compute_reward(**fn_input)
assert fn_output.shape == (batch_size, )
assert fn_output.shape == (batch_size,)
batch_size = 4
seq_len = 32
@ -82,17 +76,11 @@ def test_utils():
assert fn_output.shape == (batch_size, num_actions)
@pytest.mark.cpu
@pytest.mark.parametrize("lora_rank", [4])
@pytest.mark.parametrize("num_dim", [32])
@pytest.mark.parametrize("num_layers", [4])
def test_lora(lora_rank: int,
num_dim: int,
num_layers: int):
model = nn.ModuleList(
[nn.Linear(num_dim, num_dim)
for _ in range(num_layers)]
)
def test_lora(lora_rank: int, num_dim: int, num_layers: int):
model = nn.ModuleList([nn.Linear(num_dim, num_dim) for _ in range(num_layers)])
lora_model = convert_to_lora_module(model, lora_rank)
assert isinstance(lora_model, nn.ModuleList)
for i in range(num_layers):
@ -105,8 +93,7 @@ def test_lora(lora_rank: int,
assert isinstance(lora_model[i], LoraLinear)
assert torch.allclose(old_model[i].weight, lora_model[i].weight)
assert torch.allclose(old_model[i].bias, lora_model[i].bias)
assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A,
lora_model[i].lora_B @ lora_model[i].lora_A)
assert torch.allclose(old_model[i].lora_B @ old_model[i].lora_A, lora_model[i].lora_B @ lora_model[i].lora_A)
optimizer = torch.optim.Adam(lora_model.parameters())
x = torch.randn(8, num_dim)
for i in range(num_layers):
@ -122,12 +109,13 @@ def test_lora(lora_rank: int,
lora_model[i].lora_B @ lora_model[i].lora_A)
@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("models_maker", [
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()),
@pytest.mark.parametrize(
"models_maker",
[
lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()),
lambda: (GPTActor(), GPTCritic(), GPTRM()),
# HACK: skip llama due to long execution time
# lambda: (LlamaActor(), LlamaCritic(), LlamaRM()),
lambda: (OPTActor(), OPTCritic(), OPTRM()),
@ -178,13 +166,10 @@ def test_models(models_maker: Callable[[], Tuple[Actor, Critic, RewardModel]],
assert rm_output.shape == (batch_size, )
@pytest.mark.cpu
@pytest.mark.parametrize("batch_size", [16])
@pytest.mark.parametrize("seq_len", [128])
@pytest.mark.parametrize("num_labels", [100])
def test_loss(batch_size: int,
seq_len: int,
num_labels: int):
def test_loss(batch_size: int, seq_len: int, num_labels: int):
loss = GPTLMLoss()
loss_input = {
"logits": torch.randn(batch_size, seq_len, num_labels),
@ -194,54 +179,43 @@ def test_loss(batch_size: int,
loss = PolicyLoss()
loss_input = {
"log_probs": torch.randn(batch_size, ),
"old_log_probs": torch.randn(batch_size, ),
"advantages": torch.randn(batch_size, )
"log_probs": torch.randn(batch_size,),
"old_log_probs": torch.randn(batch_size,),
"advantages": torch.randn(batch_size,)
}
loss_output = loss(**loss_input)
loss = ValueLoss()
loss_input = {
"values": torch.randn(batch_size, ),
"old_values": torch.randn(batch_size, ),
"reward": torch.randn(batch_size, )
"values": torch.randn(batch_size,),
"old_values": torch.randn(batch_size,),
"reward": torch.randn(batch_size,)
}
loss_output = loss(**loss_input)
loss = LogSigLoss()
loss_input = {
"chosen_reward": torch.randn(batch_size, ),
"reject_reward": torch.randn(batch_size, ),
"chosen_reward": torch.randn(batch_size,),
"reject_reward": torch.randn(batch_size,),
}
loss_output = loss(**loss_input)
loss = LogExpLoss()
loss_input = {
"chosen_reward": torch.randn(batch_size, ),
"reject_reward": torch.randn(batch_size, ),
"chosen_reward": torch.randn(batch_size,),
"reject_reward": torch.randn(batch_size,),
}
loss_output = loss(**loss_input)
if __name__ == "__main__":
generate_kwargs = dict(max_length=40,
use_cache=True,
do_sample=True,
temperature=1.0,
top_k=50)
test_generation(lambda: LlamaActor(),
batch_size=4,
seq_len=32,
generate_kwargs=generate_kwargs)
generate_kwargs = dict(max_length=40, use_cache=True, do_sample=True, temperature=1.0, top_k=50)
test_generation(lambda: LlamaActor(), batch_size=4, seq_len=32, generate_kwargs=generate_kwargs)
test_utils()
test_lora(lora_rank=2, num_dim=8, num_layers=2)
test_models(models_maker=lambda: (BLOOMActor(),
BLOOMCritic(),
BLOOMRM()),
batch_size=8,
seq_len=128)
test_models(models_maker=lambda: (BLOOMActor(), BLOOMCritic(), BLOOMRM()), batch_size=8, seq_len=128)
test_loss(batch_size=8, seq_len=128, num_labels=100)

View File

@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import (
get_model_base_filenames,
get_optimizer_base_filenames,
load_shard_state_dict,
save_config_file,
save_state_dict,
save_state_dict_shards,
)
@ -107,6 +108,7 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
if self.coordinator.is_master():
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model.module, checkpoint_path)
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")

View File

@ -1,19 +1,22 @@
import random
from contextlib import nullcontext
from typing import Any, Callable, Iterator, List, Optional, Tuple, Union
from functools import partial
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from torch.nn import Module
from torch.nn import Module, SyncBatchNorm
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils._pytree import tree_map
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
from colossalai.checkpoint_io import CheckpointIO
from colossalai.checkpoint_io import CheckpointIO, HypridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import OneForwardOneBackwardSchedule
@ -26,26 +29,52 @@ from .pp_plugin_base import PipelinePluginBase
DP_AXIS, PP_AXIS, TP_AXIS = 0, 1, 2
def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
if isinstance(x, torch.Tensor) and torch.is_floating_point(x):
return x.to(dtype)
return x
class HybridParallelModule(ModelWrapper):
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup) -> None:
def __init__(self, module: Module, precision: str, shard_config: ShardConfig, dp_group: ProcessGroup, use_ddp: bool,
ddp_config: dict) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.dp_group = dp_group
shardformer = ShardFormer(shard_config)
module, self.shared_params = shardformer.optimize(module)
# TODO(ver217): add input type cast
# setting process groups for shared parameters
self.shared_param_process_groups = []
for shared_param in self.shared_params:
if len(shared_param) > 0:
self.shared_param_process_groups.append(
self.stage_manager.init_process_group_by_stages(list(shared_param.keys())))
# setting mixed_precision
self.mixed_precision = None
if precision == 'fp16':
module = module.half().cuda()
self.mixed_precision = torch.float16
elif precision == 'bf16':
module = module.to(dtype=torch.bfloat16).cuda()
else:
module = module.cuda() # train without AMP
# TODO(ver217): support TP+DP
self.mixed_precision = torch.bfloat16
if self.mixed_precision is not None:
module = module.to(self.mixed_precision)
module = module.cuda()
# setting input type cast when using mixed precision
self.convert_fn = None
if self.mixed_precision is not None:
self.convert_fn = partial(_convert_floating_point, dtype=self.mixed_precision)
# setting ddp configs
if use_ddp:
# convert model to sync bn
module = SyncBatchNorm.convert_sync_batchnorm(module, dp_group)
# wrap the model with PyTorch DDP
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module)
def sync_shared_params(self):
@ -68,19 +97,62 @@ class HybridParallelModule(ModelWrapper):
dist.all_reduce(p.grad, group=self.dp_group)
p.grad.div_(self.dp_group.size())
def forward(self, *args, **kwargs):
if self.convert_fn is not None:
args = tree_map(self.convert_fn, args)
kwargs = tree_map(self.convert_fn, kwargs)
return super().forward(*args, **kwargs)
def unwrap(self):
module = super().unwrap()
if isinstance(module, DDP):
module = module.module
return module
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A complete param_group, with params in the form of param_id
# 2. A mapping from param address (obtained using id(param)) to integer param_id
# 3. A mapping from integer param_id to param address.
# 4. A mapping from param_address (obtained using id(param)) to the original shape of parameter before sharding.
# When Zero is used, the params here are fp16/bf16 model params rather than fp32 master params in optimizer.
if optim is None:
return {}
param_info = {'param_groups': [], 'param2id': {}, 'id2param': {}, 'param2shape': {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != 'params'}
packed_group['params'] = []
for param_id, param in enumerate(group['params'], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
packed_group['params'].append(param_id)
param_info['param2id'][id(param)] = param_id
param_info['id2param'][param_id] = id(param)
param_info['param2shape'][id(param)] = original_shape
param_info['param_groups'].append(packed_group)
start_index += len(group['params'])
return param_info
def init_pipeline_optimizer(optim: Optimizer, model: Module):
params = set(model.parameters())
model_params = set(model.parameters())
new_param_groups = []
for group in optim.param_groups:
params = [p for p in group['params'] if p in params]
params = [p for p in group['params'] if p in model_params]
new_param_groups.append({**group, 'params': params})
optim.__setstate__({'param_groups': new_param_groups})
class HybridParallelNaiveOptimizer(OptimizerWrapper):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool):
def __init__(self, optim: Optimizer, model: Module, use_pipeline: bool, param_info: OrderedDict):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim)
@ -92,6 +164,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
optim: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
precision: str = 'fp16',
initial_scale: float = 2**16,
min_scale: float = 1,
@ -101,6 +174,7 @@ class HybridParallelAMPOptimizer(MixedPrecisionOptimizer):
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optim, model)
super().__init__(optim, precision, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
@ -114,6 +188,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
optimizer: Optimizer,
model: Module,
use_pipeline: bool,
param_info: OrderedDict,
initial_scale: int = 2**16, # grad scaler config
min_scale: int = 1,
growth_factor: float = 2.,
@ -131,6 +206,7 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
dp_process_group: Optional[ProcessGroup] = None, # the dp pg for comm
tp_process_group: Optional[ProcessGroup] = None, # if using tp
forced_dtype: Optional[torch.dtype] = None):
self.param_info = param_info
if use_pipeline:
init_pipeline_optimizer(optimizer, model)
super().__init__(optimizer, initial_scale, min_scale, growth_factor, backoff_factor, growth_interval,
@ -140,34 +216,100 @@ class HybridParallelZeroOptimizer(LowLevelZeroOptimizer):
class HybridParallelPlugin(PipelinePluginBase):
"""
Plugin for Hybrid Parallel Training.
Tensor parallel, pipeline parallel and data parallel(DDP/ZeRO) can be picked and combined in this plugin.
The size of tp and pp should be passed in by user, then the size of dp is automatically calculated from dp_size = world_size / (tp_size * pp_size).
Example:
>>> from colossalai.booster import Booster
>>> from colossalai.booster.plugin import HybridParallelPlugin
>>> model, train_dataset, optimizer, criterion = ...
>>> plugin = HybridParallelPlugin(tp_size=2, pp_size=2)
>>> train_dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
>>> booster = Booster(plugin=plugin)
>>> model, optimizer, criterion, train_dataloader, _ = booster.boost(model, optimizer, criterion, train_dataloader)
Args:
tp_size (int): The size of tensor parallelism. Tensor parallelism will not be used when tp_size is set to 1.
pp_size (int): The number of pipeline stages in pipeline parallelism. Pipeline parallelism will not be used when pp_size is set to 1.
precision (str, optional): Specifies the precision of parameters during training.
Auto-mixied precision will be used when this argument is set to 'fp16' or 'bf16', otherwise model is trained with 'fp32'.
Defaults to 'fp16'.
zero_stage (int, optional): The stage of ZeRO for data parallelism. Can only be choosed from [0, 1, 2].
When set to 0, ZeRO will not be used. Defaults to 0.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT. Default to Falase.
num_microbatches (int, optional): Number of microbatches when using pipeline parallelism. Defaults to None.
microbatch_size (int, optional): Microbatch size when using pipeline parallelism.
Either ``num_microbatches`` or ``microbatch_size`` should be provided if using pipeline.
If ``num_microbatches`` is provided, this will be ignored. Defaults to None.
initial_scale (float, optional): The initial loss scale of AMP. Defaults to 2**16.
min_scale (float, optional): The minimum loss scale of AMP. Defaults to 1.
growth_factor (float, optional): The multiplication factor for increasing loss scale when using AMP. Defaults to 2.
backoff_factor (float, optional): The multiplication factor for decreasing loss scale when using AMP. Defaults to 0.5.
growth_interval (int, optional): The number of steps to increase loss scale when no overflow occurs when using AMP. Defaults to 1000.
hysteresis (int, optional): The number of overflows before decreasing loss scale when using AMP. Defaults to 2.
max_scale (float, optional): The maximum loss scale of AMP. Defaults to 2**32.
max_norm (float, optional): Maximum norm for gradient clipping. Defaults to 0.
broadcast_buffers (bool, optional): Whether to broadcast buffers in the beginning of training when using DDP. Defaults to True.
ddp_bucket_cap_mb (int, optional): The bucket size in MB when using DDP. Defaults to 25.
find_unused_parameters (bool, optional): Whether to find unused parameters when using DDP. Defaults to False.
check_reduction (bool, optional): Whether to check reduction when using DDP. Defaults to False.
gradient_as_bucket_view (bool, optional): Whether to use gradient as bucket view when using DDP. Defaults to False.
static_graph (bool, optional): Whether to use static graph when using DDP. Defaults to False.
zero_bucket_size_in_m (int, optional): Gradient reduce bucket size in million elements when using ZeRO. Defaults to 12.
cpu_offload (bool, optional): Whether to open cpu_offload when using ZeRO. Defaults to False.
communication_dtype (torch.dtype, optional): Communication dtype when using ZeRO. If not specified, the dtype of param will be used. Defaults to None.
overlap_communication (bool, optional): Whether to overlap communication and computation when using ZeRO. Defaults to True.
"""
def __init__(self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
enable_sequence_parallelism: bool = False,
enable_sequence_overlap: bool = False,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
broadcast_buffers: bool = True,
ddp_bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
zero_bucket_size_in_m: int = 12,
cpu_offload: bool = False,
communication_dtype: Optional[torch.dtype] = None,
overlap_communication: bool = True) -> None:
def __init__(
self,
tp_size: int,
pp_size: int,
precision: str = 'fp16',
zero_stage: int = 0,
cpu_offload: bool = False,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_jit_fused: bool = False,
num_microbatches: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
max_norm: float = 0,
) -> None:
super().__init__()
assert dist.get_world_size() % (
tp_size * pp_size
) == 0, f'world size {dist.get_world_size()} is not divisible by tp_size {tp_size} * pp_size {pp_size}'
# TODO(ver217): support zero
assert zero_stage == 0, 'zero is not support yet'
if enable_sequence_parallelism:
assert tp_size > 1, 'Sequence parallelism must be enabled when using tensor parallelism'
self.tp_size = tp_size
self.pp_size = pp_size
self.dp_size = dist.get_world_size() // (tp_size * pp_size)
@ -178,24 +320,30 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size)
self.stage_manager = None
self.schedule = None
assert zero_stage in (0, 1, 2)
if self.pp_size > 1:
assert num_microbatches is not None, 'num_microbatches must be specified when using pipeline parallelism'
assert num_microbatches is not None or microbatch_size is not None, 'num_microbatches or microbatch_size must be specified when using pipeline parallelism'
assert self.zero_stage <= 1, 'zero stage must be 0 or 1 when using pipeline parallelism'
self.stage_manager = PipelineStageManager(self.pg_mesh, PP_AXIS)
self.schedule = OneForwardOneBackwardSchedule(num_microbatches, self.stage_manager)
self.schedule = OneForwardOneBackwardSchedule(self.stage_manager,
num_microbatches=num_microbatches,
microbatch_size=microbatch_size)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.pp_group = self.pg_mesh.get_group_along_axis(PP_AXIS)
self.shard_config = ShardConfig(tensor_parallel_process_group=self.tp_group,
pipeline_stage_manager=self.stage_manager,
enable_tensor_parallelism=self.tp_size > 1,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused)
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap)
self.amp_config = dict(
initial_scale=initial_scale,
growth_factor=growth_factor,
@ -205,6 +353,20 @@ class HybridParallelPlugin(PipelinePluginBase):
min_scale=min_scale,
max_scale=max_scale,
)
self.ddp_config = dict(broadcast_buffers=broadcast_buffers,
bucket_cap_mb=ddp_bucket_cap_mb,
find_unused_parameters=find_unused_parameters,
check_reduction=check_reduction,
gradient_as_bucket_view=gradient_as_bucket_view,
static_graph=static_graph)
self.zero_config = dict(reduce_bucket_size=zero_bucket_size_in_m * 1024 * 1024,
communication_dtype=communication_dtype,
overlap_communication=overlap_communication,
cpu_offload=cpu_offload,
partition_grad=(self.zero_stage == 2))
self.max_norm = max_norm
@property
@ -237,32 +399,44 @@ class HybridParallelPlugin(PipelinePluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
param_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group)
use_ddp = self.dp_size > 1 and self.pp_size == 1 and self.zero_stage == 0
model = HybridParallelModule(model, self.precision, self.shard_config, self.dp_group, use_ddp,
self.ddp_config)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.zero_stage == 0:
if self.precision in ['fp16', 'bf16']:
optimizer = HybridParallelAMPOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer.working_to_master_map,
optimizer.master_to_working_map)
else:
optimizer = HybridParallelNaiveOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism)
use_pipeline=self.enable_pipeline_parallelism,
param_info=param_info)
else:
assert self.dp_size > 1, "Please use Zero when data parallel size is greater than 1."
assert self.precision != 'fp32', "Please set precision to 'fp16' or 'bf16' when using ZeRO."
optimizer = HybridParallelZeroOptimizer(optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
partition_grad=(self.zero_stage == 2),
cpu_offload=self.cpu_offload,
param_info=param_info,
dp_process_group=self.dp_group,
tp_process_group=self.tp_group,
verbose=True,
clip_grad_norm=self.max_norm,
**self.zero_config,
**self.amp_config)
self.checkpoint_io.link_master_and_working_param(optimizer._param_store.working_to_master_param,
optimizer._param_store.master_to_working_param)
return model, optimizer, criterion, dataloader, lr_scheduler
def execute_pipeline(self,
@ -339,7 +513,8 @@ class HybridParallelPlugin(PipelinePluginBase):
**_kwargs)
def get_checkpoint_io(self) -> CheckpointIO:
return None
self.checkpoint_io = HypridParallelCheckpointIO(self.dp_group, self.pp_group, self.tp_group, self.zero_stage)
return self.checkpoint_io
def no_sync(self, model: Module) -> Iterator[None]:
raise NotImplementedError

View File

@ -1,5 +1,6 @@
from .checkpoint_io_base import CheckpointIO
from .general_checkpoint_io import GeneralCheckpointIO
from .hybrid_parallel_checkpoint_io import HypridParallelCheckpointIO
from .index_file import CheckpointIndexFile
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO']
__all__ = ['CheckpointIO', 'CheckpointIndexFile', 'GeneralCheckpointIO', 'HybridParallelCheckpointIO']

View File

@ -23,6 +23,7 @@ from .utils import (
load_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict,
save_state_dict_shards,
@ -183,6 +184,7 @@ class GeneralCheckpointIO(CheckpointIO):
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint_path, is_master=True)
logging.info(f"The model is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")

View File

@ -0,0 +1,702 @@
import copy
import gc
import logging
import os
from pathlib import Path
from shutil import rmtree
from typing import Dict, Iterator, Optional, OrderedDict, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.interface import OptimizerWrapper
from .general_checkpoint_io import GeneralCheckpointIO
from .index_file import CheckpointIndexFile
from .utils import (
StateDictSharder,
gather_distributed_param,
get_model_base_filenames,
get_optimizer_base_filenames,
is_safetensors_available,
load_shard_state_dict,
load_state_dict_into_model,
load_states_into_optimizer,
save_config_file,
save_param_groups,
save_state_dict_shards,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class HypridParallelCheckpointIO(GeneralCheckpointIO):
"""
CheckpointIO for Hybrid Parallel Training.
Args:
dp_group (ProcessGroup): Process group along data parallel dimension.
pp_group (ProcessGroup): Process group along pipeline parallel dimension.
tp_group (ProcessGroup): Process group along tensor parallel dimension.
zero_stage (int): The zero stage of plugin. Should be in [0, 1, 2].
verbose (bool, optional): Whether to print logging massage when saving/loading has been succesfully executed. Defaults to True.
"""
def __init__(self,
dp_group: ProcessGroup,
pp_group: ProcessGroup,
tp_group: ProcessGroup,
zero_stage: int,
verbose: bool = True) -> None:
super().__init__()
self.dp_group = dp_group
self.pp_group = pp_group
self.tp_group = tp_group
self.dp_rank = dist.get_rank(self.dp_group)
self.tp_rank = dist.get_rank(self.tp_group)
self.pp_rank = dist.get_rank(self.pp_group)
self.dp_size = dist.get_world_size(dp_group)
self.pp_size = dist.get_world_size(pp_group)
self.tp_size = dist.get_world_size(tp_group)
self.use_zero = (zero_stage > 0)
self.verbose = verbose
self.working_to_master_map = None
self.master_to_working_map = None
@staticmethod
def _model_sharder(model: nn.Module,
prefix: str = '',
keep_vars: bool = False,
size_per_shard: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
# An internel method that breaks state_dict of model into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
# Save parameters.
for name, param in model.named_parameters():
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
yield block, block_size
# Save buffers.
for name, buf in model.named_buffers():
if buf is not None and name not in model._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = state_dict_sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# Save extra states.
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = model.get_extra_state()
block, block_size = state_dict_sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
@staticmethod
def _optimizer_sharder(optimizer: OptimizerWrapper,
use_zero: bool,
dp_group: ProcessGroup,
tp_group: ProcessGroup,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None,
size_per_shard: int = 1024):
# An internel method that breaks state_dict of optimizer into shards within limited size.
state_dict_sharder = StateDictSharder(size_per_shard)
param_info = optimizer.param_info
for param, state in optimizer.optim.state.items():
if param is None:
continue
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
param_id = param_info['param2id'][id(working_param)]
original_shape = param_info['param2shape'][id(working_param)]
state_ = HypridParallelCheckpointIO.gather_from_sharded_optimizer_state(state,
working_param,
original_shape=original_shape,
dp_group=dp_group,
tp_group=tp_group,
use_zero=use_zero,
inplace=False)
block, block_size = state_dict_sharder.append_optim_state(param_id, state_)
if block is not None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def save_sharded_model(self,
model: nn.Module,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024,
use_safetensors: bool = False) -> None:
"""
Save sharded model checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_model.bin.index.json) containing a map between model params/buffers and file names.
- Multiple files that store state tensors of models.
If pipeline parallelism is used, the filenames are in the form of "pytorch_model.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_model.<prefix>-000XX.bin"
Args:
model (nn.Module): Model on local device to be saved.
checkpoint (str): Checkpointing path which should be a directory path.
gather_dtensor (bool, optional): Whether to gather_dtensor, currently not used. Defaults to True.
prefix (str, optional): Perfix of file to save. Defaults to None.
size_per_shard (int, optional): Size per shard in MB. Defaults to 1024.
use_safetensors (bool, optional): Whether to use safe tensors. Defaults to False.
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of model.
# So only let the device with dp_rank == 0 save the model.
if self.dp_rank != 0:
return
# Then collect the sharded parameters & buffers along tp_group.
# Only devices with tp_rank == 0 are responsible for model saving.
state_dict_shard = HypridParallelCheckpointIO._model_sharder(model, size_per_shard=size_per_shard)
weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.tp_rank == 0)
if self.pp_size == 1:
# When pipeline is not used, save the model shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors)
if control_saving:
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
save_config_file(model, checkpoint)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
weights_name = weights_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
weights_name = weights_name.replace(".safetensors", f"-stage-{self.pp_rank+1:05d}-shard.safetensors")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=weights_name,
is_master=control_saving,
use_safetensors=use_safetensors,
use_pp_format=True)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for weight, weight_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(weight, weight_filename)
final_index_file.write_index_file(final_index_file_path)
save_config_file(model, checkpoint)
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}.")
def load_sharded_model(self, model: nn.Module, checkpoint_index_file: Path, strict: bool = False):
"""
Load sharded model with the given path to index file of checkpoint folder.
Args:
model (nn.Module): The model to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
This argument should be manually set to False since params on same device might be stored in different files.
"""
# Check whether the checkpoint uses safetensors.
use_safetensors = False
if "safetensors" in checkpoint_index_file.name:
use_safetensors = True
if use_safetensors and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
strict = False
# Load params & buffers to model.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
def _load(name: str):
if name not in weight_map:
raise ValueError(f"{name} is not stored in checkpoint, please check your checkpointing configuration!")
filename = weight_map[name]
# If this param/buffer has been loaded before, directly return.
if filename in loaded_file:
return
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
missing_keys = []
load_state_dict_into_model(model,
state_dict,
missing_keys=missing_keys,
strict=strict,
load_sub_module=True)
loaded_file.add(filename)
# Load parameters.
for name, _ in model.named_parameters():
_load(name)
# Load buffers.
non_persistent_buffers = set()
for n, m in model.named_modules():
non_persistent_buffers |= set('.'.join((n, b)) for b in m._non_persistent_buffers_set)
for name, buf in model.named_buffers():
if buf is not None and name not in non_persistent_buffers:
_load(name)
# Load extra states.
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
if getattr(model.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
_load(extra_state_key)
# Update master params if mixed-precision training is enabled.
with torch.no_grad():
if self.working_to_master_map is not None:
for param in model.parameters():
if (param is None) or (id(param) not in self.working_to_master_map):
continue
master_param = self.working_to_master_map[id(param)]
if self.use_zero:
# master_param is sharded under Zero setting
padding_size = (self.dp_size - param.numel() % self.dp_size) % self.dp_size
if padding_size > 0:
padded_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
else:
padded_param = param.data.view(-1)
sharded_param = padded_param.split(padded_param.numel() // self.dp_size)[self.dp_rank]
master_param.data.copy_(sharded_param.data)
else:
master_param.data.copy_(param.data)
if self.verbose:
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def save_sharded_optimizer(self,
optimizer: OptimizerWrapper,
checkpoint: str,
gather_dtensor: bool = True,
prefix: Optional[str] = None,
size_per_shard: int = 1024):
"""
Save sharded optimizer checkpoint under the given checkpointing path.
The following files will be created under the path:
- An index file (pytorch_optim.bin.index.json) containing a map between optimizer states and file names
- A group file (pytorch_optim_group.bin) recording information of param_groups
- Multiple files that store state tensors of optimizers.
If pipeline parallelism is used, the filenames are in the form of "pytorch_optim.<prefix>-stage-000XX-shard-000XX.bin".
If pipeline parallelism is not used, "pytorch_optim.<prefix>-000XX.bin"
Args:
optimizer (OptimizerWrapper): Optimizer to save sharded state_dict
checkpoint (str): Path to save optimizer state_dict
gather_dtensor (bool): Whether to gather_dtensor, not used
prefix (str): Perfix of file to save
size_per_shard (int): Max file size of each file shard that store state tensors
"""
if os.path.isfile(checkpoint):
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
return
Path(checkpoint).mkdir(parents=True, exist_ok=True)
# Devices along the same dp_group share the same copies of states when zero is not used.
# In this case only let the device with dp_rank == 0 save the model.
if not self.use_zero and self.dp_rank != 0:
return
# Then collect the sharded states along dp_group(if using zero)/tp_group.
# Only devices with (dp_rank == 0 and tp_rank == 0) are responsible for states saving.
state_dict_shard = HypridParallelCheckpointIO._optimizer_sharder(
optimizer,
use_zero=self.use_zero,
dp_group=self.dp_group,
tp_group=self.tp_group,
master_to_working_map=self.master_to_working_map,
size_per_shard=size_per_shard)
states_name, save_index_file, param_group_file = get_optimizer_base_filenames(prefix)
index_file = CheckpointIndexFile(checkpoint)
control_saving = (self.dp_rank == 0 and self.tp_rank == 0)
if self.pp_size == 1:
# When pipeline is not used, save the optimizer shards as in general checkpointIO
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving)
if control_saving:
# Store param groups.
index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
# Store index file.
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
if self.verbose:
logging.info(f"The optimizer is going to be split to checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {save_index_file}.")
else:
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
# Manage filenames of sharded weights and index file for each pipeline stage.
states_name = states_name.replace(".bin", f"-stage-{self.pp_rank+1:05d}-shard.bin")
save_index_file = save_index_file.replace(".json", f"-stage-{self.pp_rank+1:05d}.json")
save_index_file = os.path.join("tmp_index_files", save_index_file)
total_size = save_state_dict_shards(sharded_state_dict=state_dict_shard,
checkpoint=checkpoint,
index_file=index_file,
base_filename=states_name,
is_master=control_saving,
use_pp_format=True)
if control_saving:
assert self.dp_rank == 0 and self.tp_rank == 0, "The saving process should have both dp_rank and tp_rank as 0."
index_file.append_meta_data("total_size", total_size)
index_file.write_index_file(save_index_file)
else:
return
dist.barrier(self.pp_group)
# The global master rank integrates the index files and clean the folder.
if self.pp_rank == 0:
final_index_file = CheckpointIndexFile(checkpoint)
final_index_file.append_meta_data("total_size", 0)
for filename in os.listdir(tmp_index_file_folder):
stage_index_file = CheckpointIndexFile.from_file(os.path.join(tmp_index_file_folder, filename))
final_index_file.metadata["total_size"] += stage_index_file.metadata["total_size"]
for param_id, state_filename in stage_index_file.weight_map.items():
final_index_file.append_weight_map(param_id, state_filename)
# Store param groups.
final_index_file.append_meta_data("param_groups", param_group_file)
group_file_path = os.path.join(checkpoint, param_group_file)
save_param_groups(optimizer.param_info, group_file_path)
final_index_file.write_index_file(final_index_file_path)
rmtree(tmp_index_file_folder)
if self.verbose:
logging.info(f"The model is split into checkpoint shards. "
f"You can find where each parameters has been saved in the "
f"index located at {final_index_file_path}.")
def load_sharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint_index_file: str, prefix: str = ""):
"""
Load sharded optimizer with the given path to index file of checkpoint folder.
Args:
optimizer (OptimizerWrapper): The optimizer to be loaded.
checkpoint_index_file (str): Path to the index file of checkpointing folder.
prefix (str): Not used.
"""
def _get_param_id_from_optimizer_param(param: torch.Tensor,
master_to_working_map: Optional[Dict[int, torch.Tensor]] = None):
if master_to_working_map is not None:
working_param = master_to_working_map[id(param)]
else:
working_param = param
return optimizer.param_info['param2id'][id(working_param)]
# id_map is a mapping from param ids kept by current pipeline, to their corresponding parameter objects.
# When Zero is used, the mapped parameter objects should be fp32 master parameters.
# IDs should be obtained through saved param2id mapping earlier saved in optimizer.param_info.
id_map = {}
for pg in optimizer.optim.param_groups:
for param in pg['params']:
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
id_map[param_id] = param
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
ckpt_root_path = ckpt_index_file.root_path
weight_map = ckpt_index_file.weight_map
weight_map = {int(k): v for k, v in weight_map.items()} # convert saved id from str to int
# Load param_groups
param_group_path = ckpt_index_file.get_param_group_filename()
if param_group_path is None:
raise RuntimeError(f'Invalid index file path {checkpoint_index_file} for an optimizer. \
Lacking param group file under current directory.')
saved_groups = torch.load(param_group_path)
updated_groups = []
for old_pg, saved_pg in zip(optimizer.optim.param_groups, saved_groups):
# obtain updated param group
new_pg = copy.deepcopy(saved_pg)
new_pg['params'] = old_pg['params'] # The parameters in the same group shouln't change.
updated_groups.append(new_pg)
optimizer.optim.__dict__.update({'param_groups': updated_groups})
# Load saved states to optimizer.
# Keep a record of loaded files so that file will not be repeatedly loaded.
loaded_file = set()
for pg in optimizer.optim.param_groups:
for param in pg['params']:
if param is None:
continue
param_id = _get_param_id_from_optimizer_param(param, self.master_to_working_map)
if param_id not in weight_map:
continue
filename = weight_map[param_id]
# If this param's states has been loaded before, directly return.
if filename in loaded_file:
continue
file_path = os.path.join(ckpt_root_path, filename)
state_dict = load_shard_state_dict(Path(file_path), use_safetensors=False)
load_states_into_optimizer(optimizer.optim, state_dict, id_map, strict=True)
loaded_file.add(filename)
# Then shard the loaded optimizer states if using tp/zero.
for param, state in optimizer.optim.state.items():
device = param.device
if self.master_to_working_map is not None:
working_param = self.master_to_working_map[id(param)]
else:
working_param = param
original_shape = optimizer.param_info['param2shape'][id(working_param)]
sharded_state = self.shard_from_complete_optimizer_state(state,
current_shape=working_param.shape,
original_shape=original_shape,
device=device,
inplace=True)
optimizer.optim.state[param] = sharded_state
sharded_optimizer_loading_epilogue(optimizer.optim)
if self.verbose:
logging.info(f"The optimizer has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
def load_unsharded_model(self, model: nn.Module, checkpoint: str, strict: bool = True):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def save_unsharded_model(self, model: nn.Module, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool):
# TODO(Baizhou): support this feature after implementing complete state_dict collection
raise NotImplementedError
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
"""
Save lr scheduler to checkpoint but only on master process.
"""
if self.coordinator.is_master():
super().save_lr_scheduler(lr_scheduler, checkpoint)
def link_master_and_working_param(self, working_to_master_map: Dict[Union[int, torch.Tensor], torch.Tensor],
master_to_working_map: Dict[Union[int, torch.Tensor], torch.Tensor]):
"""
Create mappings between working params (for forward/backward) and master params (for optimizer update) with passed in mappings.
This mapping can only be created when mixied precision is used.
The created mappings should be mappings from integer parameter addresses to parameter objects.
Args:
working_to_master_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from working parameters objects/addresses to master parameter objects.
master_to_working_map (Dict[Union[int, torch.Tensor], torch.Tensor]): A mapping from master parameters objects/addresses to working parameter objects.
"""
self.working_to_master_map = dict()
for k, v in working_to_master_map.items():
if isinstance(k, torch.Tensor):
self.working_to_master_map[id(k)] = v
elif isinstance(k, int):
self.working_to_master_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
self.master_to_working_map = dict()
for k, v in master_to_working_map.items():
if isinstance(k, torch.Tensor):
self.master_to_working_map[id(k)] = v
elif isinstance(k, int):
self.master_to_working_map[k] = v
else:
raise ValueError(
f"The passed in mapping should have keys of type 'int' or 'torch.Tensor', but got {type(k)}!")
@staticmethod
def gather_from_sharded_optimizer_state(state: OrderedDict, param: torch.Tensor, original_shape: torch.Size,
dp_group: ProcessGroup, tp_group: ProcessGroup, use_zero: bool,
inplace: bool) -> OrderedDict:
"""
With given parameter and its optimizer states, gather the complete optimizer state for saving.
Args:
state (OrderedDict): Optimizer states of given parameter, might be distributed among tp/dp group if using TP/Zero.
param (torch.Tensor): The given parameter. It should be working_param when using Zero.
original_shape (torch.Size): The size of parameter before sharding.
dp_group (ProcessGroup): The process group of data parallel.
tp_group (ProcessGroup): The process group of tensor parallel.
use_zero (bool): Whether Zero is used.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The complete optimizer state of given parameter.
"""
dp_size = dist.get_world_size(dp_group)
tp_size = dist.get_world_size(tp_group)
current_shape = param.shape
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step':
# First gather Zero shards.
if use_zero:
v = v.cuda()
gather_tensor = [torch.zeros_like(v) for _ in range(dp_size)]
dist.all_gather(gather_tensor, v, group=dp_group)
v = torch.stack(gather_tensor).view(-1)[:param.numel()].reshape_as(param)
# Then gather TP shards.
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
if partition_dim is not None:
gather_tensor = [torch.zeros_like(v) for _ in range(tp_size)]
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
state_[k] = v.detach().clone().cpu()
return state_
def shard_from_complete_optimizer_state(self, state: OrderedDict, current_shape: torch.Size,
original_shape: torch.Size, device: torch.device,
inplace: bool) -> OrderedDict:
"""
With complete optimizer states of a specific parameter loaded from checkpoint,
slice out the sharded optimizer states kept by current device.
Args:
state (OrderedDict): Complete optimizer states of a given parameter, loaded from checkpoint.
current_shape (torch.Size): The size of parameter after sharding.
original_shape (torch.Size): The size of parameter before sharding.
device (torch.device): The destination device of loaded optimizer states.
inplace (bool): If set to True, will update the values of argument 'state' in place. Else will make a copy of state.
Returns:
OrderedDict: The sharded optimizer state of the given parameter.
"""
state_ = state if inplace else copy.deepcopy(state)
for k, v in state_.items():
if isinstance(v, torch.Tensor) and k != 'step':
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
if partition_dim is not None:
slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
# Shard state along data parallel group when using Zero.
if self.use_zero:
padding_size = (self.dp_size - v.numel() % self.dp_size) % self.dp_size
with torch.no_grad():
v = v.flatten()
if padding_size > 0:
v = torch.nn.functional.pad(v, [0, padding_size])
slice_size = v.numel() // self.dp_size
v = v.split(slice_size, dim=0)[self.dp_rank]
state_[k] = v.detach().clone().to(device)
return state_

View File

@ -1,4 +1,5 @@
# coding=utf-8
import copy
import os
import re
from collections import abc as container_abcs
@ -10,10 +11,17 @@ from typing import Iterator, List, Mapping, Optional, OrderedDict, Tuple
import torch
import torch.nn as nn
from torch.optim import Optimizer
from transformers.modeling_utils import PreTrainedModel, get_parameter_dtype
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
to_global,
to_global_for_customized_distributed_tensor,
)
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@ -88,8 +96,35 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool:
return False
def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Size, tp_size: int) -> Optional[int]:
"""
Given the current shape of parameter and the shape of parameter before sharding,
return the dimension along which the parameter is sharded when using tensor parallel.
If tensor parallel is not used, return None.
Args:
current_shape (torch.Size): The current shape of parameter after sharding.
original_shape (torch.Size): The shape of parameter before sharding.
tp_size (int): The size of tp group.
Returns:
Optional[int]: The dimension along which parameter is partitioned.
"""
partition_dim = None
for dim, length in enumerate(original_shape):
if length > current_shape[dim]:
partition_dim = dim
break
if partition_dim is not None:
assert original_shape[partition_dim] == tp_size * current_shape[partition_dim], \
f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}"
return partition_dim
# ======================================
# Helper functions for saving shard file
# Helper classes and functions for saving shard file
# ======================================
def unwrap_optimizer(optimizer: OptimizerWrapper):
'''
@ -104,88 +139,32 @@ def unwrap_optimizer(optimizer: OptimizerWrapper):
return unwrapped_optim
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False) -> int:
'''
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is master.
use_safetensors (bool): Whether to use safetensors to save checkpoint.
class StateDictSharder:
Returns:
int: the total size of shards
'''
def __init__(self, size_per_shard: int) -> None:
self.max_shard_size = size_per_shard
self.current_block = OrderedDict()
self.current_block_size = 0
total_size = 0
for idx, shard_pair in enumerate(sharded_state_dict):
if not is_master:
continue
shard, current_size = shard_pair
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
def append_param(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
return total_size
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
current_block = {}
current_block_size = 0
for key, weight in state_dict.items():
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
if not is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[key] = weight
current_block_size += weight_size
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
if ret_block != None:
yield ret_block, ret_block_size
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size
yield current_block, current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state']
current_block = {}
current_block_size = 0
for param_id, state in states.items():
ret_block = None
ret_block_size = 0
def append_optim_state(self, param_id: int, state: OrderedDict) -> Tuple[Optional[OrderedDict], int]:
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
@ -204,21 +183,319 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor:
ret_block = None
ret_block_size = 0
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
# directly return if state is stored as distributed tensor
if isDTensor:
return ret_block, ret_block_size
current_block[param_id] = state
current_block_size += state_size
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + state_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
if ret_block != None:
yield ret_block, ret_block_size
self.current_block[param_id] = state
self.current_block_size += state_size
return ret_block, ret_block_size
yield current_block, current_block_size
def gather_distributed_param(param: torch.Tensor, keep_vars: bool = False) -> torch.Tensor:
"""
Gather the complete parameter for saving if passed in param is distributed under tp setting.
Args:
param (torch.Tensor): A model parameter, might be d_tensor.
keep_vars (bool, optional): Whether to return the parameter in calculation graph. Defaults to False.
Returns:
torch.Tensor: the complete parameter
"""
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
return to_global(param_)
elif is_customized_distributed_tensor(param_):
return to_global_for_customized_distributed_tensor(param_)
else:
return param_
def save_state_dict_shards(sharded_state_dict: Iterator[Tuple[OrderedDict, int]],
checkpoint: str,
index_file: "CheckpointIndexFile",
base_filename: str,
is_master: bool,
use_safetensors: bool = False,
use_pp_format: bool = False) -> int:
'''
Save sharded state dict only on master rank, this method can be used by both model and optimizer states.
Args:
sharded_state_dict (Iterator[Tuple[OrderedDict, int]]): a generator of shards, each shard contains state dict and shard size.
checkpoint (str): The path of checkpoint directory as string.
index_file (CheckpointIndexFile): The index file object to be updated.
base_filename (str): Decides the prefix of filenames of shards.
is_master (bool): Whether current rank is main process.
use_safetensors (bool, optional): Whether to use safetensors to save checkpoint. Defaults to False.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
Returns:
int: the total size of shards
'''
total_size = 0
shard_filenames = []
for idx, shard_pair in enumerate(sharded_state_dict):
shard, current_size = shard_pair
if not is_master:
del shard
continue
shard_file = get_shard_filename(base_filename, idx)
total_size = total_size + current_size
for key in shard.keys():
index_file.append_weight_map(key, shard_file)
checkpoint_file_path = os.path.join(checkpoint, shard_file)
# Only save on master rank.
save_state_dict(shard, checkpoint_file_path, use_safetensors=use_safetensors)
shard_filenames.append(shard_file)
del shard
# Clean folder, deleted unneeded files.
clean_folder(checkpoint, base_filename, shard_filenames, is_master=is_master, use_pp_format=use_pp_format)
return total_size
def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
state_dict_sharder = StateDictSharder(max_shard_size)
for key, weight in state_dict.items():
if not is_distributed_tensor(weight):
block, block_size = state_dict_sharder.append_param(key, weight)
if block != None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]:
"""
Splits an optimizer state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a
given size.
"""
# Only split state_dict['state']; state_dict['param_group'] is not considered in this function.
states = state_dict['state']
state_dict_sharder = StateDictSharder(max_shard_size)
for param_id, state in states.items():
block, block_size = state_dict_sharder.append_optim_state(param_id, state)
if block != None:
yield block, block_size
# Return the last block in sharder.
yield state_dict_sharder.current_block, state_dict_sharder.current_block_size
# ======================================
# Helper functions for saving state dict
# ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
Args:
state_dict (dict): state dict.
checkpoint_file_path (str): path to the checkpoint file.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
"""
Save information of param_groups to given file path.
Args:
state_dict (dict): state dict.
group_file_path (str): path to the group file.
"""
param_groups = state_dict["param_groups"]
torch.save(param_groups, group_file_path)
def clean_folder(checkpoint_path: str,
weights_name: str,
shard_filenames: List[str],
is_master: bool = True,
use_pp_format: bool = False):
"""
Clean the unneeded files in checkpoint directory after shards of state_dict have been saved.
Args:
checkpoint_path (str): Path to the checkpoint directory.
weights_name (str): Decides the prefix of filenames of weight shards.
shard_filenames (List[str]): The list of saved shard filenames which should not be removed.
is_master (bool, optional): Whether current rank is main process. Defaults to True.
use_pp_format: (bool, optional): Whether to save the files in pipeline format including stage information. Defaults to False.
"""
if is_master:
for filename in os.listdir(checkpoint_path):
full_filename = os.path.join(checkpoint_path, filename)
weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
if not use_pp_format:
reg = re.compile(r"(.*?)-\d{5}")
else:
# When this checkpoint is created by pipeline parallel process, the pattern is a little different.
reg = re.compile(r"(.*?)-stage-\d{5}-shard-\d{5}")
if (filename.startswith(weights_no_suffix) and os.path.isfile(full_filename)
and filename not in shard_filenames and reg.fullmatch(filename_no_suffix) is not None):
os.remove(full_filename)
def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = True):
"""
Save config.json/generation_config.json if model is a Huggingface pretrained model.
This method can only be called when a model is saved in a sharded way.
Args:
model (nn.Module): The model whose config should be saved if it's a huggingface model.
checkpoint_path (str): Path to the checkpoint directory.
is_master (bool): Whether current rank is main process.
"""
if not isinstance(model, PreTrainedModel):
return
model = unwrap_huggingface_model(model)
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
dtype = get_parameter_dtype(model)
model.config.torch_dtype = str(dtype).split(".")[1]
# Attach architecture to the config
model.config.architectures = [model.__class__.__name__]
# Save the config
if is_master:
model.config.save_pretrained(checkpoint_path)
if model.can_generate():
model.generation_config.save_pretrained(checkpoint_path)
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
only one tensor.
Args:
tensor (Tensor): tensor to be saved.
index_file (CheckpointIndexFile): path to the checkpoint file.
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
output_root_path = root_path.joinpath('dtensor')
# create directory
output_root_path.mkdir(exist_ok=True)
# save tensor to this directory
# TODO(YuliangLiu): get index of the tensor shard
# e.g. index =
index = 0
# save tensor to file
ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
# dtensor ckpt file always contains only one tensor
state_dict = {name: tensor}
save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
# update the weight map
# * means all shards
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
"""
Get checkpoint file suffix.
Args:
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: checkpoint file suffix.
"""
if use_safetensors:
return '.safetensors'
else:
return '.bin'
def generate_checkpoint_shard_file_name(index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None) -> str:
"""
Generate checkpoint shard file name.
Args:
index (int): index of the shard.
total_number (int): total number of shards.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
prefix (str): prefix of the shard file name. Default: None.
Returns:
str: checkpoint shard file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
if prefix is None:
return f"{index:05d}-of-{total_number:05d}.{suffix}"
else:
return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
"""
Generate dtensor file name.
Args:
param_name (str): name of the distributed parameter.
index (int): index of the shard.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
return f'{param_name}.{index}.{suffix}'
# ========================================
# Helper functions for loading state dict
# ========================================
def load_shard_state_dict(checkpoint_file: Path, use_safetensors: bool = False):
@ -331,17 +608,21 @@ def load_param_groups_into_optimizer(optimizer: Optimizer, param_group_path: str
return id_map
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict):
def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: dict, strict: bool = False):
r"""Copies states from `state_dict` into an Optimizer object.
Args:
optimizer(Optimizer): An initialized Optimizer object to be loaded
state_dict(dict): a mapping from tensor index (an integer)
state_dict(dict): A mapping from tensor index (an integer)
to its states to be loaded (a mapping from state name to a tensor).
id_map(dict): a mapping from tensor index (an integer)
id_map(dict): A mapping from tensor index (an integer)
to its corresponding parameter (a tensor) whose states will be updated.
strict(bool, optional): If set to True, only load the parameters with its id in id_map. Defaults to False.
"""
# Ensure that the keys of state_dict are integers.
state_dict = {int(k): v for k, v in state_dict.items()}
def cast(param, value, key=None):
r"""Make a deep copy of value, casting all tensors to device of param."""
if isinstance(value, torch.Tensor):
@ -368,7 +649,7 @@ def load_states_into_optimizer(optimizer: Optimizer, state_dict: dict, id_map: d
if k in id_map:
param = id_map[k]
new_states[param] = cast(param, v)
else:
elif not strict:
new_states[k] = v
optimizer.state.update(new_states)
@ -386,165 +667,6 @@ def sharded_optimizer_loading_epilogue(optimizer: Optimizer):
optimizer.defaults.setdefault('differentiable', False)
# ======================================
# Helper functions for saving state dict
# ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
Args:
state_dict (dict): state dict.
checkpoint_file_path (str): path to the checkpoint file.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith('.safetensors'), \
"safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
"""
Save information of param_groups to given file path.
Args:
state_dict (dict): state dict.
group_file_path (str): path to the group file.
"""
param_groups = state_dict["param_groups"]
torch.save(param_groups, group_file_path)
def save_dtensor(name: str, tensor: torch.Tensor, index_file: "CheckpointIndexFile", use_safetensors: bool) -> None:
"""
Save distributed tensor to checkpoint. This checkpoint will be a dictionary which contains
only one tensor.
Args:
tensor (Tensor): tensor to be saved.
index_file (CheckpointIndexFile): path to the checkpoint file.
size_per_shard (int): size per shard in MB.
"""
root_path = index_file.root_path
output_root_path = root_path.joinpath('dtensor')
# create directory
output_root_path.mkdir(exist_ok=True)
# save tensor to this directory
# TODO(YuliangLiu): get index of the tensor shard
# e.g. index =
index = 0
# save tensor to file
ckpt_file_name = generate_dtensor_file_name(name, index, use_safetensors)
ckpt_file_path = output_root_path.joinpath(ckpt_file_name)
# dtensor ckpt file always contains only one tensor
state_dict = {name: tensor}
save_state_dict(state_dict, str(ckpt_file_path), use_safetensors)
# update the weight map
# * means all shards
ckpt_file_name_in_weight_map = 'dtensor/' + generate_dtensor_file_name(name, '*', use_safetensors)
index_file.append_weight_map(name, ckpt_file_name_in_weight_map)
def get_checkpoint_file_suffix(use_safetensors: bool) -> str:
"""
Get checkpoint file suffix.
Args:
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: checkpoint file suffix.
"""
if use_safetensors:
return '.safetensors'
else:
return '.bin'
def generate_checkpoint_shard_file_name(index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None) -> str:
"""
Generate checkpoint shard file name.
Args:
index (int): index of the shard.
total_number (int): total number of shards.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
prefix (str): prefix of the shard file name. Default: None.
Returns:
str: checkpoint shard file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
if prefix is None:
return f"{index:05d}-of-{total_number:05d}.{suffix}"
else:
return f"{prefix}-{index:05d}-of-{total_number:05d}.{suffix}"
def generate_dtensor_file_name(param_name: str, index: int, use_safetensors: bool) -> str:
"""
Generate dtensor file name.
Args:
param_name (str): name of the distributed parameter.
index (int): index of the shard.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
Returns:
str: dtensor file name.
"""
suffix = get_checkpoint_file_suffix(use_safetensors)
return f'{param_name}.{index}.{suffix}'
def save_state_dict_as_shard(
state_dict: dict,
checkpoint_path: str,
index: int,
total_number: int,
use_safetensors: bool,
prefix: str = None,
) -> None:
"""
Save state dict as shard.
Args:
state_dict (dict): state dict.
checkpoint_path (str): path to the checkpoint file.
index (int): index of the shard.
total_number (int): total number of shards.
prefix (str): prefix of the shard file name.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
# generate the shard name
shard_file_name = generate_checkpoint_shard_file_name(index, total_number, use_safetensors, prefix)
shard_file_path = Path(checkpoint_path).joinpath(shard_file_name).absolute()
# save the shard
save_state_dict(state_dict, str(shard_file_path), use_safetensors)
# ========================================
# Helper functions for loading state dict
# ========================================
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
"""
Check whether the checkpoint has an index file.
@ -654,5 +776,5 @@ def get_shard_filename(weights_name: str, idx: int):
get shard file name
"""
shard_file = weights_name.replace(".bin", f"-{idx+1:05d}.bin")
shard_file = shard_file.replace(".safetensors", f"-{idx + 1:05d}.safetensors")
shard_file = shard_file.replace(".safetensors", f"-{idx+1:05d}.safetensors")
return shard_file

View File

@ -94,17 +94,23 @@ class ProcessGroupMesh:
return np.unravel_index(rank, shape)
@staticmethod
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...]) -> int:
def ravel(coord: Tuple[int, ...], shape: Tuple[int, ...], mode: str = 'raise') -> int:
"""Convert a coordinate to a rank.
mode: ['raise', 'wrap', 'clip'], see https://numpy.org/doc/stable/reference/generated/numpy.ravel_multi_index.html.
with wrap, index out of range would be wrapped around.
For instance, ravel((0, i, 0), (1, 2, 1), 'wrap') returns (i % 2)
Args:
coords (Tuple[int, ...]): Coordinate to be converted.
shape (Tuple[int, ...]): Shape of the process group mesh.
mode (Optional[str]): The mode for numpy.ravel_multi_index.
Returns:
int: Rank of the coordinate.
"""
return np.ravel_multi_index(coord, shape)
assert mode in ["raise", "wrap", "clip"]
return np.ravel_multi_index(coord, shape, mode)
def get_group(self, ranks_in_group: List[int], backend: Optional[str] = None) -> ProcessGroup:
"""Get the process group with the given ranks. It the process group doesn't exist, it will be created.

View File

@ -173,14 +173,10 @@ class PipelineP2PCommunication:
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank,
self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
input_tensor = _recv_object(prev_rank, cur_rank, self.stage_manager.get_p2p_process_group(prev_rank, cur_rank))
return input_tensor
@ -193,14 +189,11 @@ class PipelineP2PCommunication:
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(next_rank, cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
output_tensor_grad = _recv_object(next_rank, cur_rank,
self.stage_manager.get_p2p_process_group(next_rank, cur_rank))
return output_tensor_grad
@ -211,12 +204,10 @@ class PipelineP2PCommunication:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank,
self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
if next_rank is None:
next_rank = self.stage_manager.get_next_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(output_object, cur_rank, next_rank, self.stage_manager.get_p2p_process_group(cur_rank, next_rank))
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
@ -225,9 +216,7 @@ class PipelineP2PCommunication:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank,
self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))
if prev_rank is None:
prev_rank = self.stage_manager.get_prev_rank()
cur_rank = self.stage_manager.get_rank()
_send_object(input_object, cur_rank, prev_rank, self.stage_manager.get_p2p_process_group(cur_rank, prev_rank))

View File

@ -1,9 +1,59 @@
from typing import Any, List, Optional
from collections import OrderedDict
from typing import Any, List, Optional, Tuple
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
from torch.utils._pytree import (
SUPPORTED_NODES,
LeafSpec,
TreeSpec,
_is_leaf,
_register_pytree_node,
tree_flatten,
tree_map,
tree_unflatten,
)
# this register are for torch under version 1.13.1, maybe removed in the future
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Any]:
return list(d.values()), list(d.keys())
def _odict_unflatten(values: List[Any], context: Any) -> 'OrderedDict[Any, Any]':
return OrderedDict((key, value) for key, value in zip(context, values))
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten)
def tree_map_hf(fn: Any, pytree: Any):
flat_args, spec = tree_flatten_hf(pytree)
return tree_unflatten([fn(i) for i in flat_args], spec)
# use this flatten function to handle the ModelingOutput Class instance.
def tree_flatten_hf(pytree: Any) -> Tuple[List[Any], TreeSpec]:
"""Flattens a pytree into a list of values an a TreeSpec that can be used
to reconstruct the pytree.
"""
if isinstance(pytree, OrderedDict):
node_type = OrderedDict
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn
child_pytrees, context = flatten_fn(pytree)
# Recursively flatten the children
result: List[Any] = []
children_specs: List['TreeSpec'] = []
for child in child_pytrees:
flat, child_spec = tree_flatten_hf(child)
result += flat
children_specs.append(child_spec)
return result, TreeSpec(node_type, context, children_specs)
else:
result, tree_spec = tree_flatten(pytree)
return result, tree_spec
def to_device(x: Any, device: Optional[torch.device] = None) -> Any:
@ -104,7 +154,7 @@ def detach(x: Any) -> Any:
return x
def merge_batch(data: List[Any]) -> Any:
def merge_batch(data: List[Any], batch_size_dim=0) -> Any:
"""Merge micro batches into a batch.
Args:
@ -118,12 +168,17 @@ def merge_batch(data: List[Any]) -> Any:
flattened_data = []
tree_spec = None
for d in data:
elems, tree_spec = tree_flatten(d)
# elems should be an instance of OrderedDict
elems, tree_spec = tree_flatten_hf(d)
flattened_data.append(elems)
merged_data = []
for elem_batch in zip(*flattened_data):
if isinstance(elem_batch[0], torch.Tensor):
merged_data.append(torch.cat(elem_batch, dim=0))
if len(elem_batch[0].shape) == 0: # set loss to None in pipeline outputs
merged_data.append(None)
else:
merged_data.append(torch.cat(elem_batch, dim=batch_size_dim))
else:
merged_data.append(list(elem_batch))
return tree_unflatten(merged_data, tree_spec)

View File

@ -0,0 +1,370 @@
from functools import partial
from typing import Any, Callable, Iterable, List, Optional, Union
import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from .base import PipelineSchedule
class InterleavedSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, num_model_chunks: int, stage_manager: PipelineStageManager) -> None:
self.num_model_chunks = num_model_chunks
assert num_microbatches % self.num_model_chunks == 0, \
"Number of microbatches should be an integer multiple of number of model chunks"
super().__init__(stage_manager)
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
Args:
data_iter (Iterable): Data iterator.
device (Optional[torch.device], optional): Target device. Defaults to None.
"""
batch = next(data_iter)
if device is not None:
batch = tree_map(partial(to_device, device=device), batch)
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = [0 for _ in range(self.num_model_chunks)]
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
def load_micro_batch(self, model_chunk_id: int) -> Any:
"""Load a micro batch from the current batch.
Args:
microbatch_id (int): the current model chunk idx.
Returns:
Any: Micro batch.
"""
micro_batch = get_micro_batch(self.batch, self.microbatch_offset[model_chunk_id], self.microbatch_size)
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def get_model_chunk_id(self, microbatch_id: int, forward: bool) -> int:
"""Helper method to get the model chunk ID given the iteration number.
Args:
microbatch_id (int): the current microbatch idx
forward (bool): if is the forward process
Returns:
int: The model chunk idx of the input microbatch_id
"""
microbatch_id_in_group = (microbatch_id) % (self.stage_manager.num_stages * self.num_model_chunks)
model_chunk_id = microbatch_id_in_group // self.stage_manager.num_stages
if not forward:
model_chunk_id = (self.num_model_chunks - model_chunk_id - 1)
return model_chunk_id
def is_first_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the first stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the first stage.
"""
if self.stage_manager.is_first_stage() and model_chunk_id == 0:
return True
return False
def is_last_stage(self, model_chunk_id: int) -> bool:
"""Is the current virtual stage the last stage
Args:
model_chunk_id (int): The current model chunk idx.
Returns:
bool: Whether the current virtual stage is the last stage.
"""
if self.stage_manager.is_last_stage() and model_chunk_id == self.num_model_chunks - 1:
return True
return False
def recv_forward(self, model_chunk_id: int, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
"""
if self.is_first_stage(model_chunk_id):
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
return input_tensor
def recv_backward(self, model_chunk_id: int, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.is_last_stage(model_chunk_id):
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
return output_tensor_grad
def send_forward(self, model_chunk_id, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.is_last_stage(model_chunk_id):
self.comm.send_forward(output_object, next_rank)
def send_backward(self, model_chunk_id, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For interleaved 1F1B.
Args:
model_chunk_id (int): The current model chunk idx.
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.is_first_stage(model_chunk_id):
self.comm.send_backward(input_object, prev_rank)
def forward_step(self,
model_chunk: Module,
model_chunk_id: int,
input_obj: Optional[dict],
criterion: Callable,
accum_loss: Optional[torch.Tensor] = None,
outputs: Optional[List[Any]] = None) -> Union[torch.Tensor, dict]:
"""Forward one step of the pipeline
Args:
model (Module): Model Chunk to be run
input_obj (Optional[dict]): The output from the previous stage. If it is the first stage, the `input_obj` is None.
criterion (Callable): Criterion to calculate loss.
accum_loss (Optional[torch.Tensor], optional): Accumulated loss. Defaults to None.
outputs (Optional[List[Any]], optional): List to store the output of the last stage (final output). Defaults to None.
Returns:
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
micro_batch = self.load_micro_batch(model_chunk_id=model_chunk_id)
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
if self.is_last_stage(model_chunk_id):
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
return loss
else:
return output_obj
def backward_step(self, optimizer: OptimizerWrapper, input_obj: Optional[dict],
output_obj: Union[dict, torch.Tensor], output_obj_grad: Optional[dict]) -> Optional[dict]:
"""Backward one step of the pipeline
Args:
optimizer (OptimizerWrapper): Optimizer to update the model
input_obj (Optional[dict]): Output of the previous stage. If it is the first stage, the `input_obj` is None.
output_obj (Union[dict, torch.Tensor]): Output of the current stage. If it is the last stage, the output is the loss (Tensor).
output_obj_grad (dict): Gradient of the `output_obj`. If it is the last stage, the `output_obj_grad` is None.
Returns:
Optional[dict]: Gradient of the `input_obj`. If it is the first stage, the `input_obj_grad` is None.
"""
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
# Backward pass.
if output_obj_grad is None:
optimizer.backward(output_obj)
else:
if "backward_tensor_keys" not in output_obj:
for k, grad in output_obj_grad.items():
optimizer.backward_by_grad(output_obj[k], grad)
else:
for k, grad in output_obj_grad.items():
output_obj[k].grad = grad
for k in output_obj["backward_tensor_keys"]:
tensor_to_backward = output_obj[k]
optimizer.backward_by_grad(tensor_to_backward, tensor_to_backward.grad)
# Collect the grad of the input_obj.
input_obj_grad = None
if input_obj is not None:
input_obj_grad = {}
for k, v in input_obj.items():
if isinstance(v, torch.Tensor) and v.grad is not None:
input_obj_grad[k] = v.grad
return input_obj_grad
def forward_backward_step(self,
model_chunk: Module,
optimizer: OptimizerWrapper,
data_iter: Iterable,
criterion: Callable[..., Any],
return_loss: bool = False,
return_outputs: bool = False) -> dict:
"""Runs interleaved 1F1B schedule, with communication between pipeline stages.
Args:
model_chunk (List[Module]): Model Chunk to be trained.
optimizer (OptimizerWrapper): Optimizer to be used.
data_iter (Iterable): Data iterator.
criterion (Callable[[Any, Any], Tensor]): Criterion to be used. It should take two arguments: model outputs and inputs, and returns loss tensor.
return_loss (bool, optional): Whether to return loss. Defaults to False. Whether to return loss.
return_outputs (bool, optional): Whether to return model outputs. Defaults to False. Whether to return model outputs.
Returns:
dict: A dict with keys: 'loss' and 'outputs'.
"""
forward_only = not torch.is_grad_enabled()
self.load_batch(data_iter)
num_model_chunks = len(model_chunk)
# num_warmup_microbatches is the step when not all the processes are working
num_microbatches = self.num_microbatches * num_model_chunks
if forward_only:
num_warmup_microbatches = num_microbatches
else:
num_warmup_microbatches = (self.stage_manager.num_stages - self.stage_manager.stage - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * self.stage_manager.num_stages
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_microbatches_remaining = num_microbatches - num_warmup_microbatches
# Input, output tensors only need to be saved when doing backward passes
input_objs = None
output_objs = None
if not forward_only:
input_objs = [[] for _ in range(num_model_chunks)]
output_objs = [[] for _ in range(num_model_chunks)]
outputs = [] if return_outputs and self.stage_manager.is_last_stage() else None
if return_loss and self.stage_manager.is_last_stage():
accum_loss = torch.zeros(1, device=get_current_device())
else:
accum_loss = None
# for ranks except the first one, get into recv state
# print(self.stage_manager.stage,num_microbatches, num_warmup_microbatches, num_microbatches_remaining)
input_obj = self.recv_forward(0)
input_objs[0].append(input_obj)
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=True)
# recv first on first rank to avoid sending or recving at the same time
if self.stage_manager.is_first_stage():
input_obj = self.recv_forward(model_chunk_id)
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
self.send_forward(model_chunk_id, output_obj)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
else:
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if not forward_only:
output_objs[model_chunk_id].append(output_obj)
self.send_forward(model_chunk_id, output_obj)
if num_microbatches_remaining == 0 and i + 1 == num_warmup_microbatches:
break
else:
model_chunk_id = self.get_model_chunk_id(i + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id)
if not forward_only:
input_objs[model_chunk_id].append(input_obj)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches, forward=True)
last_iteration = (i == (num_microbatches_remaining - 1))
output_obj = self.forward_step(model_chunk, model_chunk_id, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.send_forward(model_chunk_id, output_obj)
if not last_iteration:
input_obj = self.recv_forward(model_chunk_id)
else:
self.send_forward(model_chunk_id, output_obj)
# Add input_obj and output_obj to end of list.
input_objs[model_chunk_id].append(input_obj)
output_objs[model_chunk_id].append(output_obj)
model_chunk_id = self.get_model_chunk_id(i, forward=False)
output_obj_grad = self.recv_backward(model_chunk_id)
# Pop output_obj and output_obj from the start of the list for
# the backward pass.
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
# backward
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
if last_iteration:
input_obj = None
else:
model_chunk_id = self.get_model_chunk_id(i + num_warmup_microbatches + 1, forward=True)
input_obj = self.recv_forward(model_chunk_id)
model_chunk_id = self.get_model_chunk_id(i, forward=False)
self.send_backward(model_chunk_id, input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
for i in range(num_microbatches_remaining, num_microbatches):
model_chunk_id = self.get_model_chunk_id(i, forward=False)
# print(f"{self.stage_manager.stage}/{model_chunk_id}: {len(input_objs[model_chunk_id])} {len(output_objs[model_chunk_id])} {i}")
input_obj = input_objs[model_chunk_id].pop(0)
output_obj = output_objs[model_chunk_id].pop(0)
output_obj_grad = self.recv_backward(model_chunk_id)
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.send_backward(model_chunk_id, input_obj_grad)
if outputs is not None:
outputs = merge_batch(outputs)
return {'loss': accum_loss, 'outputs': outputs}

View File

@ -6,25 +6,47 @@ import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map
from colossalai.interface import OptimizerWrapper
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.pipeline.p2p import PipelineP2PCommunication
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.utils.cuda import get_current_device
from ._utils import detach, get_batch_size, get_micro_batch, merge_batch, model_forward, retain_grad, to_device
from ._utils import (
detach,
get_batch_size,
get_micro_batch,
merge_batch,
model_forward,
retain_grad,
to_device,
tree_map_hf,
)
from .base import PipelineSchedule
class OneForwardOneBackwardSchedule(PipelineSchedule):
def __init__(self, num_microbatches: int, stage_manager: PipelineStageManager) -> None:
def __init__(self,
stage_manager: PipelineStageManager,
num_microbatches: Optional[int] = None,
microbatch_size: Optional[int] = None) -> None:
"""1F1B pipeline schedule.
Args:
stage_manager (PipelineStageManager): Pipeline stage manager
num_microbatches (Optional[int], optional): The number of microbatches. If not provided, it will be derived from microbatch size. Defaults to None.
microbatch_size (Optional[int], optional): Microbatch size. If num_microbatches is provided, this will be ignored. Defaults to None.
"""
super().__init__(stage_manager)
assert num_microbatches is not None or microbatch_size is not None, \
"Either num_microbatches or microbatch_size should be provided"
self.comm = PipelineP2PCommunication(stage_manager)
self.num_microbatches = num_microbatches
self.microbatch_size = microbatch_size
self.batch: Optional[Any] = None
self.batch_size: Optional[int] = None
self.microbatch_offset: Optional[int] = None
self.microbatch_size: Optional[int] = None
self._use_microbatch_size = num_microbatches is None
def load_batch(self, data_iter: Iterable, device: Optional[torch.device] = None) -> None:
"""Load a batch from data iterator.
@ -39,9 +61,14 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.batch = batch
self.batch_size = get_batch_size(batch)
self.microbatch_offset = 0
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
if not self._use_microbatch_size:
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = self.batch_size // self.num_microbatches
else:
assert self.batch_size % self.microbatch_size == 0, \
"Batch size should divided by the microbatch size"
self.num_microbatches = self.batch_size // self.microbatch_size
def load_micro_batch(self) -> Any:
"""Load a micro batch from the current batch.
@ -53,6 +80,62 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
self.microbatch_offset += self.microbatch_size
return tree_map(partial(to_device, device=get_current_device()), micro_batch)
def recv_forward(self, prev_rank: int = None) -> Any:
"""Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
For 1F1B.
Args:
prev_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input tensor or input tensor list.
"""
if self.stage_manager.is_first_stage():
input_tensor = None
else:
input_tensor = self.comm.recv_forward(prev_rank)
return input_tensor
def recv_backward(self, next_rank: int = None) -> Any:
"""Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
For 1F1B.
Args:
next_rank (int, optional): The rank of the source of the tensor.
Returns:
Any: The input gradient tensor or gradient tensor list.
"""
if self.stage_manager.is_last_stage():
output_tensor_grad = None
else:
output_tensor_grad = self.comm.recv_backward(next_rank)
return output_tensor_grad
def send_forward(self, output_object: Any, next_rank: int = None) -> None:
"""Sends the input tensor to the next stage in pipeline.
For 1F1B.
Args:
output_object (Any): Object to be sent.
next_rank (int, optional): The rank of the recipient of the tensor.
"""
if not self.stage_manager.is_last_stage():
self.comm.send_forward(output_object, next_rank)
def send_backward(self, input_object: Any, prev_rank: int = None) -> None:
"""Sends the gradient tensor to the previous stage in pipeline.
For 1F1B.
Args:
input_object (Any): Object to be sent.
prev_rank (int, optional): The rank of the recipient of the tensor
"""
if not self.stage_manager.is_first_stage():
self.comm.send_backward(input_object, prev_rank)
def forward_step(self,
model: Module,
input_obj: Optional[dict],
@ -72,16 +155,16 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
Union[torch.Tensor, dict]: The intermediate output (dict) of the current stage. If it is the last stage, the output is the loss (Tensor).
"""
micro_batch = self.load_micro_batch()
# for the first stage, input_obj is None
# for the non-first stage, input_obj is the output of the previous stage and it's must be a dict
output_obj = model_forward(model, micro_batch, input_obj)
if self.stage_manager.is_last_stage():
loss = criterion(output_obj, micro_batch) / self.num_microbatches
if accum_loss is not None:
accum_loss.add_(loss.detach())
if outputs is not None:
outputs.append(tree_map(detach, output_obj))
outputs.append(tree_map_hf(detach, output_obj))
return loss
else:
return output_obj
@ -102,7 +185,6 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Retain the grad on the input_obj.
tree_map(retain_grad, input_obj)
# Backward pass.
if output_obj_grad is None:
optimizer.backward(output_obj)
@ -171,11 +253,11 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_obj = self.comm.recv_forward()
input_obj = self.recv_forward()
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
self.comm.send_forward(output_obj)
self.send_forward(output_obj)
if not forward_only:
input_objs.append(input_obj)
@ -185,7 +267,7 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_obj = self.comm.recv_forward()
input_obj = self.recv_forward()
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
@ -193,15 +275,15 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
output_obj = self.forward_step(model, input_obj, criterion, accum_loss, outputs)
if forward_only:
self.comm.send_forward(output_obj)
self.send_forward(output_obj)
if not last_iteration:
input_obj = self.comm.recv_forward()
input_obj = self.recv_forward()
else:
# TODO adjust here
self.comm.send_forward(output_obj)
output_obj_grad = self.comm.recv_backward()
self.send_forward(output_obj)
output_obj_grad = self.recv_backward()
# Add input_obj and output_obj to end of list.
input_objs.append(input_obj)
@ -216,8 +298,8 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
if last_iteration:
input_obj = None
else:
input_obj = self.comm.recv_forward()
self.comm.send_backward(input_obj_grad)
input_obj = self.recv_forward()
self.send_backward(input_obj_grad)
# Run cooldown backward passes.
if not forward_only:
@ -225,10 +307,12 @@ class OneForwardOneBackwardSchedule(PipelineSchedule):
input_obj = input_objs.pop(0)
output_obj = output_objs.pop(0)
output_obj_grad = self.comm.recv_backward()
output_obj_grad = self.recv_backward()
input_obj_grad = self.backward_step(optimizer, input_obj, output_obj, output_obj_grad)
self.comm.send_backward(input_obj_grad)
self.send_backward(input_obj_grad)
if outputs is not None:
outputs = merge_batch(outputs)
if isinstance(model, ModelWrapper):
model = model.unwrap()
outputs = merge_batch(outputs, getattr(model, 'batch_size_dim', 0))
return {'loss': accum_loss, 'outputs': outputs}

View File

@ -17,28 +17,24 @@ class PipelineStageManager:
Attributes:
num_stages (int): Number of stages in the pipeline.
stage (int): The current stage.
num_virtual_stages (int): Number of virtual stages in the pipeline.
virtual_stage (int): The current virtual stage.
"""
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None:
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int, is_virtual: bool = False) -> None:
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.num_virtual_stages: Optional[int] = None
self.virtual_stage: Optional[int] = None
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
# init prev and next coord
coord = self.pg_mesh.coordinate()
if self.stage > 0:
prev_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape)
if self.stage < self.num_stages - 1:
next_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape)
# the prev rank of rank0 is the last rank
prev_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape, mode='wrap')
# the next rank of the last rank is rank0
next_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape, mode='wrap')
# init p2p process groups
stages = list(range(self.num_stages))
@ -48,32 +44,28 @@ class PipelineStageManager:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
def is_first_stage(self, virtual: bool = False) -> bool:
"""Is the current stage the first stage.
if is_virtual:
# add the process group of the first rank and the last rank
# only used in interleaved pipeline for now
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [stages[0], stages[-1]])
if self.stage in [stages[0], stages[-1]]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
def is_first_stage(self) -> bool:
"""Is the current stage the first stage.
Returns:
bool: Whether the current stage is the first stage.
"""
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == 0
return self.stage == 0
def is_last_stage(self, virtual: bool = False) -> bool:
def is_last_stage(self) -> bool:
"""Is the current stage the last stage.
Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
Returns:
bool: Whether the current stage is the last stage.
"""
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == self.num_virtual_stages - 1
return self.stage == self.num_stages - 1
@property
@ -108,7 +100,6 @@ class PipelineStageManager:
Returns:
int: Rank of the previous stage.
"""
assert not self.is_first_stage(), "Cannot get previous rank in the first stage."
return self.prev_rank
def get_next_rank(self) -> int:
@ -117,39 +108,8 @@ class PipelineStageManager:
Returns:
int: Rank of the next stage.
"""
assert not self.is_last_stage(), "Cannot get next rank in the last stage."
return self.next_rank
def set_num_virtual_stages(self, num_virtual_stages: int) -> None:
"""Set the number of virtual stages.
Args:
num_virtual_stages (int): Number of virtual stages.
"""
self.num_virtual_stages = num_virtual_stages
def set_virtual_stage(self, virtual_stage: int) -> None:
"""Set the virtual stage.
Args:
virtual_stage (int): Virtual stage.
"""
self.virtual_stage = virtual_stage
@contextmanager
def switch_virtual_stage(self, virtual_stage: int) -> None:
"""A context manager to switch virtual stage.
Args:
virtual_stage (int): Target virtual stage.
"""
old_stage = self.virtual_stage
try:
self.set_virtual_stage(virtual_stage)
yield
finally:
self.set_virtual_stage(old_stage)
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.

View File

@ -429,12 +429,13 @@ As shown in the figures above, when the sequence length is around 1000 or greate
### Convergence
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](./examples/convergence_benchmark.py) using both shardformer and non-shardformer approaches. We compared the accuracy, loss, F1 score of the training results.
To validate that training the model using shardformers does not impact its convergence. We [fine-tuned the BERT model](../../examples/language/bert/finetune.py) using both shardformer and non-shardformer approaches. The example that utilizes Shardformer simultaneously with Pipeline Parallelism and Data Parallelism (Zero1). We then compared the accuracy, loss, and F1 score of the training results.
| accuracy | f1 | loss | GPU number | model shard |
| accuracy | f1 | loss | GPU number | model sharded |
| :------: | :-----: | :-----: | :--------: | :---------: |
| 0.82594 | 0.87441 | 0.09913 | 4 | True |
| 0.81884 | 0.87299 | 0.10120 | 2 | True |
| 0.81855 | 0.87124 | 0.10357 | 1 | False |
| 0.84589 | 0.88613 | 0.43414 | 4 | True |
| 0.83594 | 0.88064 | 0.43298 | 1 | False |
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.

View File

@ -1,3 +1,5 @@
from typing import Any
import torch
import torch.distributed as dist
import torch.nn.functional as F
@ -141,6 +143,240 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
return grad_input, grad_weight, grad_bias, None, None, None
class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
overlap (`bool`): Whther to overlap the all_gather op and gradient calculate in backward.
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group)
if bias is not None:
output = F.linear(input_parallel, weight, bias)
else:
output = F.linear(input_parallel, weight)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
if not overlap:
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
input_ = input_.contiguous()
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = grad_output.t().matmul(input_parallel)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
class _LinearWithReduceScatterForwardGatherBackward(torch.autograd.Function):
"""Gather input from sequence parallel in forward and reduce-scatter gradient in backward
Args:
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
@staticmethod
def forward(ctx, input_, process_group, dim):
ctx.dim = dim
ctx.process_group = process_group
# do reduce-scatter
new_shape = list(input_.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // dist.get_world_size(process_group)
input_list = [item.contiguous() for item in torch.chunk(input_, dist.get_world_size(process_group), dim=dim)]
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_list, group=process_group)
return output
@staticmethod
def backward(ctx, grad_output):
dim = ctx.dim
process_group = ctx.process_group
return _gather(grad_output, dim, process_group), None, None
class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
"""
This class is designed for matmul operation with gather forward and reduce-scatter backward.
Args:
input_ (`torch.Tensor`): input matrix.
dim (int): the dimension to perform split and gather
process_group (`torch.distributed.ProcessGroup`): the process group used for collective communication
"""
@staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight)
ctx.use_bias = bias is not None
ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
ctx.dim = dim
ctx.overlap = overlap
input_parallel = _gather(input_, dim, process_group)
output = torch.matmul(input_parallel, weight)
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
use_bias = ctx.use_bias
dim = ctx.dim
process_group = ctx.process_group
overlap = ctx.overlap
if not overlap:
input_parallel = _gather(input_, dim, process_group)
total_input = input_parallel
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
total_input = total_input.view(-1, total_input.shape[-1])
if ctx.async_grad_reduce_scatter:
# Asynchronous reduce-scatter
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_parallel.dtype,
device=input_parallel.device).contiguous()
handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce-scatter scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1
grad_weight = total_input.t().matmul(grad_output)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.async_grad_reduce_scatter:
handle.wait()
else:
world_size = dist.get_world_size(process_group)
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
# do all gather in is async way
gather_handle = dist.all_gather(tensor_list, input_, group=process_group, async_op=True)
# calculate gradient and prepare data asynchronously with all-gather
# calculate
grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous()
# Convert the tensor shapes to 2D for execution compatibility
if len(grad_output.shape) > 2:
grad_output = grad_output.view(-1, grad_output.shape[-1])
grad_bias = grad_output.sum(dim=0) if use_bias else None
# prepare data
input_list = [
item.contiguous() for item in torch.chunk(grad_input, dist.get_world_size(process_group), dim=dim)
]
output = torch.empty(input_.shape, dtype=input_.dtype, device=input_.device).contiguous()
# wait until all-gather finished
gather_handle.wait()
# do reduce-scatter in async way
reducescatter_handle = dist.reduce_scatter(output, input_list, group=process_group, async_op=True)
input_parallel = torch.cat(tensor_list, dim=dim).contiguous()
# calculate gradient
if len(input_parallel.shape) > 2:
input_parallel = input_parallel.view(-1, input_parallel.shape[-1])
grad_weight = input_parallel.t().matmul(grad_output)
# wait until reduce-scatter finished
reducescatter_handle.wait()
return output, grad_weight, grad_bias, None, None, None, None
class _SplitForwardGatherBackward(torch.autograd.Function):
"""
Split the input and keep only the corresponding chuck to the rank.
@ -200,6 +436,26 @@ class _ReduceBackward(torch.autograd.Function):
return _reduce(grad_output, ctx.process_group), None
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
"""
@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
@ -235,9 +491,8 @@ def _gather(input_, dim=-1, process_group=None):
return input_
# all gather
rank = dist.get_rank(process_group)
input_ = input_.contiguous()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=process_group)
# concat
@ -246,24 +501,27 @@ def _gather(input_, dim=-1, process_group=None):
return output
class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatenate.
def _reduce_scatter(input_, dim=1, process_group=None):
""" Do reduce-scatter operation.
Args:
input_: input matrix.
parallel_mode: parallel mode.
dim: dimension
input_ (`torch.Tensor`): The input tensor from sequence parallel region.
dim (int): The dimension to perform reduce-scatter.
process_group (`torch.distributed.ProcessGroup`): The process group used for collective communication.
"""
world_size = dist.get_world_size(process_group)
if world_size == 1:
return input_
@staticmethod
def forward(ctx, input_, dim, process_group):
ctx.process_group = process_group
ctx.dim = dim
return _gather(input_, dim, process_group)
# reduce-scatter
new_shape = list(input_.shape)
assert new_shape[dim] % dist.get_world_size(process_group) == 0, \
f'The dimension to split ({new_shape[dim]}) is not a multiple of tensor parallel size ({dist.get_world_size(process_group)}). '
new_shape[dim] = new_shape[dim] // world_size
output = torch.empty(new_shape, dtype=input_.dtype, device=input_.device)
dist.reduce_scatter(output, input_, group=process_group)
@staticmethod
def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None
return output
def matmul_with_async_comm(input_, weight, bias, process_group, async_grad_allreduce):
@ -274,6 +532,22 @@ def linear_with_async_comm(input_, weight, bias, process_group, async_grad_allre
return LinearWithAsyncCommunication.apply(input_, weight, bias, process_group, async_grad_allreduce)
def linear_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _LinearWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim, overlap)
def linear_reducescatter_forward_gather_backward(input_, process_group, dim):
return _LinearWithReduceScatterForwardGatherBackward.apply(input_, process_group, dim)
def matmul_gather_forward_reducescatter_backward(input_, weight, bias, process_group, async_grad_reduce_scatter, dim,
overlap):
return _MatmulWithGatherForwardReduceScatterBackward.apply(input_, weight, bias, process_group,
async_grad_reduce_scatter, dim, overlap)
def gather_forward_split_backward(input_, dim, process_group):
return _GatherForwardSplitBackward.apply(input_, dim, process_group)

View File

@ -24,6 +24,8 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import (
gather_forward_split_backward,
linear_gather_forward_reducescatter_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
reduce_forward,
split_forward_gather_backward,
@ -50,6 +52,8 @@ class Linear1D_Col(ParallelModule):
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
@ -69,6 +73,9 @@ class Linear1D_Col(ParallelModule):
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = False,
seq_parallel: bool = False,
seq_parallel_dim: int = 1,
overlap: torch.cuda.Stream = None,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
@ -80,6 +87,9 @@ class Linear1D_Col(ParallelModule):
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.seq_parallel_dim = seq_parallel_dim
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
@ -180,7 +190,12 @@ class Linear1D_Col(ParallelModule):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.seq_parallel:
output_parallel = linear_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True,
self.seq_parallel_dim, self.overlap)
else:
output_parallel = linear_with_async_comm(input_parallel, self.weight, bias, self.process_group, True)
if self.gather_output:
# All-gather across the partitions.
@ -203,6 +218,8 @@ class Linear1D_Row(ParallelModule):
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
@ -221,6 +238,8 @@ class Linear1D_Row(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel: bool = False,
seq_parallel_dim: int = 1,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
@ -238,6 +257,8 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel = seq_parallel
self.seq_parallel_dim = seq_parallel_dim
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
@ -373,7 +394,11 @@ class Linear1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = F.linear(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group,
self.seq_parallel_dim)
else:
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:

View File

@ -10,6 +10,7 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
@ -56,13 +57,7 @@ class ParallelModule(nn.Module, ABC):
"""
for name, param in self._parameters.items():
if param is not None:
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
destination[prefix + name] = to_global(param_)
elif is_customized_distributed_tensor(param_):
destination[prefix + name] = to_global_for_customized_distributed_tensor(param_)
else:
destination[prefix + name] = param_
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:

View File

@ -25,7 +25,9 @@ from colossalai.tensor.d_tensor.api import (
from ._operation import (
gather_forward_split_backward,
linear_reducescatter_forward_gather_backward,
linear_with_async_comm,
matmul_gather_forward_reducescatter_backward,
matmul_with_async_comm,
reduce_backward,
reduce_forward,
@ -150,6 +152,7 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
device (`torch.device`): The device of parameters, defaults to None.
n_fused (int): The number items fused, defaults to 3 (QKV).
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
@ -173,6 +176,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
process_group: ProcessGroup = None,
async_communication: bool = False,
gather_output: bool = False,
seq_parallel: bool = False,
overlap: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
@ -185,6 +190,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
self.in_features = in_features
self.out_features = out_features
self.gather_output = gather_output
self.seq_parallel = seq_parallel
self.overlap = overlap
self.skip_bias_add = skip_bias_add
self.device = device
self.n_fused = n_fused
@ -296,15 +303,19 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
assert input_.shape[-1] == self.weight.shape[0], \
'Invalid shapes in Linear1D_Col forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
# input_parallel = input_
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
self.async_communication)
if self.seq_parallel:
input_parallel = input_
output_parallel = matmul_gather_forward_reducescatter_backward(input_parallel, self.weight, bias,
self.process_group, True, 1, self.overlap)
else:
# Set up backprop all-reduce.
input_parallel = reduce_backward(input_, self.process_group)
output_parallel = matmul_with_async_comm(input_parallel, self.weight, bias, self.process_group,
self.async_communication)
if self.gather_output:
# All-gather across the partitions.
@ -329,6 +340,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
parallel_input (bool): If set to ``True``, it's assumed that the input is split, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
which is preserved for kernel fusion, defaults to False
weight_initializer (:class:`typing.Callable`, optional):
The initializer of weight, defaults to kaiming uniform initializer.
@ -346,6 +358,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
seq_parallel: bool = False,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
@ -363,6 +376,7 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.seq_parallel = seq_parallel
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
@ -499,7 +513,10 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
output = torch.cat(output_parallel_list, dim=-1)
else:
output_parallel = torch.matmul(input_, self.weight)
output = reduce_forward(output_parallel, self.process_group)
if self.seq_parallel:
output = linear_reducescatter_forward_gather_backward(output_parallel, self.process_group, 1)
else:
output = reduce_forward(output_parallel, self.process_group)
if not self.skip_bias_add:
if self.bias is not None:

View File

@ -1,6 +1,6 @@
import math
import warnings
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@ -29,6 +29,8 @@ from transformers.models.bert.modeling_bert import (
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
class BertPipelineForwards:
@ -56,6 +58,7 @@ class BertPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None, # this is from the previous stage
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
# TODO(jianghai): add explaination of the output here.
r"""
@ -177,6 +180,17 @@ class BertPipelineForwards:
start_idx, end_idx = stage_index[0], stage_index[1]
# layer_outputs
layer_outputs = hidden_states if hidden_states is not None else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config is not None and shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
for idx, encoder_layer in enumerate(self.encoder.layer[start_idx:end_idx], start=start_idx):
if stage_manager.is_first_stage() and idx == 0:
encoder_attention_mask = encoder_extended_attention_mask
@ -223,11 +237,17 @@ class BertPipelineForwards:
all_cross_attentions = all_cross_attentions + \
(layer_outputs[2],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config is not None and shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
# end of a stage loop
sequence_output = layer_outputs[0] if layer_outputs is not None else None
sequence_output = hidden_states if hidden_states is not None else None
if stage_manager.is_last_stage():
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
@ -268,6 +288,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
logger = logging.get_logger(__name__)
@ -294,6 +315,7 @@ class BertPipelineForwards:
stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
@ -350,6 +372,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.FloatTensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@ -404,7 +427,8 @@ class BertPipelineForwards:
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states if hidden_states is not None else None,
stage_index=stage_index)
stage_index=stage_index,
shard_config=shard_config)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
@ -457,6 +481,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -491,6 +516,7 @@ class BertPipelineForwards:
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=shard_config,
)
if stage_manager.is_last_stage():
@ -532,6 +558,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**kwargs,
):
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
@ -594,7 +621,8 @@ class BertPipelineForwards:
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index)
stage_index=stage_index,
shard_config=shard_config)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
@ -636,6 +664,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -666,7 +695,8 @@ class BertPipelineForwards:
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index)
stage_index=stage_index,
shard_config=shard_config)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
@ -726,6 +756,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -742,21 +773,20 @@ class BertPipelineForwards:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
outputs = BertPipelineForwards.bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
)
outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=shard_config)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
@ -799,6 +829,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -843,6 +874,7 @@ class BertPipelineForwards:
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=shard_config,
)
if stage_manager.is_last_stage():
pooled_output = outputs[1]
@ -886,6 +918,7 @@ class BertPipelineForwards:
hidden_states: Optional[torch.Tensor] = None,
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
# NOTE: the arg start_position and end_position are used only for the last stage
r"""
@ -909,21 +942,20 @@ class BertPipelineForwards:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
outputs = BertPipelineForwards.bert_model_forward(
self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
)
outputs = BertPipelineForwards.bert_model_forward(self.bert,
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
hidden_states=hidden_states,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=shard_config)
if stage_manager.is_last_stage():
sequence_output = outputs[0]
@ -1101,3 +1133,153 @@ def get_jit_fused_bert_output_forward():
return hidden_states
return forward
def bert_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
the model is configured as a decoder.
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
batch_size, seq_length = input_shape
device = input_ids.device if input_ids is not None else inputs_embeds.device
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if attention_mask is None:
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
if token_type_ids is None:
if hasattr(self.embeddings, "token_type_ids"):
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
embedding_output = split_forward_gather_backward(embedding_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if encoder_hidden_states is not None:
encoder_hidden_states = split_forward_gather_backward(
encoder_hidden_states, dim=1, process_group=shard_config.tensor_parallel_process_group)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
# When sequence parallelism done, gather the output tensor in forward and split it in backward
sequence_output = gather_forward_split_backward(sequence_output,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
)
return forward

View File

@ -23,6 +23,10 @@ from transformers.models.bloom.modeling_bloom import (
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
logger = logging.get_logger(__name__)
def build_bloom_alibi_tensor_fn(process_group: ProcessGroup) -> torch.Tensor:
@ -111,6 +115,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], 'BaseModelOutputWithPastAndCrossAttentions']:
@ -205,6 +210,13 @@ class BloomPipelineForwards:
past_key_values_length=past_key_values_length,
)
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
start_idx, end_idx = stage_index[0], stage_index[1]
for i, (block, layer_past) in enumerate(zip(self.h[start_idx:end_idx], past_key_values[start_idx:end_idx]),
start=start_idx):
@ -248,6 +260,12 @@ class BloomPipelineForwards:
all_self_attentions = all_self_attentions + \
(outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if stage_manager.is_last_stage():
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
@ -287,6 +305,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments):
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@ -327,7 +346,8 @@ class BloomPipelineForwards:
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
stage_index=stage_index,
shard_config=shard_config)
past_key_values = None
all_hidden_states = None
all_self_attentions = None
@ -380,6 +400,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments,
):
r"""
@ -424,6 +445,7 @@ class BloomPipelineForwards:
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
@ -503,6 +525,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
**deprecated_arguments,
):
r"""
@ -547,6 +570,7 @@ class BloomPipelineForwards:
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
@ -597,6 +621,7 @@ class BloomPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@ -632,6 +657,7 @@ class BloomPipelineForwards:
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
past_key_values = None
all_hidden_states = None
@ -700,8 +726,7 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
batch_size, tgt_len, _ = query_layer.size()
_, kv_length, _, _ = key_layer.size()
@ -896,3 +921,156 @@ def get_jit_fused_bloom_gelu_forward():
return self.bloom_gelu_forward(x, bias)
return forward
def get_bloom_sequence_parallel_forward_fn(shard_config: ShardConfig):
from transformers import BloomModel
def forward(
self: BloomModel,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
if deprecated_arguments.pop("position_ids", False) is not False:
# `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
warnings.warn(
"`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
" passing `position_ids`.",
FutureWarning,
)
if len(deprecated_arguments) > 0:
raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape batch_size x num_heads x N x N
# head_mask has shape n_layer x batch x num_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
hidden_states = self.word_embeddings_layernorm(inputs_embeds)
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
# Compute alibi tensor: check build_alibi_tensor documentation
seq_length_with_past = seq_length
past_key_values_length = 0
if past_key_values[0] is not None:
past_key_values_length = past_key_values[0][0].shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
else:
attention_mask = attention_mask.to(hidden_states.device)
alibi = self.build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
causal_mask = self._prepare_attn_mask(
attention_mask,
input_shape=(batch_size, seq_length),
past_key_values_length=past_key_values_length,
)
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
alibi,
causal_mask,
layer_past,
head_mask[i],
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=causal_mask,
head_mask=head_mask[i],
use_cache=use_cache,
output_attentions=output_attentions,
alibi=alibi,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
# Add last hidden state
hidden_states = self.ln_f(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
return forward

View File

@ -9,6 +9,8 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutpu
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
@ -146,6 +148,7 @@ class ChatGLMPipelineForwards:
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None,
):
logger = logging.get_logger(__name__)
output_hidden_states = (output_hidden_states
@ -198,6 +201,11 @@ class ChatGLMPipelineForwards:
all_self_attentions = None
all_hidden_states = () if output_hidden_states else None
start_idx, end_idx = stage_index[0], stage_index[1]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
for idx in range(start_idx, end_idx):
layer = self.encoder._get_layer(idx)
if output_hidden_states:
@ -214,6 +222,11 @@ class ChatGLMPipelineForwards:
hidden_states, kv_cache = layer_ret
if use_cache:
presents = presents + (kv_cache,)
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if stage_manager.is_last_stage():
@ -233,23 +246,22 @@ class ChatGLMPipelineForwards:
return {'hidden_states': hidden_states}
@staticmethod
def chatglm_for_conditional_generation_forward(
self: ChatGLMForConditionalGeneration,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
):
def chatglm_for_conditional_generation_forward(self: ChatGLMForConditionalGeneration,
input_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
return_last_logit: Optional[bool] = False,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None):
logger = logging.get_logger(__name__)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
@ -266,6 +278,7 @@ class ChatGLMPipelineForwards:
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
shard_config=shard_config,
)
if stage_manager.is_last_stage():
hidden_states = transformer_outputs[0]
@ -296,3 +309,91 @@ class ChatGLMPipelineForwards:
)
else:
return transformer_outputs
def get_chatglm_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
full_attention_mask: Optional[torch.BoolTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (return_dict if return_dict is not None else self.config.use_return_dict)
batch_size, seq_length = input_ids.shape
if inputs_embeds is None:
inputs_embeds = self.embedding(input_ids)
if self.pre_seq_len is not None:
if past_key_values is None:
past_key_values = self.get_prompt(
batch_size=batch_size,
device=input_ids.device,
dtype=inputs_embeds.dtype,
)
if attention_mask is not None:
attention_mask = torch.cat(
[
attention_mask.new_ones((batch_size, self.pre_seq_len)),
attention_mask,
],
dim=-1,
)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
rotary_pos_emb = rotary_pos_emb[position_ids]
else:
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
# Run encoder.
# [seq_len, batch_size, hidden_size] -> [seq_len/TP_size, batch_size, hidden_size]
inputs_embeds = split_forward_gather_backward(inputs_embeds,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
inputs_embeds,
full_attention_mask,
rotary_pos_emb=rotary_pos_emb,
kv_caches=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
)
hidden_states = gather_forward_split_backward(hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group)
if not return_dict:
return tuple(v for v in [
hidden_states,
presents,
all_hidden_states,
all_self_attentions,
] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
return forward

View File

@ -21,6 +21,8 @@ from transformers.models.gpt2.modeling_gpt2 import (
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
from colossalai.shardformer.shard import ShardConfig
class GPT2PipelineForwards:
@ -47,7 +49,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, BaseModelOutputWithPastAndCrossAttentions]:
# This function is modified on the basis of transformers.models.gpt2.modeling_gpt2.GPT2Model.forward.
# Please refer to original code of transformers for more details.
@ -145,7 +148,7 @@ class GPT2PipelineForwards:
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
@ -159,6 +162,13 @@ class GPT2PipelineForwards:
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
if shard_config.enable_sequence_parallelism:
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
# Going through held blocks.
start_idx, end_idx = stage_index[0], stage_index[1]
for i in range(start_idx, end_idx):
@ -212,6 +222,12 @@ class GPT2PipelineForwards:
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# When sequence parallelism done, gather the output tensor in forward and split it in backward
if shard_config.enable_sequence_parallelism:
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
if stage_manager.is_last_stage():
hidden_states = self.ln_f(hidden_states)
@ -257,7 +273,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
@ -285,7 +302,8 @@ class GPT2PipelineForwards:
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@ -335,7 +353,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, GPT2DoubleHeadsModelOutput]:
r"""
mc_token_ids (`torch.LongTensor` of shape `(batch_size, num_choices)`, *optional*, default to index of the last token of the input):
Index of the classification token in each input sequence. Selected in the range `[0, input_ids.size(-1) -
@ -367,7 +386,8 @@ class GPT2PipelineForwards:
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@ -421,7 +441,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for position (index) of the start of the labelled span for computing the token classification loss.
@ -449,7 +470,8 @@ class GPT2PipelineForwards:
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@ -508,7 +530,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@ -534,7 +557,8 @@ class GPT2PipelineForwards:
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@ -578,7 +602,8 @@ class GPT2PipelineForwards:
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
stage_index: Optional[List[int]] = None,
shard_config: ShardConfig = None) -> Union[Dict, Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
@ -613,7 +638,8 @@ class GPT2PipelineForwards:
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index)
stage_index=stage_index,
shard_config=shard_config)
# If not at the last stage, return hidden_states as in GPT2Model
if not stage_manager.is_last_stage():
@ -696,7 +722,6 @@ def get_gpt2_flash_attention_forward():
output_attentions: Optional[bool] = False,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
_, tgt_len, _ = hidden_states.size()
assert tgt_len % 4 == 0, "Flash Attention Error: The sequence length should be a multiple of 4."
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
@ -753,3 +778,210 @@ def get_gpt2_flash_attention_forward():
return outputs
return forward
def gpt2_sequence_parallel_forward_fn(shard_config: ShardConfig):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
# GPT2Attention mask.
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = attention_mask.view(batch_size, -1)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
hidden_states = hidden_states + token_type_embeds
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...")
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
# split the input tensor along sequence dimension
# [batch_size, seq_len, hidden_size] -> [batch_size, seq_len/TP_size, hidden_size]
hidden_states = split_forward_gather_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))
# When sequence parallelism done, gather the output tensor in forward and split it in backward
hidden_states = gather_forward_split_backward(hidden_states,
dim=1,
process_group=shard_config.tensor_parallel_process_group)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
return forward

View File

@ -1,7 +1,26 @@
from typing import Optional, Tuple
import logging
import random
from typing import Dict, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
SequenceClassifierOutput,
)
from transformers.models.whisper.modeling_whisper import (
WhisperEncoder,
WhisperForAudioClassification,
WhisperForConditionalGeneration,
WhisperModel,
)
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
def get_whisper_flash_attention_forward():
@ -247,3 +266,697 @@ def get_jit_fused_whisper_decoder_layer_forward():
return outputs
return forward
class WhisperPipelineForwards:
'''
This class serves as a micro library for forward function substitution of Llama models
under pipeline setting.
'''
@staticmethod
def whisper_encoder_forward(
self: WhisperEncoder,
input_features,
attention_mask=None,
head_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_states=None,
all_attentions=None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
):
r"""
Args:
input_features (`torch.LongTensor` of shape `(batch_size, feature_size, sequence_length)`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
attention_mask (`torch.Tensor`)`, *optional*):
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
but it is not used. By default the silence in the input log mel spectrogram are ignored.
head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
logger = logging.get_logger(__name__)
stage = stage_manager.stage
at_first_stage = (stage == 0)
at_last_stage = (stage == decoder_starting_stage - 1)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Process inputs if at the first stage of encoder.
if at_first_stage:
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
inputs_embeds = inputs_embeds.permute(0, 2, 1)
embed_pos = self.embed_positions.weight
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
# check if head_mask has a correct number of layers specified if desired
if head_mask is not None:
assert head_mask.size()[0] == (
len(self.layers)
), f"The head_mask should be specified for {len(self.layers)} layers, but it is for {head_mask.size()[0]}."
else:
if hidden_states is None:
raise ValueError(
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
start_idx, end_idx = stage_index[0], stage_index[1]
for idx in range(start_idx, end_idx):
encoder_layer = self.layers[idx]
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states,
None,
(head_mask[idx] if head_mask is not None else None),
)
else:
layer_outputs = encoder_layer(
hidden_states,
None,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
output_attentions=output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if at_last_stage:
hidden_states = self.layer_norm(hidden_states)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return BaseModelOutput(last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions)
else:
return {'hidden_states': hidden_states, 'head_mask': head_mask}
@staticmethod
def whisper_decoder_forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
):
r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using [`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder.
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in `[0, 1]`:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
all `decoder_input_ids` of shape `(batch_size, sequence_length)`. inputs_embeds (`torch.FloatTensor` of
shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing
`input_ids` you can choose to directly pass an embedded representation. This is useful if you want more
control over how to convert `input_ids` indices into associated vectors than the model's internal
embedding lookup matrix.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
logger = logging.get_logger(__name__)
stage = stage_manager.stage
at_first_stage = (stage == decoder_starting_stage)
at_last_stage = (stage == stage_manager.num_stages - 1)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
if attn_mask is not None:
assert attn_mask.size()[0] == (len(self.layers)), (
f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for"
f" {head_mask.size()[0]}.")
# past_key_values_length
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if at_first_stage:
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
# embed positions
if input_ids is not None:
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
else:
positions = self.embed_positions(inputs_embeds, past_key_values_length=past_key_values_length)
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, inputs_embeds,
past_key_values_length)
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
)
use_cache = False
else:
if hidden_states is None:
raise ValueError(
"hidden_states shouldn't be None for stages other than the first stage of encoder/decoder.")
input_shape = hidden_states.size()[:-1]
attention_mask = self._prepare_decoder_attention_mask(attention_mask, input_shape, hidden_states,
past_key_values_length)
start_idx, end_idx = stage_index[0], stage_index[1]
for idx in range(start_idx, end_idx):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
decoder_layer = self.layers[idx]
if output_hidden_states:
all_hidden_states += (hidden_states,)
dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states,
attention_mask,
encoder_hidden_states,
None, # encoder attention mask
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, # past_key_value
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
cross_attn_layer_head_mask=(cross_attn_head_mask[idx]
if cross_attn_head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
if output_attentions:
all_self_attns += (layer_outputs[1],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
if at_last_stage:
hidden_states = self.layer_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(
v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
if v is not None)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
)
else:
return {
'head_mask': head_mask,
'cross_attn_head_mask': cross_attn_head_mask,
'hidden_states': hidden_states,
}
@staticmethod
def whisper_model_forward(
self: WhisperModel,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
):
r"""
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoFeatureExtractor, WhisperModel
>>> from datasets import load_dataset
>>> model = WhisperModel.from_pretrained("openai/whisper-base")
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-base")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
>>> list(last_hidden_state.shape)
[1, 2, 512]
```"""
# TODO: left the recording kv-value tensors as () or None type, this feature may be added in the future.
if past_key_values:
logger.warning_once('Non-empty past_key_values is not supported for pipeline models at the moment.')
past_key_values = None
if output_attentions:
logger.warning_once('output_attentions=True is not supported for pipeline models at the moment.')
output_attentions = False
if output_hidden_states:
logger.warning_once('output_hidden_states=True is not supported for pipeline models at the moment.')
output_hidden_states = False
if use_cache:
logger.warning_once('use_cache=True is not supported for pipeline models at the moment.')
use_cache = False
logger = logging.get_logger(__name__)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
in_decoder = stage_manager.stage >= decoder_starting_stage
if not in_decoder:
if encoder_outputs is None:
input_features = self._mask_input_features(input_features, attention_mask=attention_mask)
encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
self.encoder,
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
if stage_manager.stage == decoder_starting_stage - 1:
# last stage of encoder
return {'encoder_hidden_states': encoder_outputs[0]}
else:
return encoder_outputs
# If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
encoder_outputs = BaseModelOutput(
last_hidden_state=encoder_outputs[0],
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
at_last_decoder_stage = stage_manager.is_last_stage()
at_first_decoder_stage = stage_manager.stage == decoder_starting_stage
if encoder_outputs is not None:
encoder_hidden_states = encoder_outputs[0]
elif encoder_hidden_states is None:
raise ValueError("Non-empty encoder_hidden_states should be passed in at decoder stages.")
if not at_first_decoder_stage and hidden_states is None:
raise ValueError("If not at the first layer of decoder, non-empty hidden_states must be provided.")
# decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
decoder_outputs = WhisperPipelineForwards.whisper_decoder_forward(self.decoder,
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
# Directly return outputs of overloaded Whisper forward if not at last stage.
if not at_last_decoder_stage:
# encoder_hidden_states should be passed to the next stage
decoder_outputs['encoder_hidden_states'] = encoder_hidden_states
return decoder_outputs
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_hidden_states,
)
@staticmethod
def whisper_for_conditional_generation_forward(
self: WhisperForConditionalGeneration,
input_features: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the language modeling loss. Indices should either be in `[0, ..., config.vocab_size]`
or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored (masked), the loss is
only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> import torch
>>> from transformers import AutoProcessor, WhisperForConditionalGeneration
>>> from datasets import load_dataset
>>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en")
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="pt")
>>> input_features = inputs.input_features
>>> generated_ids = model.generate(inputs=input_features)
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> transcription
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id,
self.config.decoder_start_token_id)
in_decoder = stage_manager.stage >= decoder_starting_stage
at_last_decoder_stage = stage_manager.is_last_stage()
outputs = WhisperPipelineForwards.whisper_model_forward(self.model,
input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
if not in_decoder:
return outputs
if not at_last_decoder_stage:
# encoder_hidden_states should be passed to the next stage
outputs['encoder_hidden_states'] = encoder_hidden_states
return outputs
lm_logits = self.proj_out(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(lm_logits.device)
loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.reshape(-1))
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)
@staticmethod
def whisper_for_audio_classification_forward(
self: WhisperForAudioClassification,
input_features: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
stage_manager: Optional[PipelineStageManager] = None,
hidden_states: Optional[torch.FloatTensor] = None,
encoder_states=None,
all_attentions=None,
stage_index: Optional[List[int]] = None,
decoder_starting_stage: Optional[int] = None,
):
r"""
This function is modified on the basis of transformers.models.whisper.modeling_whisper.WhisperForAudioClassification.forward.
Please refer to original code of transformers for more details.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (output_hidden_states
if output_hidden_states is not None else self.config.output_hidden_states)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# audio_classification only holds encoder
encoder_outputs = WhisperPipelineForwards.whisper_encoder_forward(
self.encoder,
input_features,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
stage_manager=stage_manager,
hidden_states=hidden_states,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage,
)
if not stage_manager.is_last_stage():
return encoder_outputs
if self.config.use_weighted_layer_sum:
hidden_states = torch.stack(encoder_outputs, dim=1)
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
else:
hidden_states = encoder_outputs[0]
hidden_states = self.projector(hidden_states)
pooled_output = hidden_states.mean(dim=1)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
# move labels to correct device to enable PP
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + encoder_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)

View File

@ -125,9 +125,9 @@ _POLICY_LIST = {
# ChatGLM
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel":
PolicyLocation(file_name="chatglm", class_name="ChatGLMModelPolicy"),
PolicyLocation(file_name="chatglm2", class_name="ChatGLMModelPolicy"),
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration":
PolicyLocation(file_name="chatglm", class_name="ChatGLMForConditionalGenerationPolicy"),
PolicyLocation(file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"),
}

View File

@ -11,17 +11,12 @@ from torch.nn import Module
from colossalai.pipeline.stage_manager import PipelineStageManager
from ..layer.parallel_module import ParallelModule
from ..shard.shard_config import ShardConfig
__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]
class ParallelModule():
def __init__(self):
pass
@dataclass
class SubModuleReplacementDescription:
r"""

View File

@ -10,6 +10,7 @@ import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.bert import (
BertPipelineForwards,
bert_sequence_parallel_forward_fn,
get_bert_flash_attention_forward,
get_jit_fused_bert_output_forward,
get_jit_fused_bert_self_output_forward,
@ -47,13 +48,15 @@ class BertPolicy(Policy):
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertLayer,
BertModel,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BertLayer] = ModulePolicyDescription(attribute_replacement={
"attention.self.all_head_size":
@ -69,14 +72,26 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.self.query",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="attention.self.dropout",
@ -85,6 +100,7 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
@ -93,10 +109,15 @@ class BertPolicy(Policy):
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
kwargs={
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
kwargs={"seq_parallel": use_sequence_parallel},
),
SubModuleReplacementDescription(
suffix="output.dropout",
@ -115,6 +136,12 @@ class BertPolicy(Policy):
)
])
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': bert_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=BertModel)
# optimization configuration
if self.shard_config.enable_fused_normalization:
# Handle bert layer
@ -141,20 +168,26 @@ class BertPolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[BertSelfAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_bert_flash_attention_forward(),
})
},
policy=policy,
target_key=BertSelfAttention)
# use jit operator
if self.shard_config.enable_jit_fused:
policy[BertSelfOutput] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BertOutput] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=BertSelfOutput)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bert_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=BertOutput)
return policy
@ -205,7 +238,13 @@ class BertPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.encoder.layer), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)

View File

@ -285,34 +285,30 @@ class BlipPolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_blip2_flash_attention_forward(),
})
},
policy=policy,
target_key=Blip2Attention)
# use jit operator
if self.shard_config.enable_jit_fused:
policy[Blip2QFormerSelfOutput] = ModulePolicyDescription(
method_replacement={
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[Blip2QFormerOutput] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_blip2_QFormer_self_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=Blip2QFormerSelfOutput)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_blip2_QFormer_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=Blip2QFormerOutput)
return policy
def postprocess(self):
binding_map = {
'language_model.model.decoder.embed_tokens': 'language_model.lm_head',
}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model

View File

@ -12,6 +12,7 @@ from ..modeling.bloom import (
BloomPipelineForwards,
build_bloom_alibi_tensor_fn,
get_bloom_flash_attention_forward,
get_bloom_sequence_parallel_forward_fn,
get_jit_fused_bloom_attention_forward,
get_jit_fused_bloom_gelu_forward,
get_jit_fused_bloom_mlp_forward,
@ -43,6 +44,8 @@ class BloomPolicy(Policy):
policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@ -53,11 +56,14 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
kwargs={'seq_parallel': use_sequence_parallel}),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
@ -65,11 +71,14 @@ class BloomPolicy(Policy):
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
),
kwargs={
'seq_parallel': use_sequence_parallel,
'overlap': overlap
}),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
),
kwargs={'seq_parallel': use_sequence_parallel}),
])
policy[BloomModel] = ModulePolicyDescription(
@ -116,26 +125,40 @@ class BloomPolicy(Policy):
policy=policy,
target_key=BloomBlock)
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': get_bloom_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=BloomModel)
if self.shard_config.enable_flash_attention:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_bloom_flash_attention_forward(),
'dropout_add': get_dropout_add_func()
})
'dropout_add': get_dropout_add_func(),
},
policy=policy,
target_key=BloomAttention)
# enable jit fused operator
if self.shard_config.enable_jit_fused:
policy[BloomAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BloomMLP] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=BloomAttention)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_mlp_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[BloomGelu] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=BloomMLP)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_bloom_gelu_forward(),
'bloom_gelu_forward': get_jit_fused_gelu_forward_func(),
})
},
policy=policy,
target_key=BloomGelu)
return policy
@ -154,7 +177,13 @@ class BloomPolicy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement,
policy=policy,
target_key=model_cls)

View File

@ -7,7 +7,7 @@ from transformers.modeling_outputs import BaseModelOutputWithPast
import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.modeling.chatglm import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2 import ChatGLMPipelineForwards
from colossalai.shardformer.modeling.chatglm2_6b.configuration_chatglm import ChatGLMConfig
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
ChatGLMForConditionalGeneration,
@ -15,7 +15,11 @@ from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import (
GLMBlock,
)
from ..modeling.chatglm import get_flash_core_attention_forward, get_jit_fused_glm_block_forward
from ..modeling.chatglm2 import (
get_chatglm_sequence_parallel_forward_fn,
get_flash_core_attention_forward,
get_jit_fused_glm_block_forward,
)
from ..modeling.jit import get_jit_fused_dropout_add_func
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -37,6 +41,11 @@ class ChatGLMPolicy(Policy):
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if self.pipeline_stage_manager is not None:
# the batch_size_dim is bounded to Model
bsz_dim = 1
setattr(self.model, 'batch_size_dim', bsz_dim)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
@ -45,8 +54,9 @@ class ChatGLMPolicy(Policy):
policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[ChatGLMModel] = ModulePolicyDescription(attribute_replacement={},
sub_module_replacement=[
SubModuleReplacementDescription(
@ -55,36 +65,43 @@ class ChatGLMPolicy(Policy):
)
])
policy[GLMBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.projection_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads) //
self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition":
self.model.config.kv_channels * self.model.config.num_attention_heads //
self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
policy[GLMBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.projection_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads) //
self.shard_config.tensor_parallel_size,
"self_attention.qkv_hidden_size":
(self.model.config.kv_channels * self.model.config.num_attention_heads * 3) //
self.shard_config.tensor_parallel_size,
"self_attention.core_attention.num_attention_heads_per_partition":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"self_attention.core_attention.hidden_size_per_partition":
self.model.config.kv_channels * self.model.config.num_attention_heads //
self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
kwargs={
'seq_parallel': use_sequence_parallel,
'seq_parallel_dim': 0,
'overlap': overlap
}),
SubModuleReplacementDescription(suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
kwargs={
'seq_parallel': use_sequence_parallel,
'seq_parallel_dim': 0
}),
SubModuleReplacementDescription(
suffix="self_attention.core_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
if not self.model.config.rmsnorm:
@ -124,16 +141,27 @@ class ChatGLMPolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[CoreAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_flash_core_attention_forward(),
})
},
policy=policy,
target_key=CoreAttention)
# use sequence parallel
if use_sequence_parallel:
self.append_or_create_method_replacement(
description={'forward': get_chatglm_sequence_parallel_forward_fn(self.shard_config)},
policy=policy,
target_key=ChatGLMModel)
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[GLMBlock] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_glm_block_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=GLMBlock)
return policy
@ -178,7 +206,13 @@ class ChatGLMPolicy(Policy):
layers_per_stage = Policy.distribute_layers(module.num_layers, stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)

View File

@ -6,7 +6,7 @@ from torch import Tensor, nn
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward
from ..modeling.gpt2 import GPT2PipelineForwards, get_gpt2_flash_attention_forward, gpt2_sequence_parallel_forward_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@ -37,7 +37,8 @@ class GPT2Policy(Policy):
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2Model
policy = {}
use_sequence_parallel = self.shard_config.enable_sequence_parallelism
overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
@ -49,47 +50,55 @@ class GPT2Policy(Policy):
target_module=col_nn.DropoutForParallelInput,
),
])
policy[GPT2Block] = ModulePolicyDescription(attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
policy[GPT2Block] = ModulePolicyDescription(
attribute_replacement={
"attn.embed_dim": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel": use_sequence_parallel,
}),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
"seq_parallel": use_sequence_parallel,
"overlap": overlap
},
),
SubModuleReplacementDescription(suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
kwargs={
"seq_parallel": use_sequence_parallel,
}),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
@ -117,9 +126,15 @@ class GPT2Policy(Policy):
target_key=GPT2Block)
if self.shard_config.enable_flash_attention:
policy[GPT2Attention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_gpt2_flash_attention_forward(),
})
},
policy=policy,
target_key=GPT2Attention)
if self.shard_config.enable_sequence_parallelism:
policy[GPT2Model].method_replacement = {"forward": gpt2_sequence_parallel_forward_fn(self.shard_config)}
return policy
def postprocess(self):
@ -160,7 +175,13 @@ class GPT2Policy(Policy):
layers_per_stage = Policy.distribute_layers(len(module.h), stage_manager.num_stages)
stage_index = Policy.get_stage_index(layers_per_stage, stage_manager.stage)
method_replacement = {'forward': partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
shard_config=self.shard_config)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)

View File

@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Union
@ -35,6 +36,10 @@ class LlamaPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Llama dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[LlamaDecoderLayer] = ModulePolicyDescription(
attribute_replacement={
@ -105,9 +110,11 @@ class LlamaPolicy(Policy):
target_key=LlamaModel)
if self.shard_config.enable_flash_attention:
policy[LlamaAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_llama_flash_attention_forward(),
})
},
policy=policy,
target_key=LlamaAttention)
return policy

View File

@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Callable, Dict, List
@ -39,6 +40,9 @@ class OPTPolicy(Policy):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("OPT dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[OPTDecoder] = ModulePolicyDescription(sub_module_replacement=[
@ -100,16 +104,20 @@ class OPTPolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[OPTAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_opt_flash_attention_forward(),
})
},
policy=policy,
target_key=OPTAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[OPTDecoderLayer] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_opt_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=OPTDecoderLayer)
return policy

View File

@ -199,12 +199,16 @@ class SamPolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[SamAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_sam_flash_attention_forward(),
})
policy[SamVisionAttention] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=SamAttention)
self.append_or_create_method_replacement(description={
'forward': get_sam_vision_flash_attention_forward(),
})
},
policy=policy,
target_key=SamVisionAttention)
return policy

View File

@ -1,6 +1,8 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple
import numpy as np
from torch import Tensor, nn
from colossalai.shardformer.layer import (
@ -58,6 +60,10 @@ class T5BasePolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("T5 dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[T5Stack] = ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
@ -178,24 +184,33 @@ class T5BasePolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[T5Attention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_t5_flash_attention_forward(),
})
},
policy=policy,
target_key=T5Attention)
# use jit operator
if self.shard_config.enable_jit_fused:
policy[T5LayerFF] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_T5_layer_ff_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerSelfAttention] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=T5LayerFF)
self.append_or_create_method_replacement(description={
'forward': get_T5_layer_self_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[T5LayerCrossAttention] = ModulePolicyDescription(method_replacement={
},
policy=policy,
target_key=T5LayerSelfAttention)
self.append_or_create_method_replacement(description={
'forward': get_T5_layer_cross_attention_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=T5LayerCrossAttention)
return policy
def postprocess(self):
@ -228,13 +243,7 @@ class T5BasePolicy(Policy):
def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
num_encoder_stages = 0
optimal_diff = 2**31 - 1
for i in range(1, num_stages):
attempt = objective(i)
if attempt < optimal_diff:
num_encoder_stages = i
optimal_diff = attempt
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)

View File

@ -1,3 +1,4 @@
import warnings
from typing import Callable, Dict, List, Union
import torch.nn as nn
@ -32,6 +33,10 @@ class ViTPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("Vit dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[ViTEmbeddings] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
@ -90,16 +95,20 @@ class ViTPolicy(Policy):
# use flash attention
if self.shard_config.enable_flash_attention:
policy[ViTSelfAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_vit_flash_self_attention_forward(),
})
},
policy=policy,
target_key=ViTSelfAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[ViTOutput] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_vit_output_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=ViTOutput)
return policy
def new_model_class(self):

View File

@ -1,10 +1,17 @@
import warnings
from functools import partial
from typing import Callable, Dict, List, Tuple
import numpy as np
import torch.nn as nn
from torch import Tensor
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.whisper import (
WhisperPipelineForwards,
get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward,
get_whisper_flash_attention_forward,
@ -12,7 +19,8 @@ from ..modeling.whisper import (
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification'
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy',
'WhisperForAudioClassificationPolicy'
]
@ -26,7 +34,6 @@ class WhisperPolicy(Policy):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
@ -45,6 +52,11 @@ class WhisperPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
@ -191,20 +203,26 @@ class WhisperPolicy(Policy):
# enable flash attention
if self.shard_config.enable_flash_attention:
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_whisper_flash_attention_forward(),
})
},
policy=policy,
target_key=WhisperAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=WhisperDecoderLayer)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=WhisperEncoderLayer)
return policy
@ -223,6 +241,146 @@ class WhisperPolicy(Policy):
def postprocess(self):
return self.model
@staticmethod
def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int,
num_stages: int) -> Tuple[List[int], int]:
"""
Distribute whisper layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
raise ValueError("The number of encoder layers for whisper must be a positive integer.")
# number of layers should be large enough to fill in every stage
if num_encoder_layers + num_decoder_layers < num_stages:
raise ValueError("The total number of layers can't be smaller than number of stages.")
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optmized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
# s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_whisper_stage_index(layers_per_stage: List[int], stage: int,
decoder_starting_stage: int) -> Tuple[bool, int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
def get_held_layers(self) -> List[nn.Module]:
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
model = self.model.model
else:
model = None
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
# whisper for audio classification holds encoder only
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
held_layers = []
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder
if stage_manager.is_first_stage():
held_layers.append(encoder.embed_positions)
held_layers.append(encoder.conv1)
held_layers.append(encoder.conv2)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.layer_norm)
held_layers.extend(encoder.layers[start_idx:end_idx])
else:
# current stage is in whisper's decoder
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
# the case encoder and decoder put in same stage should be add in the future.
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.embed_positions)
if stage_manager.is_last_stage():
held_layers.append(decoder.layer_norm)
held_layers.extend(decoder.layers[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
model = self.model.model
else:
model = None
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
# WhisperModel
class WhisperModelPolicy(WhisperPolicy):
@ -230,6 +388,24 @@ class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import WhisperModel
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperModel,
new_forward=WhisperPipelineForwards.whisper_model_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"no shared params in whisper model"
return []
# WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
@ -238,20 +414,82 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy)
return module_policy
from transformers import WhisperForConditionalGeneration
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration,
new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
policy=policy)
return policy
def postprocess(self):
binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.proj_out)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
model = module.model
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers,
stage_manager.num_stages)
shared_params = []
shared_embedding = {}
if id(module.proj_out) == id(model.decoder.embed_tokens):
shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens
shared_embedding[stage_manager.num_stages - 1] = module.proj_out
if len(shared_embedding) > 0:
shared_params.append(shared_embedding)
return shared_params
return []
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def preprocess(self):
return self.model
def module_policy(self):
from transformers import WhisperForAudioClassification
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForAudioClassification,
new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.projector)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []

View File

@ -20,6 +20,8 @@ class ShardConfig:
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
@ -28,6 +30,8 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
# pipeline_parallel_size: int
# data_parallel_size: int
@ -40,6 +44,11 @@ class ShardConfig:
return self._tensor_parallel_size
def __post_init__(self):
if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
raise ValueError(
"enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True")
if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
@ -57,3 +66,5 @@ class ShardConfig:
self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True
self.enable_sequence_parallelism = True
self.enable_sequence_overlap = True

View File

@ -92,22 +92,21 @@ class ModelSharder(object):
param_replacement (List[Callable]): The function list to get parameter shard information in policy
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
# released layers are not shardable
can_replace_param_or_layer = include is None or module in include
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
if param_replacement is not None and can_replace_param_or_layer:
if param_replacement is not None and (include is None or module in include):
self._replace_param(module, param_replacement)
if method_replacement is not None:
self._replace_method(module, method_replacement)
if sub_module_replacement is not None and can_replace_param_or_layer:
self._replace_sub_module(module, sub_module_replacement)
if sub_module_replacement is not None:
self._replace_sub_module(module, sub_module_replacement, include)
for name, child in module.named_children():
self._recursive_replace_layer(child,
@ -154,18 +153,17 @@ class ModelSharder(object):
bound_method = MethodType(new_method, module)
setattr(module, method_name, bound_method)
def _replace_sub_module(
self,
org_layer: nn.Module,
sub_module_replacement: List[SubModuleReplacementDescription],
) -> None:
def _replace_sub_module(self,
org_layer: nn.Module,
sub_module_replacement: List[SubModuleReplacementDescription],
include: Optional[Set[nn.Module]] = None) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer (torch.nn.Module): The origin layer object to shard
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
for description in sub_module_replacement:
suffix = description.suffix
@ -174,9 +172,12 @@ class ModelSharder(object):
assert target_module is not None, 'target_module should not be None'
# TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix, ignore=True)
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
continue
assert not isinstance(native_sub_module, target_module), \
f"The module with suffix {suffix} has been replaced, please check the policy"

View File

@ -10,8 +10,9 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import _cast_float, free_storage
@ -733,7 +734,7 @@ class GeminiDDP(ModelWrapper):
Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
sharder = _StateDictSharder(max_shard_size)
sharder = StateDictSharder(max_shard_size)
# get the mapping between copies and fp16 parameters
fp16_to_fp32 = dict()
@ -755,7 +756,7 @@ class GeminiDDP(ModelWrapper):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)
block, block_size = sharder.append(prefix + name, gathered_param)
block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
yield block, block_size
@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper):
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = sharder.append(prefix + name, buffer)
block, block_size = sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# save extra states
@ -774,32 +775,10 @@ class GeminiDDP(ModelWrapper):
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
block, block_size = sharder.append(extra_state_key, extra_state)
block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
yield sharder.current_block, sharder.current_block_size
class _StateDictSharder:
def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.current_block = OrderedDict()
self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size

View File

@ -10,7 +10,7 @@ from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
@ -692,49 +692,17 @@ class GeminiOptimizer(OptimizerWrapper):
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
"""
current_block = {}
current_block_size = 0
sharder = StateDictSharder(max_shard_size)
for param_id in self.id_to_real_params.keys():
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
ret_block = None
ret_block_size = 0
block, block_size = sharder.append_optim_state(param_id, state)
if block is not None:
yield block, block_size
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor:
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[param_id] = state
current_block_size += state_size
if ret_block != None:
yield ret_block, ret_block_size
yield current_block, current_block_size
yield sharder.current_block, sharder.current_block_size
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')

View File

@ -338,6 +338,24 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.zero_grad()
def backward_by_grad(self, tensor, grad):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)
if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self.zero_grad()
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
@ -363,7 +381,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
if not self.require_grad_sync:
return

View File

@ -7,13 +7,15 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
bash test_ci.sh
```
### Results on 2-GPU
### Bert-Finetune Results
| Plugin | Accuracy | F1-score | GPU number |
| -------------- | -------- | -------- | -------- |
| torch_ddp | 84.4% | 88.6% | 2 |
| torch_ddp_fp16 | 84.7% | 88.8% | 2 |
| gemini | 84.0% | 88.4% | 2 |
| hybrid_parallel | 84.5% | 88.6% | 4 |
| Plugin | Accuracy | F1-score |
| -------------- | -------- | -------- |
| torch_ddp | 84.4% | 88.6% |
| torch_ddp_fp16 | 84.7% | 88.8% |
| gemini | 84.0% | 88.4% |
## Benchmark
```

View File

@ -1,12 +1,14 @@
import argparse
from typing import List, Union
from contextlib import nullcontext
from typing import Callable, List, Union
import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
from torch.optim import Optimizer
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
@ -18,8 +20,9 @@ from transformers import (
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
@ -32,14 +35,26 @@ LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1
output_transform_fn = lambda x: x
criterion = lambda x: x.loss
def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}
@torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
def evaluate_model(
model: nn.Module,
optimizer,
criterion,
test_dataloader: Union[DataLoader, List[DataLoader]],
num_labels: int,
task_name: str,
eval_splits: List[str],
booster: Booster,
coordinator: DistCoordinator,
):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()
@ -47,23 +62,66 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat
accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)
if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
batch_size = batch["input_ids"].shape[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
pg_mesh = booster.plugin.pg_mesh
pp_group = booster.plugin.pp_group
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
current_rank = dist.get_rank()
#TODO pass dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)
metric.add_batch(predictions=preds, references=labels)
if booster.plugin.stage_manager.is_last_stage():
val_loss = outputs["loss"]
logits = outputs["outputs"]["logits"]
accum_loss.add_(val_loss)
if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
dist.broadcast(preds, src=current_rank, group=pp_group)
dist.broadcast(val_loss, src=current_rank, group=pp_group)
metric.add_batch(predictions=preds, references=labels)
elif current_rank in current_pp_group_ranks:
val_loss = torch.empty((1,), device=get_current_device())
preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device())
dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group)
dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group)
accum_loss.add_(val_loss)
metric.add_batch(predictions=preds, references=labels)
else:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
accum_loss.add_(val_loss)
if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
metric.add_batch(predictions=preds, references=labels)
results = metric.compute()
dist.all_reduce(accum_loss.div_(len(dataloader)))
if coordinator.is_master():
if coordinator.is_master() and results is not None:
results['loss'] = accum_loss.item() / coordinator.world_size
return results
if isinstance(test_dataloader, DataLoader):
@ -77,25 +135,43 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat
return final_results
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
booster: Booster, coordinator: DistCoordinator):
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
is_pp_last_stage = hasattr(
booster.plugin,
"stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage()
with tqdm(train_dataloader,
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
for batch in pbar:
# Forward pass
batch = move_to_cuda(batch)
outputs = model(**batch)
loss = outputs[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
#TODO pass train_dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
# Backward and optimize
if booster.plugin.stage_manager.is_last_stage():
loss = outputs['loss']
pbar.set_postfix({'loss': loss.item()})
else:
outputs = model(**batch)
loss = _criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({'loss': loss.item()})
# Backward and optimize
booster.backward(loss, optimizer)
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Print log info
pbar.set_postfix({'loss': loss.item()})
def main():
# ==============================
@ -107,7 +183,7 @@ def main():
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
help="plugin to use")
parser.add_argument(
"--model_type",
@ -116,6 +192,7 @@ def main():
help="bert or albert",
)
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()
if args.model_type == 'bert':
@ -145,6 +222,17 @@ def main():
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision='fp16',
initial_scale=1)
booster = Booster(plugin=plugin, **booster_kwargs)
@ -165,8 +253,9 @@ def main():
# bert pretrained model
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else:
@ -196,19 +285,27 @@ def main():
num_training_steps=total_steps,
)
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
optimizer,
criterion=_criterion,
lr_scheduler=lr_scheduler)
# ==============================
# Train model
# ==============================
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task,
data_builder.eval_splits, booster, coordinator)
if coordinator.is_master():
print(results)

View File

@ -3,6 +3,6 @@ set -xe
pip install -r requirements.txt
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
done

View File

@ -1,7 +1,5 @@
[pytest]
markers =
cpu: tests which can run on CPU
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features
dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs)
largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs)
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy

View File

@ -2,7 +2,7 @@ from .albert import *
from .bert import *
from .blip2 import *
from .bloom import *
from .chatglm import *
from .chatglm2 import *
from .gpt import *
from .llama import *
from .opt import *

View File

@ -12,8 +12,8 @@ from ..registry import ModelAttribute, model_zoo
def data_gen():
input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
return dict(input_ids=input_ids, attention_mask=attention_mask)

View File

@ -0,0 +1,164 @@
import pytest
import torch
import torch.distributed as dist
from torch.optim import Adam
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
assert_close_loose,
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
# TODO (Baizhou): Add test cases for shard=False
@clear_cache_before_run()
@parameterize('shard', [True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
@parameterize('test_config', [{
'tp_size': 4,
'pp_size': 1,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 2,
'pp_size': 1,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
(model_fn, data_gen_fn, output_transform_fn, loss_fn,
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = loss_fn
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
def _preprocess_data(data):
if booster.plugin.stage_manager is not None:
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
return iter([data])
else:
return {k: v.cuda() for k, v in data.items()}
model = model_fn().cuda()
optimizer = Adam(model.parameters(), lr=1e-3)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(_preprocess_data(data),
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
else:
output = model(**_preprocess_data(data))
loss = criterion(output)
optimizer.backward(loss)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()
new_model = model_fn().cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False)
dist.barrier()
# Check whether the loaded model & optimizer works smoothly.
model.train()
new_model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(_preprocess_data(data),
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
booster.execute_pipeline(_preprocess_data(data),
new_model,
_criterion,
new_optimizer,
return_loss=True,
return_outputs=False)
else:
old_model_loss = criterion(model(**_preprocess_data(data)))
optimizer.backward(old_model_loss)
new_model_loss = criterion(new_model(**_preprocess_data(data)))
new_optimizer.backward(new_model_loss)
optimizer.step()
new_optimizer.step()
# Check updated weights.
stage_manager = booster.plugin.stage_manager
if stage_manager is None or stage_manager.is_first_stage():
assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3)
assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data,
new_model.unwrap().h[0].mlp.c_fc.weight.data,
atol=5e-3,
rtol=5e-3)
dist.barrier()
Randomizer.reset_index()
clear_layout_converter()
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_hybrid_ckpIO(world_size):
spawn(run_dist, world_size)

View File

@ -0,0 +1,83 @@
import os
import pytest
import torch
import torch.distributed as dist
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize('model_name', ['transformers_gpt'])
@parameterize('plugin_type', ['ddp', 'zero', 'gemini'])
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
(model_fn, data_gen_fn, output_transform_fn, loss_fn,
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = loss_fn
if plugin_type == 'ddp':
plugin = TorchDDPPlugin()
elif plugin_type == 'zero':
plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32)
elif plugin_type == 'gemini':
plugin = GeminiPlugin(precision="fp16", initial_scale=32)
else:
raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.")
booster = Booster(plugin=plugin)
model = model_fn().cuda()
model_huggingface_cls = model.__class__
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()
new_model = model_huggingface_cls.from_pretrained(model_ckpt_path)
new_model = new_model.cuda()
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
if plugin_type == 'gemini':
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
new_model.unwrap().state_dict(only_rank_0=False), False)
else:
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
dist.barrier()
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_from_pretrained()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_huggingface_compatibility(world_size):
spawn(run_dist, world_size)

View File

@ -8,7 +8,6 @@ import pytest
from colossalai.context.config import Config
@pytest.mark.cpu
def test_load_config():
filename = Path(__file__).parent.joinpath('sample_config.py')
config = Config.from_file(filename)

View File

@ -143,7 +143,6 @@ def run_dist(rank, world_size, port, backend, port_list, host):
reset_seeds()
@pytest.mark.cpu
@rerun_if_address_is_in_use()
def test_context():
"""

View File

@ -5,11 +5,10 @@ import os
from pathlib import Path
import pytest
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
@pytest.mark.cpu
def test_cifar10_dataset():
# build transform
transform_pipeline = [transforms.ToTensor()]

View File

@ -53,7 +53,6 @@ def run_data_sampler(rank, world_size, port):
torch.cuda.empty_cache()
@pytest.mark.cpu
@rerun_if_address_is_in_use()
def test_data_sampler():
spawn(run_data_sampler, 4)

View File

@ -64,7 +64,6 @@ def run_data_sampler(rank, world_size, port):
torch.cuda.empty_cache()
@pytest.mark.cpu
@rerun_if_address_is_in_use()
def test_data_sampler():
spawn(run_data_sampler, 4)

View File

@ -0,0 +1,39 @@
from colossalai.shardformer.policies.t5 import T5BasePolicy
def test_t5_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_t5_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end

View File

@ -0,0 +1,44 @@
from colossalai.shardformer.policies.whisper import WhisperPolicy
def test_whisper_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_whisper_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end
if __name__ == '__main__':
test_whisper_pipeline_distribution()
test_whisper_pipeline_layers()

View File

@ -0,0 +1,161 @@
import copy
from functools import partial
from types import MethodType
import pytest
import torch
import torch.nn as nn
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(4, 8)
self.linear2 = nn.Linear(8, 8)
self.linear3 = nn.Linear(8, 8)
self.linear4 = nn.Linear(8, 8)
self.linear5 = nn.Linear(8, 8)
self.linear6 = nn.Linear(8, 8)
self.linear7 = nn.Linear(8, 8)
self.linear8 = nn.Linear(8, 4)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.linear6(x)
x = self.linear7(x)
x = self.linear8(x)
return x
def pp_linear_fwd(forward,
data: torch.Tensor = None,
input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
num_chunks: int = None,
model_chunk_id: int = None):
if stage_mgr.is_first_stage() and model_chunk_id == 0:
return {'input_obj': forward(data)}
elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1:
return forward(input_obj)
else:
return {'input_obj': forward(input_obj)}
@parameterize("num_micro_batches", [4, 8, 12])
def examine_pp(num_micro_batches):
"""
This test is to examine the correctness of interleaved 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()
seed_all(1453)
NUM_MICRO_BATCHS = num_micro_batches
BATCH_SIZE = num_micro_batches
NUM_CHUNKS = 2
# create model
torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda()
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, world_size, 1)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True)
schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager)
sharded_model = torch.nn.ModuleList()
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
if idx % (world_size) == local_rank:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd,
stage_mgr=stage_manager,
num_chunks=NUM_CHUNKS,
model_chunk_id=len(sharded_model)), sub_model._forward)
sharded_model.append(sub_model.cuda())
# create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1))
# create
seed_all(1453)
if local_rank == 0:
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
else:
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
torch.distributed.all_reduce(input_list[0])
criterion = lambda x, y: torch.mean(x)
# forward and backward
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output, _)
torch_loss.backward()
pp_ret = schedule.forward_backward_step(sharded_model,
pp_optimizer,
iter(input_list),
criterion,
return_loss=True,
return_outputs=True)
# check loss
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret['loss'])
# check gradients
torch_grad = []
for torch_p in torch_model.parameters():
torch_grad.append(torch_p.grad.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
if idx < 2:
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
else:
assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data)
# step
torch_optimizer.step()
pp_optimizer.step()
# check updated param
torch_param = []
for torch_p in torch_model.parameters():
torch_param.append(torch_p.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
if idx < 2:
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
else:
assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
examine_pp()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pp():
spawn(run_dist, 4)
if __name__ == '__main__':
test_pp()

View File

@ -61,7 +61,7 @@ def examine_pp():
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, world_size, 1)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
if idx % (world_size) == local_rank:

View File

@ -49,15 +49,6 @@ def check_stage_manager():
next_rank = ranks_in_group[ranks_in_group.index(rank) + 1]
assert stage_manager.get_next_rank() == next_rank
# check virtual stage
stage_manager.set_num_virtual_stages(PP_SIZE * 2)
assert stage_manager.num_virtual_stages == PP_SIZE * 2
stage_manager.set_virtual_stage(stage_manager.stage * 2)
assert stage_manager.virtual_stage == stage_manager.stage * 2
with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1):
assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1
assert stage_manager.virtual_stage == stage_manager.stage * 2
# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:

View File

@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor
@parameterize('lazy_init', [False, True])
def check_linear_conv_1d_col(lazy_init: bool):
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
@ -62,7 +61,9 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
n_fused=3)
seq_parallel=seq_parallel,
n_fused=3,
overlap=overlap)
assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
@ -76,10 +77,11 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear.load_state_dict(linear_conv_col.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_conv_col(x)
assert_close(rearrange(out, 1), gather_out)
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out)
# check backward correctness
out.sum().backward()
@ -89,14 +91,16 @@ def check_linear_conv_1d_col(lazy_init: bool):
assert_close(target_grad, linear_conv_col.weight.grad)
@parameterize('lazy_init', [False, True])
def check_linear_conv_1d_row(lazy_init: bool):
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy,
process_group=None,
parallel_input=False,
seq_parallel=seq_parallel)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
@ -109,10 +113,11 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness
out.sum().backward()
@ -123,12 +128,19 @@ def check_linear_conv_1d_row(lazy_init: bool):
assert_close(target_grad, linear_row.weight.grad)
@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
@parameterize('overlap', [True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel, overlap)
check_linear_conv_1d_row(lazy_init, seq_parallel)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# test for linear conv
check_linear_conv_1d_col()
check_linear_conv_1d_row()
check_gpt2_qkv_fused_linear_1d()
@rerun_if_address_is_in_use()

View File

@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize('lazy_init', [False, True])
def check_linear_1d_col(lazy_init: bool):
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
linear_col = Linear1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
seq_parallel=seq_parallel,
overlap=overlap)
# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool):
linear_col.load_state_dict(linear.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True)
out = linear(x_for_unshard)
@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool):
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_unshard_gard, x_for_shard.grad)
@parameterize('lazy_init', [False, True])
def check_linear_1d_row(lazy_init: bool):
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
linear_row = Linear1D_Row.from_native_module(linear_copy,
process_group=None,
parallel_input=False,
seq_parallel=seq_parallel)
assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128])
@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool):
linear_row.load_state_dict(linear.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool):
# run forward
out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard)
assert_close(out, gather_out)
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness
out.sum().backward()
@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad)
@parameterize('lazy_init', [False, True])
def check_linear_col_plus_row(lazy_init: bool):
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda()
@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool):
with ctx:
linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
linear_col = Linear1D_Col.from_native_module(linear_1_copy,
process_group=None,
gather_output=False,
seq_parallel=seq_parallel,
overlap=overlap)
linear_row = Linear1D_Row.from_native_module(linear_2_copy,
process_group=None,
parallel_input=True,
seq_parallel=seq_parallel)
linear_1.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear_1.state_dict())
@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool):
linear_row.load_state_dict(linear_2.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True)
# run forward
unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard))
assert_close(unshard_out, shard_out)
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, shard_out)
# check backward correctness
unshard_out.sum().backward()
@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool):
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_unshard_gard, x_for_shard.grad)
def run_dist(rank, world_size, port):
@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
@parameterize('overlap', [True])
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
check_linear_1d_col(lazy_init, seq_parallel, overlap)
check_linear_1d_row(lazy_init, seq_parallel)
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
def check_dist_linear(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_1d_col()
check_linear_1d_row()
check_linear_col_plus_row()
run_dist_linear_test()
@rerun_if_address_is_in_use()
def test_linear():
spawn(run_dist, nprocs=2)
spawn(check_dist_linear, nprocs=2)
if __name__ == '__main__':

View File

@ -1,4 +1,5 @@
import copy
import math
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
@ -12,6 +13,7 @@ from torch.optim import Adam, Optimizer
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@ -25,6 +27,7 @@ def build_model(model_fn,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
use_lazy_init: bool = False):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext()
@ -38,7 +41,8 @@ def build_model(model_fn,
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
enable_jit_fused=enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
@ -135,6 +139,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
return loss
data = data_gen_fn()
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data['input_ids'].shape[1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
input_shape = data['input_ids'].shape
for k, v in data.items():
if v.shape == input_shape:
data[k] = v.repeat(1, times)
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in data.items():
@ -177,11 +191,10 @@ def check_output_hidden_state(org_output: Tensor,
org_hidden_state = org_output.last_hidden_state
if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)
sharded_hidden_state = sharded_output['outputs']['last_hidden_state']
else:
sharded_hidden_state = sharded_output.last_hidden_state
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
@ -219,6 +232,43 @@ def check_weight(org_model: Module,
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
def get_grad_tensors_for_check(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
name: str = None):
grad_to_check = {}
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
grad_to_check[suffix] = {
"org_grad": org_grad.float(),
"shard_grad": shard_grad.float(),
"rtol": rtol,
"atol": atol
}
return grad_to_check
# used by sam/blip2
def check_grad(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
@ -231,7 +281,6 @@ def check_grad(org_model: Module,
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
@ -246,3 +295,30 @@ def check_grad(org_model: Module,
assert torch.allclose(
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
def unwrap_model(module: Module,
base_model_class_name: Optional[str] = None,
base_model_attribute_name: Optional[str] = None):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if base_model_class_name is None:
return module
if module.__class__.__name__ == base_model_class_name:
return module
return getattr(module, base_model_attribute_name, None)
def check_all_grad_tensors(check_tensors):
"""
"org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model
"""
for suffix, check_info in check_tensors.items():
org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"]
atol = check_info["atol"]
assert torch.allclose(
org_grad, shard_grad, atol=atol, rtol=rtol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"

View File

@ -10,11 +10,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -32,8 +34,46 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
bert = unwrap_model(org_model, 'BertModel', 'bert')
sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert')
col_layer_for_check = ['encoder.layer[0].output.dense']
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
col_layer_grads = get_grad_tensors_for_check(bert,
sharded_bert,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
row_layer_grads = get_grad_tensors_for_check(bert,
sharded_bert,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@ -44,30 +84,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'BertModel':
bert = org_model
sharded_bert = sharded_model.unwrap()
else:
bert = org_model.bert
sharded_bert = sharded_model.unwrap().bert
col_layer_for_check = ['encoder.layer[0].output.dense']
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
else:
@ -75,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@ -98,6 +119,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_bert_test(test_config):
@ -111,12 +155,50 @@ def run_bert_test(test_config):
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_bert_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bert_test()
def check_bert_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bert_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@ -124,5 +206,13 @@ def test_bert():
spawn(check_bert, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert_3d():
spawn(check_bert_3d, 8)
if __name__ == "__main__":
test_bert()
test_bert_3d()

View File

@ -3,16 +3,19 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -34,6 +37,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
bloom = unwrap_model(org_model, 'BloomModel', 'transformer')
sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer')
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense']
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(bloom,
sharded_bloom,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
col_layer_grads = get_grad_tensors_for_check(bloom,
sharded_bloom,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@ -45,28 +85,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
bloom = org_model
sharded_bloom = sharded_model.unwrap()
else:
bloom = org_model.transformer
sharded_bloom = sharded_model.unwrap().transformer
# check grad
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 5e-3, 5e-3
check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
@ -74,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@ -97,18 +118,72 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_bloom_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_bloom_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -118,6 +193,12 @@ def check_bloom(rank, world_size, port):
run_bloom_test()
def check_bloom_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bloom_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@ -125,5 +206,13 @@ def test_bloom():
spawn(check_bloom, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom_3d():
spawn(check_bloom_3d, 8)
if __name__ == "__main__":
test_bloom()
test_bloom_3d()

View File

@ -4,16 +4,19 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -35,6 +38,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer')
shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer')
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(chatglm_model,
shard_chatglm_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
col_layer_grads = get_grad_tensors_for_check(chatglm_model,
shard_chatglm_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@ -47,43 +88,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'ChatGLMModel':
chatglm_model = org_model
shard_chatglm_model = sharded_model.unwrap()
else:
chatglm_model = org_model.transformer
shard_chatglm_model = sharded_model.unwrap().transformer
# check grad
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_grad(chatglm_model,
shard_chatglm_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
check_grad(chatglm_model,
shard_chatglm_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
@ -98,6 +103,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -121,12 +130,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}])
def run_chatglm_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_chatglm_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
@ -142,6 +194,12 @@ def check_chatglm(rank, world_size, port):
run_chatglm_test()
def check_chatglm_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_chatglm_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@ -149,5 +207,13 @@ def test_chatglm():
spawn(check_chatglm, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm_3d():
spawn(check_chatglm_3d, 8)
if __name__ == "__main__":
test_chatglm()
test_chatglm_3d()

View File

@ -3,18 +3,20 @@ import torch
from torch import distributed as dist
import colossalai
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -36,6 +38,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer')
sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer')
col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
col_layer_grads = get_grad_tensors_for_check(gpt2,
sharded_gpt2,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
row_layer_grads = get_grad_tensors_for_check(gpt2,
sharded_gpt2,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@ -48,32 +87,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
def unwrap(module):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if module.__class__.__name__ == 'GPT2Model':
return module
return module.transformer
# unwrap model
gpt2 = unwrap(org_model)
sharded_gpt2 = unwrap(sharded_model)
col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# check grad
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
@ -81,6 +95,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -106,12 +124,80 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': True,
'use_lazy_init': True,
'enable_sequence_parallelism': True,
'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'enable_sequence_parallelism': True,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
@clear_cache_before_run()
def run_gpt2_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
@clear_cache_before_run()
def run_gpt2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
@ -127,10 +213,13 @@ def check_gpt2(rank, world_size, port):
run_gpt2_test()
# TODO(ver217): fix this
def check_gpt2_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt2_3d_test()
@pytest.mark.skip("this will stuck in CI")
@pytest.mark.skip(reason="This test will hang in CI")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@ -138,5 +227,13 @@ def test_gpt2():
spawn(check_gpt2, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2_3d():
spawn(check_gpt2_3d, 8)
if __name__ == "__main__":
test_gpt2()
test_gpt2_3d()

View File

@ -6,16 +6,19 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -39,6 +42,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
col_layer_for_check = ['layers[0].self_attn.o_proj']
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(llama_model,
shard_llama_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
col_layer_grads = get_grad_tensors_for_check(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@ -51,42 +91,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'LlamaModel':
llama_model = org_model
shard_llama_model = sharded_model.unwrap()
else:
llama_model = org_model.model
shard_llama_model = sharded_model.unwrap().model
# check grad
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
col_layer_for_check = ['layers[0].self_attn.o_proj']
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-4
else:
atol, rtol = 5e-3, 5e-3
check_grad(llama_model,
shard_llama_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
check_grad(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
@ -101,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@ -128,19 +136,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_llama_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_llama_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -150,6 +213,12 @@ def check_llama(rank, world_size, port):
run_llama_test()
def check_llama_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@ -157,5 +226,13 @@ def test_llama():
spawn(check_llama, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama_3d():
spawn(check_llama_3d, 8)
if __name__ == "__main__":
test_llama()
test_llama_3d()

View File

@ -6,16 +6,19 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -39,6 +42,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
opt_model = unwrap_model(org_model, 'OPTModel', 'model')
shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model')
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3
else:
atol, rtol = 4e-2, 4e-2
row_layer_grads = get_grad_tensors_for_check(opt_model,
shard_opt_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
col_layer_grads = get_grad_tensors_for_check(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@ -50,42 +90,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'OPTModel':
opt_model = org_model
shard_opt_model = sharded_model.unwrap()
else:
opt_model = org_model.model
shard_opt_model = sharded_model.unwrap().model
# check grad
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3
else:
atol, rtol = 3e-2, 3e-2
check_grad(opt_model,
shard_opt_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
check_grad(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
@ -100,6 +105,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -123,12 +132,62 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_opt_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
clear_layout_converter()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_opt_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
@ -144,6 +203,12 @@ def check_OPTModel(rank, world_size, port):
run_opt_test()
def check_opt_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_opt_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@ -151,5 +216,13 @@ def test_OPTModel():
spawn(check_OPTModel, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_opt_3d():
spawn(check_opt_3d, 8)
if __name__ == '__main__':
test_OPTModel()
test_opt_3d()

View File

@ -1,5 +1,6 @@
import pytest
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers
@ -9,11 +10,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -35,6 +38,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
t5 = unwrap_model(org_model)
sharded_t5 = unwrap_model(sharded_model)
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
row_layer_grads = get_grad_tensors_for_check(t5,
sharded_t5,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@ -47,30 +76,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
t5 = org_model
sharded_t5 = sharded_model.unwrap()
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
# check weights and gradients
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
atol, rtol = 5e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@ -99,17 +115,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
@clear_cache_before_run()
def run_t5_test(test_config):
# TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO(baizhou): add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
@ -125,12 +160,49 @@ def run_t5_test(test_config):
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_t5_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
def check_t5_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@ -138,5 +210,13 @@ def test_t5():
spawn(check_t5, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_t5_3d():
spawn(check_t5_3d, 8)
if __name__ == "__main__":
test_t5()
test_t5_3d()

View File

@ -9,11 +9,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
@ -35,6 +37,44 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
vit_model = unwrap_model(org_model, 'ViTModel', 'vit')
shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit')
# check grad
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(vit_model,
shard_vit_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
col_layer_grads = get_grad_tensors_for_check(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
@ -44,45 +84,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if org_model.__class__.__name__ == 'ViTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'ViTModel':
vit_model = org_model
shard_vit_model = sharded_model.unwrap()
else:
vit_model = org_model.vit
shard_vit_model = sharded_model.unwrap().vit
# check grad
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_grad(vit_model,
shard_vit_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
check_grad(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
@ -97,9 +101,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
#TODO: num_microbatch size = 2 inf loss
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
@ -120,20 +128,71 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': True,
'use_lazy_init': False,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_vit_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
# TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_vit_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@ -143,6 +202,12 @@ def check_vit(rank, world_size, port):
run_vit_test()
def check_vit_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
@ -150,5 +215,13 @@ def test_vit():
spawn(check_vit, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit_3d():
spawn(check_vit_3d, 8)
if __name__ == "__main__":
test_vit()
test_vit_3d()

View File

@ -3,6 +3,8 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
@ -11,55 +13,205 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5)
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
# do backward
org_loss.backward()
shard_loss.backward()
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwarp the model
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
whisper = org_model.model
sharded_whisper = sharded_model.model
sharded_whisper = sharded_model.unwrap().model
else:
whisper = org_model
sharded_whisper = sharded_model
sharded_whisper = sharded_model.unwrap()
# check grad
if org_model.__class__.__name__ == 'WhisperForAudioClassification':
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj']
else:
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj']
check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
col_layer_for_check = [
'encoder.layers[0].self_attn.q_proj',
# 'decoder.layers[0].self_attn.q_proj'
]
row_layer_for_check = [
'encoder.layers[0].self_attn.out_proj',
#'decoder.layers[0].self_attn.out_proj'
]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
row_layer_grads = get_grad_tensors_for_check(whisper,
sharded_whisper,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1)
col_layer_grads = get_grad_tensors_for_check(whisper,
sharded_whisper,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'WhisperModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(whisper,
sharded_whisper,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
check_weight(whisper,
sharded_whisper,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
#TODO fix WhisperForConditionalGeneration enable jit fused operato
# TODOjianghai) fix fp16
@parameterize(
'test_config',
[
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
},
# whisper is not supported fp16 for now.
])
def run_whisper_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn,
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification':
continue
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_whisper_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@ -69,12 +221,26 @@ def check_whisper(rank, world_size, port):
run_whisper_test()
def check_whisper_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_whisper_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_whisper():
spawn(check_whisper, 2)
spawn(check_whisper, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_whisper_3d():
spawn(check_whisper_3d, 8)
if __name__ == "__main__":
test_whisper()
test_whisper_3d()

View File

@ -40,7 +40,6 @@ def forward_inplace(x, weight):
return out
@pytest.mark.gpu
@clear_cache_before_run()
@parameterize("use_reentrant", [True, False])
@parameterize("cpu_offload", [True, False])