|
|
@ -5,12 +5,13 @@ import psutil
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
from packaging import version
|
|
|
|
from packaging import version
|
|
|
|
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
|
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
import colossalai
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
from colossalai.nn.parallel import ZeroDDP
|
|
|
|
from colossalai.nn.parallel import ZeroDDP
|
|
|
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
|
|
|
|
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
from colossalai.zero import ZeroOptimizer
|
|
|
|
from colossalai.zero import ZeroOptimizer
|
|
|
@ -19,17 +20,30 @@ from transformers import GPT2Config, GPT2LMHeadModel
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args():
|
|
|
|
def parse_args():
|
|
|
|
parser = colossalai.get_default_parser()
|
|
|
|
parser = colossalai.get_default_parser()
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
|
|
"--distplan",
|
|
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
|
|
default='colossalai',
|
|
|
|
|
|
|
|
help="The distributed plan [colossalai, ddp, zero].",
|
|
|
|
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
parser.add_argument(
|
|
|
|
"--tp_degree",
|
|
|
|
"--tp_degree",
|
|
|
|
type=int,
|
|
|
|
type=int,
|
|
|
|
default=1,
|
|
|
|
default=1,
|
|
|
|
help="Tensor Parallelism Degree.",
|
|
|
|
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
parser.add_argument(
|
|
|
|
"--placement",
|
|
|
|
"--placement",
|
|
|
|
type=str,
|
|
|
|
type=str,
|
|
|
|
default='cpu',
|
|
|
|
default='cpu',
|
|
|
|
help="Placement Policy for Gemini.",
|
|
|
|
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
|
|
"--shardinit",
|
|
|
|
|
|
|
|
type=bool,
|
|
|
|
|
|
|
|
default=False,
|
|
|
|
|
|
|
|
help=
|
|
|
|
|
|
|
|
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
|
|
|
|
)
|
|
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
return args
|
|
|
@ -38,8 +52,6 @@ def parse_args():
|
|
|
|
## Parameter Sharding Strategies for Tensor Parallelism
|
|
|
|
## Parameter Sharding Strategies for Tensor Parallelism
|
|
|
|
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
|
|
|
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
|
|
|
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
|
|
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
|
|
|
if param.process_group.tp_world_size() == 1:
|
|
|
|
|
|
|
|
param.set_process_group(pg)
|
|
|
|
|
|
|
|
param.set_tensor_spec(*spec)
|
|
|
|
param.set_tensor_spec(*spec)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -136,21 +148,30 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
for mn, module in model.named_modules():
|
|
|
|
for mn, module in model.named_modules():
|
|
|
|
for pn, param in module.named_parameters(recurse=False):
|
|
|
|
for pn, param in module.named_parameters(recurse=False):
|
|
|
|
# set process group for all parameters
|
|
|
|
# NOTE() a param maybe shared by tow modules
|
|
|
|
param.set_process_group(pg)
|
|
|
|
if hasattr(param, 'visited'):
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
param.set_dist_spec(ReplicaSpec())
|
|
|
|
if 'mlp.c_fc' in mn:
|
|
|
|
if 'mlp.c_fc' in mn:
|
|
|
|
if 'weight' in pn or 'bias' in pn:
|
|
|
|
if 'weight' in pn or 'bias' in pn:
|
|
|
|
split_param_col_tp1d(param, pg) # colmn slice
|
|
|
|
split_param_col_tp1d(param, pg) # colmn slice
|
|
|
|
# keep the shape of the output from c_fc
|
|
|
|
# keep the shape of the output from c_fc
|
|
|
|
param.compute_spec.set_output_replicate(False)
|
|
|
|
param.compute_spec.set_output_replicate(False)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
param.set_dist_spec(ReplicaSpec())
|
|
|
|
elif 'mlp.c_proj' in mn:
|
|
|
|
elif 'mlp.c_proj' in mn:
|
|
|
|
if 'weight' in pn:
|
|
|
|
if 'weight' in pn:
|
|
|
|
split_param_row_tp1d(param, pg) # row slice
|
|
|
|
split_param_row_tp1d(param, pg) # row slice
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
param.set_dist_spec(ReplicaSpec())
|
|
|
|
elif 'wte' in mn or 'wpe' in mn:
|
|
|
|
elif 'wte' in mn or 'wpe' in mn:
|
|
|
|
split_param_col_tp1d(param, pg) # colmn slice
|
|
|
|
split_param_col_tp1d(param, pg) # colmn slice
|
|
|
|
elif 'c_attn' in mn or 'c_proj' in mn:
|
|
|
|
elif 'c_attn' in mn or 'c_proj' in mn:
|
|
|
|
split_param_col_tp1d(param, pg) # colmn slice
|
|
|
|
split_param_col_tp1d(param, pg) # colmn slice
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
param.set_dist_spec(ReplicaSpec())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
param.visited = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Gemini + ZeRO DDP
|
|
|
|
# Gemini + ZeRO DDP
|
|
|
@ -188,32 +209,49 @@ def main():
|
|
|
|
disable_existing_loggers()
|
|
|
|
disable_existing_loggers()
|
|
|
|
colossalai.launch_from_torch(config={})
|
|
|
|
colossalai.launch_from_torch(config={})
|
|
|
|
|
|
|
|
|
|
|
|
pg = ProcessGroup(tp_degree=args.tp_degree)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_dist_logger()
|
|
|
|
logger = get_dist_logger()
|
|
|
|
logger.info(get_mem_info(), ranks=[0])
|
|
|
|
logger.info(f"using dist plan {args.distplan}", ranks=[0])
|
|
|
|
|
|
|
|
|
|
|
|
# build GPT model
|
|
|
|
|
|
|
|
with ColoInitContext(device=get_current_device()):
|
|
|
|
|
|
|
|
model = gpt2_medium(checkpoint=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
numel = sum([p.numel() for p in model.parameters()])
|
|
|
|
|
|
|
|
logger.info(f'Model numel: {numel}', ranks=[0])
|
|
|
|
|
|
|
|
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Tensor Parallelism (TP)
|
|
|
|
|
|
|
|
tensor_parallelize(model, pg)
|
|
|
|
|
|
|
|
# Gemini + ZeRO DP, Note it must be used after TP
|
|
|
|
|
|
|
|
model = gemini_zero_dpp(model, pg, args.placement)
|
|
|
|
|
|
|
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# build criterion
|
|
|
|
# build criterion
|
|
|
|
criterion = GPTLMLoss()
|
|
|
|
criterion = GPTLMLoss()
|
|
|
|
|
|
|
|
|
|
|
|
# build optimizer
|
|
|
|
torch.manual_seed(123)
|
|
|
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
|
|
|
if args.distplan == "colossalai":
|
|
|
|
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
|
|
|
# all param must use the same process group.
|
|
|
|
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
|
|
|
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
|
|
|
|
|
|
|
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# build GPT model
|
|
|
|
|
|
|
|
with ColoInitContext(device='cuda', default_dist_spec=default_dist_spec, default_pg=default_pg):
|
|
|
|
|
|
|
|
model = gpt2_medium(checkpoint=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pg = default_pg
|
|
|
|
|
|
|
|
# Tensor Parallelism (TP)
|
|
|
|
|
|
|
|
tensor_parallelize(model, pg)
|
|
|
|
|
|
|
|
# Gemini + ZeRO DP, Note it must be used after TP
|
|
|
|
|
|
|
|
model = gemini_zero_dpp(model, pg, args.placement)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# build optimizer
|
|
|
|
|
|
|
|
optimizer = HybridAdam(model.parameters(), lr=1e-3)
|
|
|
|
|
|
|
|
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
|
|
|
|
|
|
|
|
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif args.distplan == "ddp":
|
|
|
|
|
|
|
|
model = gpt2_medium(checkpoint=True).cuda()
|
|
|
|
|
|
|
|
ddp_model = DDP(model)
|
|
|
|
|
|
|
|
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif args.distplan == "zero":
|
|
|
|
|
|
|
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
|
|
|
|
|
|
|
model = gpt2_medium(checkpoint=True).cuda()
|
|
|
|
|
|
|
|
ddp_model = DDP(model)
|
|
|
|
|
|
|
|
optimizer = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
raise TypeError(f"{args.distplan} is error")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
numel = sum([p.numel() for p in model.parameters()])
|
|
|
|
|
|
|
|
logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
|
|
|
|
|
|
|
|
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
model.train()
|
|
|
|
model.train()
|
|
|
@ -225,7 +263,11 @@ def main():
|
|
|
|
outputs = model(input_ids, attn_mask)
|
|
|
|
outputs = model(input_ids, attn_mask)
|
|
|
|
loss = criterion(outputs, input_ids)
|
|
|
|
loss = criterion(outputs, input_ids)
|
|
|
|
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
|
|
|
|
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
|
|
|
|
optimizer.backward(loss)
|
|
|
|
if args.distplan == "colossalai":
|
|
|
|
|
|
|
|
optimizer.backward(loss)
|
|
|
|
|
|
|
|
elif args.distplan in ["ddp", "zero"]:
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
|
|
|
|
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.step()
|
|
|
|
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
|
|
|
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
|
|
|
|