mirror of https://github.com/hpcaitech/ColossalAI
[example] update gpt example for larger model scale (#2211)
parent
24246f7aa5
commit
d5e3e3ec01
|
@ -59,7 +59,6 @@ class MemStatsCollector:
|
|||
return [t - self._sampling_time[0] for t in self._sampling_time]
|
||||
|
||||
def start_collection(self):
|
||||
print('start collection')
|
||||
self._start_flag = True
|
||||
self._mem_monitor.start()
|
||||
|
||||
|
@ -68,7 +67,6 @@ class MemStatsCollector:
|
|||
# self._step_total = len(self._sampling_time)
|
||||
self._step_total = len(self._memstats.non_model_data_list('cuda'))
|
||||
self._start_flag = False
|
||||
self._mem_monitor.finish()
|
||||
print(f'finish_collection {self._step_total}')
|
||||
|
||||
# deprecated
|
||||
|
|
|
@ -62,7 +62,7 @@ ColossalAI version 0.1.13.
|
|||
|
||||
How dose Batch Size affect the efficency.
|
||||
|
||||
| model | #GPU | policy | TP |batch | Tflops |
|
||||
| model | #GPU | policy | TP | batch per DP | Tflops |
|
||||
| ---------- | --------- |--------- |--------- |--------- |--------- |
|
||||
| gpt2_10b | 2 | cpu | 1 | 32 | 122.046 |
|
||||
| gpt2_10b | 2 | cpu | 1 | 16 | 82.649 |
|
||||
|
@ -71,7 +71,7 @@ How dose Batch Size affect the efficency.
|
|||
|
||||
How dose the Placement Policy affect the efficency.
|
||||
|
||||
| model | #GPU | policy | TP |batch | Tflops |
|
||||
| model | #GPU | policy | TP | batch per DP | Tflops |
|
||||
| ---------- | --------- |--------- |--------- |--------- |--------- |
|
||||
| gpt2_10b | 4 | auto | 1 | 8 | 88.657 |
|
||||
| gpt2_10b | 4 | cuda | 1 | 8 | OOM |
|
||||
|
@ -80,9 +80,23 @@ How dose the Placement Policy affect the efficency.
|
|||
|
||||
How dose the Tensor Parallel Degree affect the efficency.
|
||||
|
||||
| model | #GPU | policy | TP |batch | Tflops |
|
||||
| model | #GPU | policy | TP | batch per DP | Tflops |
|
||||
| ---------- | --------- |--------- |--------- |--------- |--------- |
|
||||
| gpt2_10b | 4 | auto | 1 | 8 | 88.657 |
|
||||
| gpt2_10b | 4 | auto | 2 | 8 | 56.687 |
|
||||
| gpt2_10b | 4 | auto | 4 | 8 | 29.019 |
|
||||
| gpt2_10b | 4 | auto | 4 | 64 | 50.411 |
|
||||
| gpt2_20b | 1 | cpu | 1 | 8 | 43.102 |
|
||||
| gpt2_20b | 4 | cpu | 4 | 8 | 28.491 |
|
||||
|
||||
|
||||
Touch the bar of model scale and batch size.
|
||||
|
||||
| model | #GPU | policy | TP | batch per DP | Tflops |
|
||||
| ---------- | --------- |--------- |--------- |--------- |--------- |
|
||||
|
||||
| gpt2_20b | 4 | cpu | 1 | 64 | CUDA OOM |
|
||||
| gpt2_20b | 4 | auto | 1/2 | 64 | CUDA OOM |
|
||||
| gpt2_20b | 4 | cpu | 2 | 64 | 121.394 |
|
||||
| gpt2_20b | 4 | cpu | 2 | 8 | 43.102 |
|
||||
| gpt2_20b | 8 | cpu | 2 | 64 | 125.170 |
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
from torch import nn
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
|
||||
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
def gpt2_medium(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_xl(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_10b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_14b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_20b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_24b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def model_builder(model_size: str):
|
||||
if model_size == "gpt2_medium":
|
||||
return gpt2_medium
|
||||
elif model_size == "gpt2_xl":
|
||||
return gpt2_xl
|
||||
elif model_size == "gpt2_10b":
|
||||
return gpt2_10b
|
||||
elif model_size == "gpt2_14b":
|
||||
return gpt2_14b
|
||||
elif model_size == "gpt2_20b":
|
||||
return gpt2_20b
|
||||
elif model_size == "gpt2_24b":
|
||||
return gpt2_24b
|
||||
|
||||
|
||||
__all__ = ['model_builder']
|
|
@ -2,9 +2,12 @@
|
|||
export DISTPAN="colossalai"
|
||||
|
||||
# The following options only valid when DISTPAN="colossalai"
|
||||
export TPDEGREE=4
|
||||
export GPUNUM=4
|
||||
export PLACEMENT='auto'
|
||||
export TPDEGREE=2
|
||||
export GPUNUM=8
|
||||
export PLACEMENT='cpu'
|
||||
export USE_SHARD_INIT=False
|
||||
export BATCH_SIZE=64
|
||||
export MODEL_TYPE="gpt2_20b"
|
||||
|
||||
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee run.log
|
||||
mkdir -p logs
|
||||
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=${GPUNUM} train_gpt_demo.py --tp_degree=${TPDEGREE} --model_type=${MODEL_TYPE} --batch_size=${BATCH_SIZE} --placement ${PLACEMENT} --shardinit ${USE_SHARD_INIT} --distplan ${DISTPAN} 2>&1 | tee ./logs/${MODEL_TYPE}_${DISTPAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}.log
|
||||
|
|
|
@ -6,18 +6,16 @@ import torch
|
|||
import torch.nn as nn
|
||||
from packaging import version
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from transformers import GPT2Config, GPT2LMHeadModel
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.nn.optimizer.gemini_optimizer import GeminiAdamOptimizer
|
||||
from colossalai.nn.optimizer.zero_optimizer import ZeroOptimizer
|
||||
from colossalai.nn.parallel import ZeroDDP
|
||||
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
|
||||
from model_zoo import model_builder
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -47,6 +45,18 @@ def parse_args():
|
|||
help=
|
||||
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="batch size per DP group of training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default='gpt2_medium',
|
||||
help="model model scale",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
@ -65,33 +75,6 @@ def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
|||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
|
||||
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
|
||||
class GPTLMModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size=768,
|
||||
num_layers=12,
|
||||
num_attention_heads=12,
|
||||
max_seq_len=1024,
|
||||
vocab_size=50257,
|
||||
checkpoint=False):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
self.model = GPT2LMHeadModel(
|
||||
GPT2Config(n_embd=hidden_size,
|
||||
n_layer=num_layers,
|
||||
n_head=num_attention_heads,
|
||||
n_positions=max_seq_len,
|
||||
n_ctx=max_seq_len,
|
||||
vocab_size=vocab_size))
|
||||
if checkpoint:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
# Only return lm_logits
|
||||
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
|
@ -112,18 +95,6 @@ def get_data(batch_size, seq_len, vocab_size):
|
|||
return input_ids, attention_mask
|
||||
|
||||
|
||||
def gpt2_medium(checkpoint=False):
|
||||
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_xl(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def gpt2_10b(checkpoint=True):
|
||||
return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)
|
||||
|
||||
|
||||
def get_cpu_mem():
|
||||
return psutil.Process().memory_info().rss / 1024**2
|
||||
|
||||
|
@ -210,7 +181,8 @@ def main():
|
|||
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
|
||||
raise TypeError(f"{args.distplan} is error")
|
||||
|
||||
BATCH_SIZE = 64
|
||||
# batch size per DP degree
|
||||
BATCH_SIZE = args.batch_size
|
||||
SEQ_LEN = 1024
|
||||
VOCAB_SIZE = 50257
|
||||
|
||||
|
@ -220,7 +192,7 @@ def main():
|
|||
colossalai.launch_from_torch(config={})
|
||||
|
||||
logger = get_dist_logger()
|
||||
logger.info(f"using dist plan {args.distplan}", ranks=[0])
|
||||
logger.info(f"{args.model_type}, {args.distplan}, batch size {BATCH_SIZE}", ranks=[0])
|
||||
|
||||
# build criterion
|
||||
criterion = GPTLMLoss()
|
||||
|
@ -232,8 +204,11 @@ def main():
|
|||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
|
||||
# build GPT model
|
||||
with ColoInitContext(device=get_current_device(), default_dist_spec=default_dist_spec, default_pg=default_pg):
|
||||
model = gpt2_10b(checkpoint=True)
|
||||
with ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=default_pg):
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
|
||||
pg = default_pg
|
||||
# Tensor Parallelism (TP)
|
||||
|
@ -246,7 +221,7 @@ def main():
|
|||
optimizer = GeminiAdamOptimizer(model, lr=1e-3, initial_scale=2**5)
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
else:
|
||||
model = gpt2_10b(checkpoint=True).cuda()
|
||||
model = model_builder(args.model_type)(checkpoint=True).cuda()
|
||||
|
||||
if args.distplan.startswith("torch"):
|
||||
model = DDP(model)
|
||||
|
@ -262,10 +237,14 @@ def main():
|
|||
overlap_communication=True,
|
||||
partition_grad=partition_flag,
|
||||
verbose=True)
|
||||
# notice that the model is still in fp32
|
||||
|
||||
# model is shared after TP
|
||||
numel = sum([p.numel() for p in model.parameters()])
|
||||
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
||||
|
||||
# Tflops_per_GPU = global_batch * global_numel * seq_len * 8 / #gpu
|
||||
# = (batch_per_DP_group * dp_degree) * (numel * tp_degree) * seq_len * 8 / (tp_degree * dp_degree)
|
||||
# = batch_per_DP_group * numel * seq_len * 8
|
||||
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
|
Loading…
Reference in New Issue