mirror of https://github.com/hpcaitech/ColossalAI
[example] add TP to GPT example (#1828)
parent
49216d7ab1
commit
a25f755331
|
@ -1 +1 @@
|
||||||
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=2 train_gpt_demo.py 2>&1 | tee run.log
|
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log
|
||||||
|
|
|
@ -10,13 +10,48 @@ import colossalai
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
from colossalai.tensor import ProcessGroup
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.zero import ZeroOptimizer
|
from colossalai.zero import ZeroOptimizer
|
||||||
from transformers import GPT2Config, GPT2LMHeadModel
|
from transformers import GPT2Config, GPT2LMHeadModel
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = colossalai.get_default_parser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--tp_degree",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Tensor Parallelism Degree.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--placement",
|
||||||
|
type=str,
|
||||||
|
default='cpu',
|
||||||
|
help="Placement Policy for Gemini.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
## Parameter Sharding Strategies for Tensor Parallelism
|
||||||
|
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||||
|
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||||
|
if param.process_group.tp_world_size() == 1:
|
||||||
|
param.set_process_group(pg)
|
||||||
|
param.set_tensor_spec(*spec)
|
||||||
|
|
||||||
|
|
||||||
|
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||||
|
split_param_single_dim_tp1d(0, param, pg)
|
||||||
|
|
||||||
|
|
||||||
|
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||||
|
split_param_single_dim_tp1d(-1, param, pg)
|
||||||
|
|
||||||
|
|
||||||
|
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
|
||||||
class GPTLMModel(nn.Module):
|
class GPTLMModel(nn.Module):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -56,6 +91,7 @@ class GPTLMLoss(nn.Module):
|
||||||
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||||
|
|
||||||
|
|
||||||
|
## Randomly Generated Data
|
||||||
def get_data(batch_size, seq_len, vocab_size):
|
def get_data(batch_size, seq_len, vocab_size):
|
||||||
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
|
||||||
attention_mask = torch.ones_like(input_ids)
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
@ -90,54 +126,96 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
|
||||||
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
# Tensor Parallel
|
||||||
BATCH_SIZE = 8
|
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||||
SEQ_LEN = 1024
|
"""tensor_parallelize
|
||||||
VOCAB_SIZE = 50257
|
Sharding the Model Parameters.
|
||||||
NUM_STEPS = 10
|
|
||||||
PLACEMENT_POLICY = 'auto'
|
|
||||||
disable_existing_loggers()
|
|
||||||
colossalai.launch_from_torch(config={})
|
|
||||||
pg = ProcessGroup()
|
|
||||||
logger = get_dist_logger()
|
|
||||||
|
|
||||||
logger.info(get_mem_info(), ranks=[0])
|
Args:
|
||||||
# build GPT model
|
model (torch.nn.Module): a torch module to be sharded
|
||||||
with ColoInitContext(device=get_current_device()):
|
"""
|
||||||
model = gpt2_medium(checkpoint=True)
|
for mn, module in model.named_modules():
|
||||||
numel = sum([p.numel() for p in model.parameters()])
|
for pn, param in module.named_parameters(recurse=False):
|
||||||
logger.info(f'Model numel: {numel}', ranks=[0])
|
# set process group for all parameters
|
||||||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
param.set_process_group(pg)
|
||||||
|
|
||||||
|
if 'mlp.c_fc' in mn:
|
||||||
|
if 'weight' in pn or 'bias' in pn:
|
||||||
|
split_param_col_tp1d(param, pg) # colmn slice
|
||||||
|
# keep the shape of the output from c_fc
|
||||||
|
param.compute_spec.set_output_replicate(False)
|
||||||
|
elif 'mlp.c_proj' in mn:
|
||||||
|
if 'weight' in pn:
|
||||||
|
split_param_row_tp1d(param, pg) # row slice
|
||||||
|
elif 'wte' in mn or 'wpe' in mn:
|
||||||
|
split_param_col_tp1d(param, pg) # colmn slice
|
||||||
|
elif 'c_attn' in mn or 'c_proj' in mn:
|
||||||
|
split_param_col_tp1d(param, pg) # colmn slice
|
||||||
|
|
||||||
|
|
||||||
|
# Gemini + ZeRO DDP
|
||||||
|
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
|
||||||
cai_version = colossalai.__version__
|
cai_version = colossalai.__version__
|
||||||
logger.info(f'using Colossal-AI version {cai_version}')
|
|
||||||
if version.parse(cai_version) > version.parse("0.1.10"):
|
if version.parse(cai_version) > version.parse("0.1.10"):
|
||||||
from colossalai.nn.parallel import GeminiDDP
|
from colossalai.nn.parallel import GeminiDDP
|
||||||
model = GeminiDDP(model,
|
model = GeminiDDP(model,
|
||||||
device=get_current_device(),
|
device=get_current_device(),
|
||||||
placement_policy=PLACEMENT_POLICY,
|
placement_policy=placememt_policy,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
search_range_mb=32)
|
search_range_mb=32)
|
||||||
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
|
||||||
from colossalai.gemini import ChunkManager, GeminiManager
|
from colossalai.gemini import ChunkManager, GeminiManager
|
||||||
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
|
||||||
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
|
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
|
||||||
chunk_manager = ChunkManager(chunk_size,
|
chunk_manager = ChunkManager(chunk_size,
|
||||||
pg,
|
pg,
|
||||||
enable_distributed_storage=True,
|
enable_distributed_storage=True,
|
||||||
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
|
init_device=GeminiManager.get_default_device(placememt_policy))
|
||||||
model = ZeroDDP(model, gemini_manager)
|
model = ZeroDDP(model, gemini_manager)
|
||||||
|
else:
|
||||||
|
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
BATCH_SIZE = 8
|
||||||
|
SEQ_LEN = 1024
|
||||||
|
VOCAB_SIZE = 50257
|
||||||
|
NUM_STEPS = 10
|
||||||
|
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch_from_torch(config={})
|
||||||
|
|
||||||
|
pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||||
|
|
||||||
|
logger = get_dist_logger()
|
||||||
|
logger.info(get_mem_info(), ranks=[0])
|
||||||
|
|
||||||
|
# build GPT model
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model = gpt2_medium(checkpoint=True)
|
||||||
|
|
||||||
|
numel = sum([p.numel() for p in model.parameters()])
|
||||||
|
logger.info(f'Model numel: {numel}', ranks=[0])
|
||||||
|
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||||||
|
|
||||||
|
# Tensor Parallelism (TP)
|
||||||
|
tensor_parallelize(model, pg)
|
||||||
|
# Gemini + ZeRO DP, Note it must be used after TP
|
||||||
|
model = gemini_zero_dpp(model, pg, args.placement)
|
||||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||||
|
|
||||||
# build criterion
|
# build criterion
|
||||||
criterion = GPTLMLoss()
|
criterion = GPTLMLoss()
|
||||||
|
|
||||||
# optimizer
|
# build optimizer
|
||||||
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
||||||
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
||||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
model.train()
|
model.train()
|
||||||
for n in range(NUM_STEPS):
|
for n in range(NUM_STEPS):
|
||||||
# we just use randomly generated data here
|
# we just use randomly generated data here
|
||||||
|
@ -156,6 +234,8 @@ def main():
|
||||||
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
|
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
|
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -36,7 +36,6 @@ from datasets import load_dataset
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
from utils import colo_memory_cap
|
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
import transformers
|
import transformers
|
||||||
|
@ -47,7 +46,6 @@ from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
from colossalai.tensor import ProcessGroup
|
from colossalai.tensor import ProcessGroup
|
||||||
from colossalai.utils import get_current_device, get_dataloader
|
from colossalai.utils import get_current_device, get_dataloader
|
||||||
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
|
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.zero import ZeroOptimizer
|
from colossalai.zero import ZeroOptimizer
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
@ -249,12 +247,20 @@ def parse_args():
|
||||||
return args
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def colo_memory_cap(size_in_GB):
|
||||||
|
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
||||||
|
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
||||||
|
if size_in_GB * (1024**3) < cuda_capacity:
|
||||||
|
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
||||||
|
print("Using {} GB of GPU memory".format(size_in_GB))
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
disable_existing_loggers()
|
disable_existing_loggers()
|
||||||
colossalai.launch_from_torch(config=dict())
|
colossalai.launch_from_torch(config=dict())
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
is_main_process = gpc.get_local_rank(ParallelMode.DATA) == 0
|
is_main_process = dist.get_rank() == 0
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
datasets.utils.logging.set_verbosity_warning()
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
|
|
@ -1,28 +0,0 @@
|
||||||
import torch
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
|
|
||||||
def memory_cap(size_in_GB):
|
|
||||||
print(f"use only {size_in_GB} GB of CUDA memory")
|
|
||||||
assert dist.is_initialized(), "memory_cap must be used after dist init"
|
|
||||||
local_rank = dist.get_rank()
|
|
||||||
cuda_capacity = torch.cuda.get_device_properties(local_rank).total_memory
|
|
||||||
size_in_B = (size_in_GB * 1024**3)
|
|
||||||
if size_in_B > cuda_capacity:
|
|
||||||
print(f'memory_cap is uselsess since {cuda_capacity / 1024**3} less than {size_in_GB}')
|
|
||||||
return
|
|
||||||
fraction = (size_in_GB * 1024**3) / cuda_capacity
|
|
||||||
print(f'mem faction is {fraction}')
|
|
||||||
torch.cuda.set_per_process_memory_fraction(fraction, local_rank)
|
|
||||||
|
|
||||||
|
|
||||||
def colo_memory_cap(size_in_GB):
|
|
||||||
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
|
|
||||||
cuda_capacity = colo_device_memory_capacity(get_current_device())
|
|
||||||
if size_in_GB * (1024**3) < cuda_capacity:
|
|
||||||
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
|
|
||||||
print("Using {} GB of GPU memory".format(size_in_GB))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
memory_cap(40)
|
|
Loading…
Reference in New Issue