[example] add TP to GPT example (#1828)

pull/1832/head
Jiarui Fang 2022-11-08 17:17:19 +08:00 committed by GitHub
parent 49216d7ab1
commit a25f755331
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 113 additions and 55 deletions

View File

@ -1 +1 @@
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=2 train_gpt_demo.py 2>&1 | tee run.log env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log

View File

@ -10,13 +10,48 @@ import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from transformers import GPT2Config, GPT2LMHeadModel from transformers import GPT2Config, GPT2LMHeadModel
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini.",
)
args = parser.parse_args()
return args
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
if param.process_group.tp_world_size() == 1:
param.set_process_group(pg)
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
class GPTLMModel(nn.Module): class GPTLMModel(nn.Module):
def __init__(self, def __init__(self,
@ -56,6 +91,7 @@ class GPTLMLoss(nn.Module):
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
## Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size): def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
@ -90,54 +126,96 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
def main(): # Tensor Parallel
BATCH_SIZE = 8 def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
SEQ_LEN = 1024 """tensor_parallelize
VOCAB_SIZE = 50257 Sharding the Model Parameters.
NUM_STEPS = 10
PLACEMENT_POLICY = 'auto'
disable_existing_loggers()
colossalai.launch_from_torch(config={})
pg = ProcessGroup()
logger = get_dist_logger()
logger.info(get_mem_info(), ranks=[0]) Args:
# build GPT model model (torch.nn.Module): a torch module to be sharded
with ColoInitContext(device=get_current_device()): """
model = gpt2_medium(checkpoint=True) for mn, module in model.named_modules():
numel = sum([p.numel() for p in model.parameters()]) for pn, param in module.named_parameters(recurse=False):
logger.info(f'Model numel: {numel}', ranks=[0]) # set process group for all parameters
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) param.set_process_group(pg)
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # colmn slice
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__ cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
if version.parse(cai_version) > version.parse("0.1.10"): if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model, model = GeminiDDP(model,
device=get_current_device(), device=get_current_device(),
placement_policy=PLACEMENT_POLICY, placement_policy=placememt_policy,
pin_memory=True, pin_memory=True,
search_range_mb=32) search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"): elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32) chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager) gemini_manager = GeminiManager(placememt_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
pg, pg,
enable_distributed_storage=True, enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY)) init_device=GeminiManager.get_default_device(placememt_policy))
model = ZeroDDP(model, gemini_manager) model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
def main():
args = parse_args()
BATCH_SIZE = 8
SEQ_LEN = 1024
VOCAB_SIZE = 50257
NUM_STEPS = 10
disable_existing_loggers()
colossalai.launch_from_torch(config={})
pg = ProcessGroup(tp_degree=args.tp_degree)
logger = get_dist_logger()
logger.info(get_mem_info(), ranks=[0])
# build GPT model
with ColoInitContext(device=get_current_device()):
model = gpt2_medium(checkpoint=True)
numel = sum([p.numel() for p in model.parameters()])
logger.info(f'Model numel: {numel}', ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
# Gemini + ZeRO DP, Note it must be used after TP
model = gemini_zero_dpp(model, pg, args.placement)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# build criterion # build criterion
criterion = GPTLMLoss() criterion = GPTLMLoss()
# optimizer # build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
torch.cuda.synchronize()
model.train() model.train()
for n in range(NUM_STEPS): for n in range(NUM_STEPS):
# we just use randomly generated data here # we just use randomly generated data here
@ -156,6 +234,8 @@ def main():
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0]) ranks=[0])
torch.cuda.synchronize()
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -36,7 +36,6 @@ from datasets import load_dataset
from packaging import version from packaging import version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
from utils import colo_memory_cap
import colossalai import colossalai
import transformers import transformers
@ -47,7 +46,6 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from transformers import ( from transformers import (
@ -249,12 +247,20 @@ def parse_args():
return args return args
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB))
def main(): def main():
args = parse_args() args = parse_args()
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config=dict()) colossalai.launch_from_torch(config=dict())
logger = get_dist_logger() logger = get_dist_logger()
is_main_process = gpc.get_local_rank(ParallelMode.DATA) == 0 is_main_process = dist.get_rank() == 0
if is_main_process: if is_main_process:
datasets.utils.logging.set_verbosity_warning() datasets.utils.logging.set_verbosity_warning()

View File

@ -1,28 +0,0 @@
import torch
import torch.distributed as dist
def memory_cap(size_in_GB):
print(f"use only {size_in_GB} GB of CUDA memory")
assert dist.is_initialized(), "memory_cap must be used after dist init"
local_rank = dist.get_rank()
cuda_capacity = torch.cuda.get_device_properties(local_rank).total_memory
size_in_B = (size_in_GB * 1024**3)
if size_in_B > cuda_capacity:
print(f'memory_cap is uselsess since {cuda_capacity / 1024**3} less than {size_in_GB}')
return
fraction = (size_in_GB * 1024**3) / cuda_capacity
print(f'mem faction is {fraction}')
torch.cuda.set_per_process_memory_fraction(fraction, local_rank)
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB))
if __name__ == '__main__':
memory_cap(40)