mirror of https://github.com/hpcaitech/ColossalAI
[utils] optimize partition_tensor_parallel_state_dict (#1546)
parent
d8a5aded19
commit
2bed096848
|
@ -29,26 +29,37 @@ def partition_tensor_parallel_state_dict(state_dict: OrderedDict,
|
|||
partition_states: dict = dict()):
|
||||
src_rank = gpc.get_ranks_in_group(parallel_mode)[0]
|
||||
depth = gpc.get_world_size(parallel_mode)
|
||||
|
||||
if gpc.get_local_rank(parallel_mode) == 0:
|
||||
|
||||
partitioned_state_list = [dict() for _ in range(depth)]
|
||||
|
||||
for key in list(state_dict.keys()):
|
||||
param = state_dict.pop(key)
|
||||
dim = dims.get(key, 0)
|
||||
do_partition = partition_states.get(key, True)
|
||||
if do_partition:
|
||||
param = torch.chunk(param, depth, dim=dim)
|
||||
for i, p in enumerate(partitioned_state_list):
|
||||
p[key] = param[i] if do_partition else param
|
||||
|
||||
else:
|
||||
partitioned_state_list = [None for _ in range(depth)]
|
||||
|
||||
partitioned_state = [None]
|
||||
scatter_object_list(partitioned_state, partitioned_state_list, src=src_rank, group=gpc.get_cpu_group(parallel_mode))
|
||||
return partitioned_state[0]
|
||||
group = gpc.get_cpu_group(parallel_mode)
|
||||
is_rank0 = gpc.get_local_rank(parallel_mode) == 0
|
||||
partition_info = [None]
|
||||
if is_rank0:
|
||||
partition_info_dict = OrderedDict()
|
||||
for key, param in state_dict.items():
|
||||
dim = dims[key]
|
||||
is_partitioned = partition_states[key]
|
||||
shape = list(param.shape)
|
||||
if is_partitioned:
|
||||
shape[dim] = shape[dim] // depth
|
||||
partition_info_dict[key] = (is_partitioned, param.dtype, shape, dim)
|
||||
partition_info[0] = partition_info_dict
|
||||
dist.broadcast_object_list(partition_info, src_rank, group=group)
|
||||
partitioned_state = OrderedDict()
|
||||
for key, (is_partitioned, dtype, shape, dim) in partition_info[0].items():
|
||||
if is_partitioned:
|
||||
output = torch.empty(shape, dtype=dtype)
|
||||
if is_rank0:
|
||||
scatter_list = [t.contiguous() for t in state_dict[key].chunk(depth, dim)]
|
||||
else:
|
||||
scatter_list = None
|
||||
dist.scatter(output, scatter_list, src_rank, group=group)
|
||||
else:
|
||||
if is_rank0:
|
||||
output = state_dict[key]
|
||||
else:
|
||||
output = torch.empty(shape, dtype=dtype)
|
||||
dist.broadcast(output, src_rank, group=group)
|
||||
partitioned_state[key] = output
|
||||
return partitioned_state
|
||||
|
||||
|
||||
def gather_tensor_parallel_state_dict(
|
||||
|
|
Loading…
Reference in New Issue