[utils] optimize partition_tensor_parallel_state_dict (#1546)

pull/1548/head
ver217 2022-09-06 17:45:31 +08:00 committed by GitHub
parent d8a5aded19
commit 2bed096848
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 31 additions and 20 deletions

View File

@ -29,26 +29,37 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
partition_states: dict = dict()): partition_states: dict = dict()):
src_rank = gpc.get_ranks_in_group(parallel_mode)[0] src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
group = gpc.get_cpu_group(parallel_mode)
if gpc.get_local_rank(parallel_mode) == 0: is_rank0 = gpc.get_local_rank(parallel_mode) == 0
partition_info = [None]
partitioned_state_list = [dict() for _ in range(depth)] if is_rank0:
partition_info_dict = OrderedDict()
for key in list(state_dict.keys()): for key, param in state_dict.items():
param = state_dict.pop(key) dim = dims[key]
dim = dims.get(key, 0) is_partitioned = partition_states[key]
do_partition = partition_states.get(key, True) shape = list(param.shape)
if do_partition: if is_partitioned:
param = torch.chunk(param, depth, dim=dim) shape[dim] = shape[dim] // depth
for i, p in enumerate(partitioned_state_list): partition_info_dict[key] = (is_partitioned, param.dtype, shape, dim)
p[key] = param[i] if do_partition else param partition_info[0] = partition_info_dict
dist.broadcast_object_list(partition_info, src_rank, group=group)
else: partitioned_state = OrderedDict()
partitioned_state_list = [None for _ in range(depth)] for key, (is_partitioned, dtype, shape, dim) in partition_info[0].items():
if is_partitioned:
partitioned_state = [None] output = torch.empty(shape, dtype=dtype)
scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode)) if is_rank0:
return partitioned_state[0] scatter_list = [t.contiguous() for t in state_dict[key].chunk(depth, dim)]
else:
scatter_list = None
dist.scatter(output, scatter_list, src_rank, group=group)
else:
if is_rank0:
output = state_dict[key]
else:
output = torch.empty(shape, dtype=dtype)
dist.broadcast(output, src_rank, group=group)
partitioned_state[key] = output
return partitioned_state
def gather_tensor_parallel_state_dict( def gather_tensor_parallel_state_dict(