mirror of https://github.com/hpcaitech/ColossalAI
[example] gpt, shard init on all processes (#2366)
parent
1f8ab6f1f5
commit
1aaeb596c6
|
@ -117,7 +117,7 @@ class ColoTensor(torch.Tensor):
|
|||
def set_process_group(self, pg: ProcessGroup):
|
||||
"""set_process_group
|
||||
change the pg of the ColoTensor. Note that the valid use cases is limited.
|
||||
Only existing pg is DP and dist spec is REPLICaTE is valid.
|
||||
It works for the target pg is DP and TP only and current dist spec of the Tensor is Replica.
|
||||
|
||||
Args:
|
||||
pg (ProcessGroup): target pg
|
||||
|
@ -127,10 +127,10 @@ class ColoTensor(torch.Tensor):
|
|||
# if the new pg is the same as the old pg, just returns
|
||||
if self.process_group == pg:
|
||||
return
|
||||
assert self.process_group.tp_world_size() == 1, \
|
||||
"Can not set_process_group on a ColoTensor whose process_group has tp world group"
|
||||
assert self.process_group.tp_world_size() == 1 or self.process_group.dp_world_size() == 1, \
|
||||
"Can not set_process_group on a ColoTensor whose process_group is both tp > 1 and world group > 1"
|
||||
assert self.dist_spec.placement.value == 'r', \
|
||||
"Can not set_process_group on a ColoTensor whose dist spec is not REPLICATE"
|
||||
"Can not set_process_group on a ColoTensor whose dist spec is not Replica"
|
||||
|
||||
self.process_group = pg
|
||||
|
||||
|
|
|
@ -148,10 +148,16 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
|||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
# NOTE() a param maybe shared by tow modules
|
||||
# NOTE() a param maybe shared by two modules
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
|
||||
# if shard init, then convert param to replica and use the dp-only ProcessGroup
|
||||
param: ColoParameter = param
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
param.set_process_group(pg)
|
||||
|
||||
# shard it w.r.t tp pattern
|
||||
if 'mlp.c_fc' in mn:
|
||||
if 'weight' in pn or 'bias' in pn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
|
@ -170,7 +176,6 @@ def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
|||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
|
||||
param.visited = True
|
||||
|
||||
|
||||
|
@ -248,27 +253,28 @@ def main():
|
|||
torch.manual_seed(123)
|
||||
if args.distplan == "colossalai":
|
||||
# all param must use the same process group.
|
||||
default_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
|
||||
world_size = torch.distributed.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size)
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
|
||||
|
||||
# build GPT model
|
||||
if version.parse(CAI_VERSION) > version.parse("0.1.10"):
|
||||
with ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_dist_spec=default_dist_spec,
|
||||
default_pg=default_pg):
|
||||
default_pg=shard_pg):
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
else:
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(args.model_type)(checkpoint=True)
|
||||
|
||||
pg = default_pg
|
||||
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
|
||||
# Tensor Parallelism (TP)
|
||||
tensor_parallelize(model, pg)
|
||||
tensor_parallelize(model, tp_pg)
|
||||
|
||||
# build a Gemini model and a highly optimized cpu optimizer
|
||||
# Gemini + ZeRO DP, Note it must be used after TP
|
||||
model, optimizer = build_gemini(model, pg, args.placement)
|
||||
model, optimizer = build_gemini(model, tp_pg, args.placement)
|
||||
|
||||
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue