mirror of https://github.com/hpcaitech/ColossalAI
remove gather out in parallel action (#1163)
parent
51f1ec96b0
commit
177c374401
|
@ -37,10 +37,10 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
||||||
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
|
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
|
||||||
ParallelAction(ComputePattern.TP1D))
|
ParallelAction(ComputePattern.TP1D))
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||||
if parallel_action.gather_out:
|
|
||||||
# All-Gather(Output)
|
# TODO(jiaruifang) addam is special case
|
||||||
output = output.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
|
# since gpt call view after the Op.
|
||||||
return output
|
return output.to_replicate()
|
||||||
|
|
||||||
|
|
||||||
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||||
|
@ -62,11 +62,6 @@ def colo_addmm(input_tensor: GeneralTensor,
|
||||||
"""
|
"""
|
||||||
input_tensor, mat1, mat2 = tuple(map(convert_to_colo_tensor, (input_tensor, mat1, mat2)))
|
input_tensor, mat1, mat2 = tuple(map(convert_to_colo_tensor, (input_tensor, mat1, mat2)))
|
||||||
|
|
||||||
# building the computing graph, inputs -> op
|
|
||||||
# if GraphGlobalEnv().graph_building:
|
|
||||||
# cur_op_node = GraphOpNode('linear', [weight, bias])
|
|
||||||
# cur_op_node.add_prev_tensor(input_tensor)
|
|
||||||
|
|
||||||
# Add communication logic before and after linear call.
|
# Add communication logic before and after linear call.
|
||||||
ret_tensor = None
|
ret_tensor = None
|
||||||
if not mat2.has_spec(): # No Model Parallel Applied
|
if not mat2.has_spec(): # No Model Parallel Applied
|
||||||
|
@ -84,8 +79,4 @@ def colo_addmm(input_tensor: GeneralTensor,
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# building the computing graph, op -> output
|
|
||||||
# if GraphGlobalEnv().graph_building:
|
|
||||||
# cur_op_node.add_post_tensor(ret_tensor)
|
|
||||||
|
|
||||||
return ret_tensor
|
return ret_tensor
|
||||||
|
|
|
@ -30,9 +30,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||||
ParallelAction(ComputePattern.TP1D))
|
ParallelAction(ComputePattern.TP1D))
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||||
if weight.spec.parallel_action.gather_out:
|
return output.to_replicate()
|
||||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_embedding_1Drow(input_tensor: ColoTensor,
|
def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||||
|
|
|
@ -36,9 +36,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
||||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||||
ParallelAction(ComputePattern.TP1D))
|
ParallelAction(ComputePattern.TP1D))
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||||
if weight.spec.parallel_action.gather_out:
|
|
||||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
return output.to_replicate()
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_embedding_bag_1d(tp_mode: str,
|
def colo_embedding_bag_1d(tp_mode: str,
|
||||||
|
|
|
@ -42,10 +42,7 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
||||||
distspec.shard(weight.spec.get_process_group(), [-1],
|
distspec.shard(weight.spec.get_process_group(), [-1],
|
||||||
[weight.spec.get_process_group_size()]),
|
[weight.spec.get_process_group_size()]),
|
||||||
ParallelAction(ComputePattern.TP1D)))
|
ParallelAction(ComputePattern.TP1D)))
|
||||||
if parallel_action.gather_out:
|
return output.to_replicate()
|
||||||
# All-Gather(Output)
|
|
||||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||||
|
|
|
@ -92,10 +92,13 @@ class ColoTensor(torch.Tensor):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'ColoTensor: {super().__repr__()}'
|
return f'ColoTensor: {super().__repr__()}'
|
||||||
|
|
||||||
def is_model_data(self) -> bool:
|
|
||||||
return self._type == TensorType.MODEL
|
|
||||||
|
|
||||||
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
def _convert_to_dist_spec(self, dist_spec: _DistSpec) -> None:
|
||||||
|
"""_convert_to_dist_spec
|
||||||
|
Note the function will not handle the logic of backward propagation!
|
||||||
|
It is used during model tensor initializations as an internal function.
|
||||||
|
Args:
|
||||||
|
dist_spec (_DistSpec): the target dist. spec.
|
||||||
|
"""
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||||
self._tensor_spec.dist_spec = dist_spec
|
self._tensor_spec.dist_spec = dist_spec
|
||||||
|
@ -106,6 +109,19 @@ class ColoTensor(torch.Tensor):
|
||||||
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||||
return ColoTensor.from_torch_tensor(ret, tensor_spec)
|
return ColoTensor.from_torch_tensor(ret, tensor_spec)
|
||||||
|
|
||||||
|
def to_replicate_(self):
|
||||||
|
"""to_replicate_
|
||||||
|
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||||
|
"""
|
||||||
|
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, distspec.replicate())
|
||||||
|
self._tensor_spec.dist_spec = distspec.replicate()
|
||||||
|
|
||||||
|
def to_replicate(self) -> 'ColoTensor':
|
||||||
|
"""to_replicate
|
||||||
|
converting dist spec of the tensor to REPLICATE
|
||||||
|
"""
|
||||||
|
return self.convert_to_dist_spec(distspec.replicate(self.spec.get_process_group()))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||||
tensor = tensor.as_subclass(ColoTensor)
|
tensor = tensor.as_subclass(ColoTensor)
|
||||||
|
@ -121,3 +137,13 @@ class ColoTensor(torch.Tensor):
|
||||||
tensor = ColoTensor(data, spec=copy(self.spec))
|
tensor = ColoTensor(data, spec=copy(self.spec))
|
||||||
memo[id(self)] = tensor
|
memo[id(self)] = tensor
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
# TODO(jiaruifang) a patch for gpt test.
|
||||||
|
# We need to override the member function must operate on a replicated tensor
|
||||||
|
# def view(self, *args, **kwargs):
|
||||||
|
# self.data = DistSpecManager.handle_trans_spec(self,
|
||||||
|
# self.spec.dist_spec,
|
||||||
|
# distspec.replicate(self.spec.get_process_group()))
|
||||||
|
# # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group())
|
||||||
|
# self.data.view(*args, **kwargs)
|
||||||
|
# return ColoTensor.from_torch_tensor(self.data)
|
||||||
|
|
|
@ -13,13 +13,12 @@ class ComputePattern(Enum):
|
||||||
|
|
||||||
class ParallelAction(object):
|
class ParallelAction(object):
|
||||||
|
|
||||||
def __init__(self, compute_pattern: ComputePattern, gather_out: bool = True) -> None:
|
def __init__(self, compute_pattern: ComputePattern) -> None:
|
||||||
assert isinstance(compute_pattern, ComputePattern)
|
assert isinstance(compute_pattern, ComputePattern)
|
||||||
self.compute_pattern = compute_pattern
|
self.compute_pattern = compute_pattern
|
||||||
self.gather_out = gather_out
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f'compute pattern: {self.compute_pattern}, gather out: {self.gather_out}'
|
return f'compute pattern: {self.compute_pattern}'
|
||||||
|
|
||||||
|
|
||||||
class TensorSpec(object):
|
class TensorSpec(object):
|
||||||
|
|
|
@ -41,6 +41,7 @@ def run_with_spec(spec_init_func):
|
||||||
x = torch.rand(2, 4).cuda()
|
x = torch.rand(2, 4).cuda()
|
||||||
out = model(x)
|
out = model(x)
|
||||||
colo_out = F.linear(x, weight, bias)
|
colo_out = F.linear(x, weight, bias)
|
||||||
|
colo_out = colo_out.to_replicate()
|
||||||
assert tensor_equal(out, colo_out)
|
assert tensor_equal(out, colo_out)
|
||||||
grad = torch.rand_like(out)
|
grad = torch.rand_like(out)
|
||||||
out.backward(grad)
|
out.backward(grad)
|
||||||
|
|
|
@ -26,10 +26,10 @@ def init_1d_row_linear(weight):
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
|
||||||
def init_1d_col_linear(weight, gather_out=True):
|
def init_1d_col_linear(weight):
|
||||||
spec = TensorSpec(
|
spec = TensorSpec(
|
||||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
ParallelAction(ComputePattern.TP1D, gather_out=gather_out))
|
ParallelAction(ComputePattern.TP1D))
|
||||||
with DistSpecManager.no_grad():
|
with DistSpecManager.no_grad():
|
||||||
weight.set_spec(spec)
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
@ -98,7 +98,7 @@ def run_1d_hybrid_tp(model_name):
|
||||||
if 'proj2' in name and 'weight' in name:
|
if 'proj2' in name and 'weight' in name:
|
||||||
init_1d_row_linear(p)
|
init_1d_row_linear(p)
|
||||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||||
init_1d_col_linear(p, gather_out=False)
|
init_1d_col_linear(p)
|
||||||
|
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||||
|
|
Loading…
Reference in New Issue