[example] gpt, shard init on all processes (#2366)

pull/2371/head
Jiarui Fang 2023-01-06 15:44:50 +08:00 committed by GitHub
parent 1f8ab6f1f5
commit 1aaeb596c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 18 additions and 12 deletions

View File

@ -117,7 +117,7 @@ class ColoTensor(torch.Tensor):
def set_process_group(self, pg: ProcessGroup):
"""set_process_group
change the pg of the ColoTensor. Note that the valid use cases is limited.
Only existing pg is DP and dist spec is REPLICaTE is valid.
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
Args:
pg (ProcessGroup): target pg
@ -127,10 +127,10 @@ class ColoTensor(torch.Tensor):
# if the new pg is the same as the old pg, just returns
if self.process_group == pg:
return
assert self.process_group.tp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
assert self.dist_spec.placement.value == 'r', \
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
self.process_group = pg

View File

@ -148,10 +148,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# NOTE() a param maybe shared by tow modules
# NOTE() a param maybe shared by two modules
if hasattr(param, 'visited'):
continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup
param: ColoParameter = param
param.set_dist_spec(ReplicaSpec())
param.set_process_group(pg)
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
@ -170,7 +176,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
split_param_col_tp1d(param, pg) # colmn slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
@ -248,27 +253,28 @@ def main():
torch.manual_seed(123)
if args.distplan == "colossalai":
# all param must use the same process group.
default_pg = ProcessGroup(tp_degree=args.tp_degree)
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size)
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
# build GPT model
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
with ColoInitContext(device=get_current_device(),
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=default_pg):
default_pg=shard_pg):
model = model_builder(args.model_type)(checkpoint=True)
else:
with ColoInitContext(device=get_current_device()):
model = model_builder(args.model_type)(checkpoint=True)
pg = default_pg
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
tensor_parallelize(model, pg)
tensor_parallelize(model, tp_pg)
# build a Gemini model and a highly optimized cpu optimizer
# Gemini + ZeRO DP, Note it must be used after TP
model, optimizer = build_gemini(model, pg, args.placement)
model, optimizer = build_gemini(model, tp_pg, args.placement)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else: