[examples]adding tp to PaLM (#2319)

pull/2298/head^2
ZijianYY 2023-01-05 17:57:50 +08:00 committed by GitHub
parent 9c9246c0d9
commit f7fd592bf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 1 deletions

View File

@ -104,6 +104,48 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'net.0' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'to_q' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'to_kv' in mn:
split_param_row_tp1d(param, pg) # row slice
elif 'to_out' in mn:
split_param_row_tp1d(param, pg) # row slice
elif '1.1' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif '1.2' in mn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]:
@ -150,7 +192,7 @@ if args.distplan == "colossalai":
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
#tensor_parallelize(model, pg)
tensor_parallelize(model, pg)
model = gemini_zero_dpp(model, pg, args.placement)
#optimizer