[utils] optimize partition_tensor_parallel_state_dict (#1546)

pull/1548/head
ver217 2022-09-06 17:45:31 +08:00 committed by GitHub
parent d8a5aded19
commit 2bed096848
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 20 deletions

View File

@ -29,26 +29,37 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
partition_states: dict = dict()):
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode)
if gpc.get_local_rank(parallel_mode) == 0:
partitioned_state_list = [dict() for _ in range(depth)]
for key in list(state_dict.keys()):
param = state_dict.pop(key)
dim = dims.get(key, 0)
do_partition = partition_states.get(key, True)
if do_partition:
param = torch.chunk(param, depth, dim=dim)
for i, p in enumerate(partitioned_state_list):
p[key] = param[i] if do_partition else param
else:
partitioned_state_list = [None for _ in range(depth)]
partitioned_state = [None]
scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
return partitioned_state[0]
group = gpc.get_cpu_group(parallel_mode)
is_rank0 = gpc.get_local_rank(parallel_mode) == 0
partition_info = [None]
if is_rank0:
partition_info_dict = OrderedDict()
for key, param in state_dict.items():
dim = dims[key]
is_partitioned = partition_states[key]
shape = list(param.shape)
if is_partitioned:
shape[dim] = shape[dim] // depth
partition_info_dict[key] = (is_partitioned, param.dtype, shape, dim)
partition_info[0] = partition_info_dict
dist.broadcast_object_list(partition_info, src_rank, group=group)
partitioned_state = OrderedDict()
for key, (is_partitioned, dtype, shape, dim) in partition_info[0].items():
if is_partitioned:
output = torch.empty(shape, dtype=dtype)
if is_rank0:
scatter_list = [t.contiguous() for t in state_dict[key].chunk(depth, dim)]
else:
scatter_list = None
dist.scatter(output, scatter_list, src_rank, group=group)
else:
if is_rank0:
output = state_dict[key]
else:
output = torch.empty(shape, dtype=dtype)
dist.broadcast(output, src_rank, group=group)
partitioned_state[key] = output
return partitioned_state
def gather_tensor_parallel_state_dict(