mirror of https://github.com/hpcaitech/ColossalAI
[examples]adding tp to PaLM (#2319)
parent
9c9246c0d9
commit
f7fd592bf4
|
@ -104,6 +104,48 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
|
|||
raise NotImplemented(f"CAI version {cai_version} is not supported")
|
||||
return model
|
||||
|
||||
## Parameter Sharding Strategies for Tensor Parallelism
|
||||
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
||||
# Tensor Parallel
|
||||
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
|
||||
"""tensor_parallelize
|
||||
Sharding the Model Parameters.
|
||||
Args:
|
||||
model (torch.nn.Module): a torch module to be sharded
|
||||
"""
|
||||
for mn, module in model.named_modules():
|
||||
for pn, param in module.named_parameters(recurse=False):
|
||||
if hasattr(param, 'visited'):
|
||||
continue
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
if 'net.0' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
elif 'to_q' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
elif 'to_kv' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif 'to_out' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
elif '1.1' in mn:
|
||||
split_param_col_tp1d(param, pg) # colmn slice
|
||||
elif '1.2' in mn:
|
||||
split_param_row_tp1d(param, pg) # row slice
|
||||
else:
|
||||
param.set_dist_spec(ReplicaSpec())
|
||||
|
||||
param.visited = True
|
||||
|
||||
|
||||
args = parse_args()
|
||||
if args.distplan not in ["colossalai", "pytorch"]:
|
||||
|
@ -150,7 +192,7 @@ if args.distplan == "colossalai":
|
|||
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
|
||||
|
||||
pg = default_pg
|
||||
#tensor_parallelize(model, pg)
|
||||
tensor_parallelize(model, pg)
|
||||
model = gemini_zero_dpp(model, pg, args.placement)
|
||||
|
||||
#optimizer
|
||||
|
|
Loading…
Reference in New Issue