mirror of https://github.com/hpcaitech/ColossalAI
[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)
parent
f4ef224358
commit
4b9bba8116
|
@ -13,34 +13,37 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
# beta * input + alpha * All-Reduce(Output) = res
|
||||
|
||||
mat1 = mat1.convert_to_dist_spec(
|
||||
distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]))
|
||||
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
# input
|
||||
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||
output = beta * input_tensor + alpha * output
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.spec.get_process_group())))
|
||||
output = ColoTensor.from_torch_tensor(output,
|
||||
spec=TensorSpec(distspec.replicate(mat2.tensor_spec.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
alpha: Number) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
parallel_action = mat2.spec.compute_spec
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
|
||||
compute_spec = mat2.tensor_spec.compute_spec
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.tensor_spec.get_process_group()))
|
||||
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
|
||||
|
||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
# TODO(jiaruifang) addam is special case
|
||||
# since gpt call view after the Op.
|
||||
return output.to_replicate()
|
||||
if compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
|
@ -64,14 +67,15 @@ def colo_addmm(input_tensor: GeneralTensor,
|
|||
|
||||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not mat2.has_spec(): # No Model Parallel Applied
|
||||
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
|
||||
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op'
|
||||
if not mat2.has_compute_spec(): # No Model Parallel Applied
|
||||
assert mat2.tensor_spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
|
||||
assert input_tensor.tensor_spec.is_gathered(), 'Invalid input spec for native addmm op'
|
||||
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
|
||||
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered():
|
||||
elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_gathered():
|
||||
mode = 'row'
|
||||
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()):
|
||||
elif mat2.tensor_spec.is_1D_col() and (input_tensor.tensor_spec.is_1D_col()
|
||||
or input_tensor.tensor_spec.is_1D_row()):
|
||||
mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -18,7 +18,7 @@ def register_elementwise_op(op):
|
|||
"""
|
||||
output = op(input_tensor, *args, **kwargs)
|
||||
if isinstance(input_tensor, ColoTensor):
|
||||
spec = copy(input_tensor.spec)
|
||||
spec = copy(input_tensor.tensor_spec)
|
||||
return ColoTensor.from_torch_tensor(output, spec=spec)
|
||||
return ColoTensor.from_torch_tensor(output)
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
weight,
|
||||
|
@ -27,10 +27,15 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
return output.to_replicate()
|
||||
|
||||
compute_spec = weight.tensor_spec.compute_spec
|
||||
if compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def colo_embedding_1Drow(input_tensor: ColoTensor,
|
||||
|
@ -43,7 +48,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
||||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
|
||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
num_embeddings_per_partition = weight.size(0)
|
||||
|
@ -70,7 +75,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
partial_output[input_mask, :] = 0.
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
|
||||
output = ColoTensor.from_torch_tensor(output,
|
||||
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
|
@ -108,8 +114,8 @@ def colo_embedding(input_tensor: GeneralTensor,
|
|||
|
||||
# Handle differen parallel actions.
|
||||
|
||||
if not weight.has_spec(): # No Model Parallel Applied
|
||||
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
||||
return ColoTensor.from_torch_tensor(
|
||||
F.embedding(input_tensor,
|
||||
weight,
|
||||
|
@ -118,10 +124,10 @@ def colo_embedding(input_tensor: GeneralTensor,
|
|||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse))
|
||||
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.spec.is_1D_row():
|
||||
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.tensor_spec.is_1D_row():
|
||||
mode = 'row'
|
||||
elif weight.spec.is_1D_col():
|
||||
elif weight.tensor_spec.is_1D_col():
|
||||
mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -19,7 +19,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
|||
padding_idx: Optional[int] = None) -> ColoTensor:
|
||||
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
|
||||
output_parallel = F.embedding_bag(input_tensor,
|
||||
weight,
|
||||
|
@ -33,11 +33,14 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
|||
include_last_offset=include_last_offset,
|
||||
padding_idx=padding_idx)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
return output.to_replicate()
|
||||
if weight.tensor_spec.compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def colo_embedding_bag_1d(tp_mode: str,
|
||||
|
@ -86,8 +89,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
|
|||
|
||||
# Handle differen parallel actions.
|
||||
|
||||
if not weight.has_spec(): # No Model Parallel Applied
|
||||
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op'
|
||||
return ColoTensor.from_torch_tensor(
|
||||
F.embedding_bag(input_tensor,
|
||||
weight,
|
||||
|
@ -100,8 +103,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
|
|||
per_sample_weights=per_sample_weights,
|
||||
include_last_offset=include_last_offset,
|
||||
padding_idx=padding_idx))
|
||||
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.spec.is_1D_col():
|
||||
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.tensor_spec.is_1D_col():
|
||||
tp_mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -17,8 +17,8 @@ def colo_layernorm(
|
|||
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||||
|
||||
# TODO (ver217): check dist spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.tensor_spec.get_process_group()))
|
||||
|
||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||
output = ColoTensor.from_torch_tensor(output, input_tensor.spec)
|
||||
output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec)
|
||||
return output
|
||||
|
|
|
@ -13,7 +13,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
input_tensor = input_tensor.convert_to_dist_spec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]))
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = F.linear(input_tensor, weight)
|
||||
|
@ -21,10 +21,11 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
# Bias
|
||||
if bias is not None:
|
||||
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||
output = output + bias
|
||||
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
|
||||
output = ColoTensor.from_torch_tensor(output,
|
||||
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
|
@ -32,17 +33,20 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
|
||||
# All-Gather(Output)
|
||||
# Input:B
|
||||
parallel_action = weight.spec.compute_spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
compute_spec = weight.tensor_spec.compute_spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
|
||||
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
output = ColoTensor.from_torch_tensor(output_parallel,
|
||||
spec=TensorSpec(
|
||||
distspec.shard(weight.spec.get_process_group(), [-1],
|
||||
[weight.spec.get_process_group_size()]),
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1],
|
||||
[weight.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D)))
|
||||
return output.to_replicate()
|
||||
if compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
else:
|
||||
return output
|
||||
|
||||
|
||||
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
|
||||
|
@ -62,14 +66,15 @@ def colo_linear_imp(input_tensor: GeneralTensor,
|
|||
|
||||
# Add communication logic before and after linear call.
|
||||
ret_tensor = None
|
||||
if not weight.has_spec(): # No Model Parallel Applied
|
||||
assert weight.spec.is_gathered(), 'Invalid weight spec for native Linear op'
|
||||
assert bias is None or bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native Linear op'
|
||||
assert bias is None or bias.tensor_spec.is_gathered(), 'Invalid bias spec for native Linear op'
|
||||
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
|
||||
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()):
|
||||
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||
if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_gathered()):
|
||||
mode = 'row'
|
||||
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()):
|
||||
elif weight.tensor_spec.is_1D_row() and (bias is None or bias.tensor_spec.is_1D_row()
|
||||
or bias.tensor_spec.is_1D_col()):
|
||||
mode = 'col'
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -18,7 +18,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
|
|||
label_smoothing: float = 0.0):
|
||||
input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight)))
|
||||
|
||||
if input_tensor.spec.is_gathered(): # Input is gathered
|
||||
if input_tensor.tensor_spec.is_gathered(): # Input is gathered
|
||||
output = F.cross_entropy(input_tensor,
|
||||
target,
|
||||
weight=weight,
|
||||
|
@ -28,8 +28,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
|
|||
reduction=reduction,
|
||||
label_smoothing=label_smoothing)
|
||||
return ColoTensor.from_torch_tensor(output)
|
||||
elif input_tensor.has_spec(): # Single Model Parallel Applied
|
||||
if input_tensor.spec.is_1D_col():
|
||||
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
|
||||
if input_tensor.tensor_spec.is_1D_col():
|
||||
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
|
||||
return ColoTensor.from_torch_tensor(output)
|
||||
else:
|
||||
|
|
|
@ -38,8 +38,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|||
param = module.get_parameter(param_name)
|
||||
if not isinstance(param, ColoParameter):
|
||||
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
|
||||
if param.has_spec():
|
||||
cur_compute_pattern = param.spec.compute_spec.compute_pattern
|
||||
if param.has_compute_spec():
|
||||
cur_compute_pattern = param.tensor_spec.compute_spec.compute_pattern
|
||||
if compute_pattern is None:
|
||||
compute_pattern = cur_compute_pattern
|
||||
else:
|
||||
|
@ -61,8 +61,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|||
cur_match = True
|
||||
for param_name, dist_spec in param_specs.items():
|
||||
param = module.get_parameter(param_name)
|
||||
if param.has_spec():
|
||||
if dist_spec != param.spec.dist_spec:
|
||||
if param.has_compute_spec():
|
||||
if dist_spec != param.tensor_spec.dist_spec:
|
||||
cur_match = False
|
||||
break
|
||||
else:
|
||||
|
@ -97,7 +97,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
|
|||
param = module.get_parameter(param_name)
|
||||
if isinstance(param, ColoParameter):
|
||||
spec = TensorSpec(dist_spec, parallel_action)
|
||||
param.set_spec(spec)
|
||||
param.set_tensor_spec(spec)
|
||||
for mod in param.shared_param_modules:
|
||||
modules_update_param.add(mod)
|
||||
for mod in modules_update_param:
|
||||
|
|
|
@ -82,7 +82,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec))
|
||||
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.tensor_spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
|
|
|
@ -57,15 +57,15 @@ class ColoTensor(torch.Tensor):
|
|||
self._graph_node = None
|
||||
|
||||
@property
|
||||
def spec(self) -> TensorSpec:
|
||||
def tensor_spec(self) -> TensorSpec:
|
||||
return self._tensor_spec
|
||||
|
||||
def set_spec(self, spec: TensorSpec) -> None:
|
||||
def set_tensor_spec(self, spec: TensorSpec) -> None:
|
||||
spec = copy(spec)
|
||||
self._convert_to_dist_spec(spec.dist_spec)
|
||||
self._tensor_spec = spec
|
||||
|
||||
def has_spec(self) -> bool:
|
||||
def has_compute_spec(self) -> bool:
|
||||
return self._tensor_spec.compute_spec is not None
|
||||
|
||||
def is_model_data(self) -> bool:
|
||||
|
@ -100,27 +100,27 @@ class ColoTensor(torch.Tensor):
|
|||
dist_spec (_DistSpec): the target dist. spec.
|
||||
"""
|
||||
with DistSpecManager.no_grad():
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
|
||||
self._tensor_spec.dist_spec = dist_spec
|
||||
|
||||
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
|
||||
tensor_spec = copy(self._tensor_spec)
|
||||
tensor_spec.dist_spec = dist_spec
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
|
||||
ret = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
|
||||
return ColoTensor.from_torch_tensor(ret, tensor_spec)
|
||||
|
||||
def to_replicate_(self):
|
||||
"""to_replicate_
|
||||
an inline member function, converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, distspec.replicate())
|
||||
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, distspec.replicate())
|
||||
self._tensor_spec.dist_spec = distspec.replicate()
|
||||
|
||||
def to_replicate(self) -> 'ColoTensor':
|
||||
"""to_replicate
|
||||
converting dist spec of the tensor to REPLICATE
|
||||
"""
|
||||
return self.convert_to_dist_spec(distspec.replicate(self.spec.get_process_group()))
|
||||
return self.convert_to_dist_spec(distspec.replicate(self.tensor_spec.get_process_group()))
|
||||
|
||||
@staticmethod
|
||||
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
|
||||
|
@ -134,16 +134,6 @@ class ColoTensor(torch.Tensor):
|
|||
else:
|
||||
with torch._C.DisableTorchFunction():
|
||||
data = self.data.clone()
|
||||
tensor = ColoTensor(data, spec=copy(self.spec))
|
||||
tensor = ColoTensor(data, spec=copy(self.tensor_spec))
|
||||
memo[id(self)] = tensor
|
||||
return tensor
|
||||
|
||||
# TODO(jiaruifang) a patch for gpt test.
|
||||
# We need to override the member function must operate on a replicated tensor
|
||||
# def view(self, *args, **kwargs):
|
||||
# self.data = DistSpecManager.handle_trans_spec(self,
|
||||
# self.spec.dist_spec,
|
||||
# distspec.replicate(self.spec.get_process_group()))
|
||||
# # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group())
|
||||
# self.data.view(*args, **kwargs)
|
||||
# return ColoTensor.from_torch_tensor(self.data)
|
||||
return tensor
|
|
@ -18,6 +18,8 @@ class ComputeSpec(object):
|
|||
def __init__(self, compute_pattern: ComputePattern) -> None:
|
||||
assert isinstance(compute_pattern, ComputePattern)
|
||||
self.compute_pattern = compute_pattern
|
||||
# Make sure output tensors are replicate
|
||||
self.output_replicate = True
|
||||
|
||||
def __repr__(self):
|
||||
return f'compute pattern: {self.compute_pattern}'
|
||||
|
|
|
@ -129,7 +129,7 @@ def _get_colo_tensors_info(*args) -> list:
|
|||
info = []
|
||||
for arg in args:
|
||||
if isinstance(arg, ColoTensor):
|
||||
info.append((arg.__class__, arg.spec))
|
||||
info.append((arg.__class__, arg.tensor_spec))
|
||||
else:
|
||||
info.append(None)
|
||||
return info
|
||||
|
|
|
@ -42,10 +42,10 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
|
|||
has_dist_parameter = False
|
||||
with torch.no_grad():
|
||||
for param in self.parameters():
|
||||
if isinstance(param, ColoParameter) and param.has_spec():
|
||||
if isinstance(param, ColoParameter) and param.has_compute_spec():
|
||||
has_dist_parameter = True
|
||||
mapping[id(param)] = copy(param.spec)
|
||||
param.set_spec(TensorSpec(distspec.replicate()))
|
||||
mapping[id(param)] = copy(param.tensor_spec)
|
||||
param.set_tensor_spec(TensorSpec(distspec.replicate()))
|
||||
|
||||
# TODO: fix when keep_vars = True
|
||||
# when keep_vars = False, the state_dict_func will call detach to create
|
||||
|
@ -62,7 +62,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
|
|||
param_id = id(param)
|
||||
if param_id in mapping:
|
||||
spec = mapping[id(param)]
|
||||
param.set_spec(spec)
|
||||
param.set_tensor_spec(spec)
|
||||
return ret
|
||||
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ def init_1d_row(weight, bias):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
|
@ -51,8 +51,8 @@ def init_1d_col(weight, bias):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
bias.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
|
@ -63,6 +63,7 @@ def run_with_spec(spec_init_func):
|
|||
x = torch.rand(2, 16).cuda()
|
||||
out = model(x)
|
||||
colo_out = torch.addmm(bias, x, weight)
|
||||
colo_out = colo_out.to_replicate()
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
|
|
|
@ -20,7 +20,7 @@ def init_1d_col(weight):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
|
|
|
@ -20,7 +20,7 @@ def init_1d_row(weight):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight):
|
||||
|
@ -28,7 +28,7 @@ def init_1d_col(weight):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
|
|
|
@ -22,7 +22,7 @@ def init_1d_row_spec(model):
|
|||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_spec(spec)
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model):
|
||||
|
@ -32,7 +32,7 @@ def init_1d_col_spec(model):
|
|||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_spec(spec)
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
|
|
|
@ -21,7 +21,7 @@ def init_1d_row(weight, bias):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
|
@ -29,8 +29,8 @@ def init_1d_col(weight, bias):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
bias.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
bias.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
|
|
|
@ -23,7 +23,7 @@ def init_1d_row_linear(weight):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight):
|
||||
|
@ -31,7 +31,7 @@ def init_1d_col_linear(weight):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight):
|
||||
|
@ -39,7 +39,7 @@ def init_1d_row_embedding(weight):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight):
|
||||
|
@ -47,7 +47,7 @@ def init_1d_col_embedding(weight):
|
|||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_spec(spec)
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_1d_hybrid_tp(model_name):
|
||||
|
|
|
@ -157,7 +157,7 @@ def run_check_shared_param():
|
|||
col_spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
model.cls.predictions.bias.set_spec(col_spec)
|
||||
model.cls.predictions.bias.set_tensor_spec(col_spec)
|
||||
try:
|
||||
check_colo_module(model.cls.predictions.decoder, recursive=False)
|
||||
except Exception as e:
|
||||
|
|
|
@ -36,10 +36,10 @@ def test_layernorm():
|
|||
|
||||
def check_spec_eq(tensor, other):
|
||||
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
|
||||
for k in dir(tensor.spec.dist_spec):
|
||||
for k in dir(tensor.tensor_spec.dist_spec):
|
||||
if not k.startswith('__'):
|
||||
assert hasattr(other.spec.dist_spec, k)
|
||||
assert getattr(tensor.spec.dist_spec, k) == getattr(other.spec.dist_spec, k)
|
||||
assert hasattr(other.tensor_spec.dist_spec, k)
|
||||
assert getattr(tensor.tensor_spec.dist_spec, k) == getattr(other.tensor_spec.dist_spec, k)
|
||||
|
||||
|
||||
def check_element_wise_ops():
|
||||
|
|
|
@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size):
|
|||
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
|
||||
tensor_spec = TensorSpec(shard_spec)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
assert t.shape == torch.Size((4 * world_size, 5))
|
||||
|
||||
|
||||
|
|
|
@ -51,7 +51,7 @@ def init_1d_row_spec(model):
|
|||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_spec(spec)
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model):
|
||||
|
@ -61,7 +61,7 @@ def init_1d_col_spec(model):
|
|||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_spec(spec)
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
@parameterize('use_chunk', [False, True])
|
||||
|
|
Loading…
Reference in New Issue