[rename] convert_to_dist -> redistribute (#1243)

pull/1245/head^2
Jiarui Fang 2022-07-11 13:05:44 +08:00 committed by GitHub
parent f6add9b720
commit 2699dfbbfd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 18 additions and 18 deletions

View File

@ -11,7 +11,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# mat1:S[1] x mat2:S[0] = Output:P
# beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.convert_to_dist_spec(distspec.shard([-1], [mat2.get_tp_world_size()]))
mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()]))
# Output:P
partial_output = torch.mm(mat1, mat2)
@ -28,7 +28,7 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
compute_spec = mat2.compute_spec
mat1 = mat1.convert_to_dist_spec(distspec.replicate())
mat1 = mat1.redistribute(distspec.replicate())
mat1 = reduce_grad(mat1, mat1.get_process_group())
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)

View File

@ -14,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
input_tensor = input_tensor.redistribute(distspec.replicate())
output_parallel = F.embedding(input_tensor,
weight,
@ -46,7 +46,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# Find index in this shard and mask those not here
# Reduce all
pg = weight.get_process_group()
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
input_tensor = input_tensor.redistribute(distspec.replicate())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = weight.get_process_group().tp_local_rank()

View File

@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
pg = weight.get_process_group()
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
input_tensor = input_tensor.redistribute(distspec.replicate())
output_parallel = F.embedding_bag(input_tensor,
weight,

View File

@ -16,7 +16,7 @@ def colo_layernorm(
assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
bias = convert_to_colo_tensor(bias, weight.get_process_group())
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
input_tensor = input_tensor.redistribute(distspec.replicate())
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))

View File

@ -12,7 +12,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res
# Input:S[1]
pg = weight.get_process_group()
input_tensor = input_tensor.convert_to_dist_spec(distspec.shard([-1], [weight.get_tp_world_size()]))
input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()]))
# Output:P
partial_output = F.linear(input_tensor, weight)
@ -33,7 +33,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Gather(Output)
# Input:B
compute_spec = weight.compute_spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
input_tensor = input_tensor.redistribute(distspec.replicate())
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
output_parallel = F.linear(input_parallel, weight, bias)

View File

@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
Tensor._base.__get__,
Tensor.grad.__get__,
Tensor._grad.__get__,
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
}
@ -140,7 +140,7 @@ class ColoTensor(torch.Tensor):
"""
assert isinstance(dist_spec, _DistSpec)
assert self.process_group is not None
self._convert_to_dist_spec(dist_spec)
self._redistribute(dist_spec)
def set_tensor_spec(self, dist_spec, compute_spec):
if dist_spec:
@ -174,8 +174,8 @@ class ColoTensor(torch.Tensor):
def __repr__(self):
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
"""_convert_to_dist_spec
def _redistribute(self, dist_spec: _DistSpec) -> None:
"""_redistribute
Note the function will not handle the logic of backward propagation!
It is used during model tensor initializations as an internal function.
Args:
@ -186,7 +186,7 @@ class ColoTensor(torch.Tensor):
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
self.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
@ -194,13 +194,13 @@ class ColoTensor(torch.Tensor):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self._convert_to_dist_spec(dist_spec=distspec.replicate())
self._redistribute(dist_spec=distspec.replicate())
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.convert_to_dist_spec(distspec.replicate())
return self.redistribute(distspec.replicate())
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
@ -234,7 +234,7 @@ class ColoTensor(torch.Tensor):
"""
if self.is_replicate():
return super().view(*args)
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
replicated_t = self.redistribute(dist_spec=distspec.replicate())
return replicated_t.view(*args)
def size_global(self, args: Optional[int] = None):
@ -280,4 +280,4 @@ class ColoTensor(torch.Tensor):
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
def is_sharded(self):
return self.dist_spec.placement == DistPlacementPattern.SHARD
return self.dist_spec.placement == DistPlacementPattern.SHARD

View File

@ -22,7 +22,7 @@ def check_cross_entropy():
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
input_shard = input_t_colo.convert_to_dist_spec(distspec.shard([-1], [pg.tp_world_size()]))
input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()]))
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
output = F.cross_entropy(input_t, target)