mirror of https://github.com/hpcaitech/ColossalAI
[rename] convert_to_dist -> redistribute (#1243)
parent
f6add9b720
commit
2699dfbbfd
|
@ -11,7 +11,7 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
# mat1:S[1] x mat2:S[0] = Output:P
|
||||
# beta * input + alpha * All-Reduce(Output) = res
|
||||
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.shard([-1], [mat2.get_tp_world_size()]))
|
||||
mat1 = mat1.redistribute(distspec.shard([-1], [mat2.get_tp_world_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
|
@ -28,7 +28,7 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
alpha: Number) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
compute_spec = mat2.compute_spec
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate())
|
||||
mat1 = mat1.redistribute(distspec.replicate())
|
||||
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
||||
|
||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||
|
|
|
@ -14,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
weight,
|
||||
|
@ -46,7 +46,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
pg = weight.get_process_group()
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
|
||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
tensor_parallel_rank = weight.get_process_group().tp_local_rank()
|
||||
|
|
|
@ -20,7 +20,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
|||
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
pg = weight.get_process_group()
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
|
||||
output_parallel = F.embedding_bag(input_tensor,
|
||||
weight,
|
||||
|
|
|
@ -16,7 +16,7 @@ def colo_layernorm(
|
|||
assert isinstance(weight, ColoTensor)
|
||||
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
||||
bias = convert_to_colo_tensor(bias, weight.get_process_group())
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
|
||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||
output = ColoTensor.from_torch_tensor(output, ColoTensorSpec(input_tensor.get_process_group()))
|
||||
|
|
|
@ -12,7 +12,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
pg = weight.get_process_group()
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.shard([-1], [weight.get_tp_world_size()]))
|
||||
input_tensor = input_tensor.redistribute(distspec.shard([-1], [weight.get_tp_world_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = F.linear(input_tensor, weight)
|
||||
|
@ -33,7 +33,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# All-Gather(Output)
|
||||
# Input:B
|
||||
compute_spec = weight.compute_spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate())
|
||||
input_tensor = input_tensor.redistribute(distspec.replicate())
|
||||
input_parallel = reduce_grad(input_tensor, weight.get_process_group())
|
||||
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
|
|
|
@ -18,7 +18,7 @@ def _get_my_nowrap_functions() -> Set[Callable]:
|
|||
Tensor._base.__get__,
|
||||
Tensor.grad.__get__,
|
||||
Tensor._grad.__get__,
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
Tensor.data.__get__, # make .data returns torch.Tensor rather than ColoTensor
|
||||
}
|
||||
|
||||
|
||||
|
@ -140,7 +140,7 @@ class ColoTensor(torch.Tensor):
|
|||
"""
|
||||
assert isinstance(dist_spec, _DistSpec)
|
||||
assert self.process_group is not None
|
||||
self._convert_to_dist_spec(dist_spec)
|
||||
self._redistribute(dist_spec)
|
||||
|
||||
def set_tensor_spec(self, dist_spec, compute_spec):
|
||||
if dist_spec:
|
||||
|
@ -174,8 +174,8 @@ class ColoTensor(torch.Tensor):
|
|||
def __repr__(self):
|
||||
return f'ColoTensor:\n{super().__repr__()}\n{self.dist_spec}\n{self.process_group}'
|
||||
|
||||
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
||||
"""_convert_to_dist_spec
|
||||
def _redistribute(self, dist_spec: _DistSpec) -> None:
|
||||
"""_redistribute
|
||||
Note the function will not handle the logic of backward propagation!
|
||||
It is used during model tensor initializations as an internal function.
|
||||
Args:
|
||||
|
@ -186,7 +186,7 @@ class ColoTensor(torch.Tensor):
|
|||
self.data = DistSpecManager.handle_trans_spec(self.data, self.dist_spec, dist_spec, self.process_group)
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
def redistribute(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.dist_spec, dist_spec, self.process_group)
|
||||
return ColoTensor.from_torch_tensor(ret, ColoTensorSpec(self.process_group, dist_attr=dist_spec))
|
||||
|
||||
|
@ -194,13 +194,13 @@ class ColoTensor(torch.Tensor):
|
|||
"""to_replicate_
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self._convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||
self._redistribute(dist_spec=distspec.replicate())
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
return self.convert_to_dist_spec(distspec.replicate())
|
||||
return self.redistribute(distspec.replicate())
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: Optional[ColoTensorSpec] = None) -> 'ColoTensor':
|
||||
|
@ -234,7 +234,7 @@ class ColoTensor(torch.Tensor):
|
|||
"""
|
||||
if self.is_replicate():
|
||||
return super().view(*args)
|
||||
replicated_t = self.convert_to_dist_spec(dist_spec=distspec.replicate())
|
||||
replicated_t = self.redistribute(dist_spec=distspec.replicate())
|
||||
return replicated_t.view(*args)
|
||||
|
||||
def size_global(self, args: Optional[int] = None):
|
||||
|
@ -280,4 +280,4 @@ class ColoTensor(torch.Tensor):
|
|||
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == 0
|
||||
|
||||
def is_sharded(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD
|
||||
|
|
|
@ -22,7 +22,7 @@ def check_cross_entropy():
|
|||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
input_t_colo = ColoTensor.from_torch_tensor(tensor=input_ct, spec=ColoTensorSpec(pg))
|
||||
input_shard = input_t_colo.convert_to_dist_spec(distspec.shard([-1], [pg.tp_world_size()]))
|
||||
input_shard = input_t_colo.redistribute(distspec.shard([-1], [pg.tp_world_size()]))
|
||||
input_shard.set_tensor_spec(dist_spec=None, compute_spec=ComputeSpec(ComputePattern.TP1D))
|
||||
|
||||
output = F.cross_entropy(input_t, target)
|
||||
|
|
Loading…
Reference in New Issue