[ColoTensor] rename APIs and add output_replicate to ComputeSpec (#1168)

pull/1174/head
Jiarui Fang 2022-06-24 13:08:54 +08:00 committed by GitHub
parent f4ef224358
commit 4b9bba8116
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 116 additions and 105 deletions

View File

@ -13,34 +13,37 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
# beta * input + alpha * All-Reduce(Output) = res
mat1 = mat1.convert_to_dist_spec(
distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]))
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]))
# Output:P
partial_output = torch.mm(mat1, mat2)
# Reduce(Output)
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# input
assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op'
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
output = beta * input_tensor + alpha * output
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.spec.get_process_group())))
output = ColoTensor.from_torch_tensor(output,
spec=TensorSpec(distspec.replicate(mat2.tensor_spec.get_process_group())))
return output
def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
alpha: Number) -> ColoTensor:
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
parallel_action = mat2.spec.compute_spec
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group()))
compute_spec = mat2.tensor_spec.compute_spec
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.tensor_spec.get_process_group()))
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]),
ComputeSpec(ComputePattern.TP1D))
output_spec = TensorSpec(
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
# TODO(jiaruifang) addam is special case
# since gpt call view after the Op.
return output.to_replicate()
if compute_spec.output_replicate:
return output.to_replicate()
else:
return output
def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
@ -64,14 +67,15 @@ def colo_addmm(input_tensor: GeneralTensor,
# Add communication logic before and after linear call.
ret_tensor = None
if not mat2.has_spec(): # No Model Parallel Applied
assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op'
if not mat2.has_compute_spec(): # No Model Parallel Applied
assert mat2.tensor_spec.is_gathered(), 'Invalid mat2 spec for native addmm op'
assert input_tensor.tensor_spec.is_gathered(), 'Invalid input spec for native addmm op'
ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha))
elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered():
elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_gathered():
mode = 'row'
elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()):
elif mat2.tensor_spec.is_1D_col() and (input_tensor.tensor_spec.is_1D_col()
or input_tensor.tensor_spec.is_1D_row()):
mode = 'col'
else:
raise NotImplementedError

View File

@ -18,7 +18,7 @@ def register_elementwise_op(op):
"""
output = op(input_tensor, *args, **kwargs)
if isinstance(input_tensor, ColoTensor):
spec = copy(input_tensor.spec)
spec = copy(input_tensor.tensor_spec)
return ColoTensor.from_torch_tensor(output, spec=spec)
return ColoTensor.from_torch_tensor(output)

View File

@ -17,7 +17,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
sparse: bool = False) -> ColoTensor:
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
output_parallel = F.embedding(input_tensor,
weight,
@ -27,10 +27,15 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
output_spec = TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
return output.to_replicate()
compute_spec = weight.tensor_spec.compute_spec
if compute_spec.output_replicate:
return output.to_replicate()
else:
return output
def colo_embedding_1Drow(input_tensor: ColoTensor,
@ -43,7 +48,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
num_embeddings_per_partition = weight.size(0)
@ -70,7 +75,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
partial_output[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
output = ColoTensor.from_torch_tensor(output,
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
return output
@ -108,8 +114,8 @@ def colo_embedding(input_tensor: GeneralTensor,
# Handle differen parallel actions.
if not weight.has_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
F.embedding(input_tensor,
weight,
@ -118,10 +124,10 @@ def colo_embedding(input_tensor: GeneralTensor,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_row():
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.tensor_spec.is_1D_row():
mode = 'row'
elif weight.spec.is_1D_col():
elif weight.tensor_spec.is_1D_col():
mode = 'col'
else:
raise NotImplementedError

View File

@ -19,7 +19,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
padding_idx: Optional[int] = None) -> ColoTensor:
# embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
# Gather splitted lookup table
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
output_parallel = F.embedding_bag(input_tensor,
weight,
@ -33,11 +33,14 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
include_last_offset=include_last_offset,
padding_idx=padding_idx)
output_spec = TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
ComputeSpec(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
return output.to_replicate()
if weight.tensor_spec.compute_spec.output_replicate:
return output.to_replicate()
else:
return output
def colo_embedding_bag_1d(tp_mode: str,
@ -86,8 +89,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
# Handle differen parallel actions.
if not weight.has_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op'
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
F.embedding_bag(input_tensor,
weight,
@ -100,8 +103,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
per_sample_weights=per_sample_weights,
include_last_offset=include_last_offset,
padding_idx=padding_idx))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_col():
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.tensor_spec.is_1D_col():
tp_mode = 'col'
else:
raise NotImplementedError

View File

@ -17,8 +17,8 @@ def colo_layernorm(
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
# TODO (ver217): check dist spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.spec.get_process_group()))
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.tensor_spec.get_process_group()))
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
output = ColoTensor.from_torch_tensor(output, input_tensor.spec)
output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec)
return output

View File

@ -13,7 +13,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# All-Reduce(Output) + bias = res
# Input:S[1]
input_tensor = input_tensor.convert_to_dist_spec(
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]))
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]))
# Output:P
partial_output = F.linear(input_tensor, weight)
@ -21,10 +21,11 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
# Bias
if bias is not None:
assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op'
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
output = output + bias
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group())))
output = ColoTensor.from_torch_tensor(output,
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
return output
@ -32,17 +33,20 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
# Input:B x Weight:S[1] + Bias:S[1] = Output:S[1]
# All-Gather(Output)
# Input:B
parallel_action = weight.spec.compute_spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
compute_spec = weight.tensor_spec.compute_spec
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
output_parallel = F.linear(input_parallel, weight, bias)
output = ColoTensor.from_torch_tensor(output_parallel,
spec=TensorSpec(
distspec.shard(weight.spec.get_process_group(), [-1],
[weight.spec.get_process_group_size()]),
distspec.shard(weight.tensor_spec.get_process_group(), [-1],
[weight.tensor_spec.get_process_group_size()]),
ComputeSpec(ComputePattern.TP1D)))
return output.to_replicate()
if compute_spec.output_replicate:
return output.to_replicate()
else:
return output
def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor':
@ -62,14 +66,15 @@ def colo_linear_imp(input_tensor: GeneralTensor,
# Add communication logic before and after linear call.
ret_tensor = None
if not weight.has_spec(): # No Model Parallel Applied
assert weight.spec.is_gathered(), 'Invalid weight spec for native Linear op'
assert bias is None or bias.spec.is_gathered(), 'Invalid bias spec for native Linear op'
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native Linear op'
assert bias is None or bias.tensor_spec.is_gathered(), 'Invalid bias spec for native Linear op'
ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias))
elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()):
elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_gathered()):
mode = 'row'
elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()):
elif weight.tensor_spec.is_1D_row() and (bias is None or bias.tensor_spec.is_1D_row()
or bias.tensor_spec.is_1D_col()):
mode = 'col'
else:
raise NotImplementedError

View File

@ -18,7 +18,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
label_smoothing: float = 0.0):
input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight)))
if input_tensor.spec.is_gathered(): # Input is gathered
if input_tensor.tensor_spec.is_gathered(): # Input is gathered
output = F.cross_entropy(input_tensor,
target,
weight=weight,
@ -28,8 +28,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor,
reduction=reduction,
label_smoothing=label_smoothing)
return ColoTensor.from_torch_tensor(output)
elif input_tensor.has_spec(): # Single Model Parallel Applied
if input_tensor.spec.is_1D_col():
elif input_tensor.has_compute_spec(): # Single Model Parallel Applied
if input_tensor.tensor_spec.is_1D_col():
output = VocabParallelCrossEntropyLoss1D()(input_tensor, target)
return ColoTensor.from_torch_tensor(output)
else:

View File

@ -38,8 +38,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
param = module.get_parameter(param_name)
if not isinstance(param, ColoParameter):
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
if param.has_spec():
cur_compute_pattern = param.spec.compute_spec.compute_pattern
if param.has_compute_spec():
cur_compute_pattern = param.tensor_spec.compute_spec.compute_pattern
if compute_pattern is None:
compute_pattern = cur_compute_pattern
else:
@ -61,8 +61,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
cur_match = True
for param_name, dist_spec in param_specs.items():
param = module.get_parameter(param_name)
if param.has_spec():
if dist_spec != param.spec.dist_spec:
if param.has_compute_spec():
if dist_spec != param.tensor_spec.dist_spec:
cur_match = False
break
else:
@ -97,7 +97,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu
param = module.get_parameter(param_name)
if isinstance(param, ColoParameter):
spec = TensorSpec(dist_spec, parallel_action)
param.set_spec(spec)
param.set_tensor_spec(spec)
for mod in param.shared_param_modules:
modules_update_param.add(mod)
for mod in modules_update_param:

View File

@ -82,7 +82,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec))
tensor = ColoParameter(data, self.requires_grad, spec=copy(self.tensor_spec))
memo[id(self)] = tensor
return tensor

View File

@ -57,15 +57,15 @@ class ColoTensor(torch.Tensor):
self._graph_node = None
@property
def spec(self) -> TensorSpec:
def tensor_spec(self) -> TensorSpec:
return self._tensor_spec
def set_spec(self, spec: TensorSpec) -> None:
def set_tensor_spec(self, spec: TensorSpec) -> None:
spec = copy(spec)
self._convert_to_dist_spec(spec.dist_spec)
self._tensor_spec = spec
def has_spec(self) -> bool:
def has_compute_spec(self) -> bool:
return self._tensor_spec.compute_spec is not None
def is_model_data(self) -> bool:
@ -100,27 +100,27 @@ class ColoTensor(torch.Tensor):
dist_spec (_DistSpec): the target dist. spec.
"""
with DistSpecManager.no_grad():
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
self._tensor_spec.dist_spec = dist_spec
def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor':
tensor_spec = copy(self._tensor_spec)
tensor_spec.dist_spec = dist_spec
ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec)
ret = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec)
return ColoTensor.from_torch_tensor(ret, tensor_spec)
def to_replicate_(self):
"""to_replicate_
an inline member function, converting dist spec of the tensor to REPLICATE
"""
self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, distspec.replicate())
self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, distspec.replicate())
self._tensor_spec.dist_spec = distspec.replicate()
def to_replicate(self) -> 'ColoTensor':
"""to_replicate
converting dist spec of the tensor to REPLICATE
"""
return self.convert_to_dist_spec(distspec.replicate(self.spec.get_process_group()))
return self.convert_to_dist_spec(distspec.replicate(self.tensor_spec.get_process_group()))
@staticmethod
def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor':
@ -134,16 +134,6 @@ class ColoTensor(torch.Tensor):
else:
with torch._C.DisableTorchFunction():
data = self.data.clone()
tensor = ColoTensor(data, spec=copy(self.spec))
tensor = ColoTensor(data, spec=copy(self.tensor_spec))
memo[id(self)] = tensor
return tensor
# TODO(jiaruifang) a patch for gpt test.
# We need to override the member function must operate on a replicated tensor
# def view(self, *args, **kwargs):
# self.data = DistSpecManager.handle_trans_spec(self,
# self.spec.dist_spec,
# distspec.replicate(self.spec.get_process_group()))
# # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group())
# self.data.view(*args, **kwargs)
# return ColoTensor.from_torch_tensor(self.data)
return tensor

View File

@ -18,6 +18,8 @@ class ComputeSpec(object):
def __init__(self, compute_pattern: ComputePattern) -> None:
assert isinstance(compute_pattern, ComputePattern)
self.compute_pattern = compute_pattern
# Make sure output tensors are replicate
self.output_replicate = True
def __repr__(self):
return f'compute pattern: {self.compute_pattern}'

View File

@ -129,7 +129,7 @@ def _get_colo_tensors_info(*args) -> list:
info = []
for arg in args:
if isinstance(arg, ColoTensor):
info.append((arg.__class__, arg.spec))
info.append((arg.__class__, arg.tensor_spec))
else:
info.append(None)
return info

View File

@ -42,10 +42,10 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
has_dist_parameter = False
with torch.no_grad():
for param in self.parameters():
if isinstance(param, ColoParameter) and param.has_spec():
if isinstance(param, ColoParameter) and param.has_compute_spec():
has_dist_parameter = True
mapping[id(param)] = copy(param.spec)
param.set_spec(TensorSpec(distspec.replicate()))
mapping[id(param)] = copy(param.tensor_spec)
param.set_tensor_spec(TensorSpec(distspec.replicate()))
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
@ -62,7 +62,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di
param_id = id(param)
if param_id in mapping:
spec = mapping[id(param)]
param.set_spec(spec)
param.set_tensor_spec(spec)
return ret

View File

@ -43,7 +43,7 @@ def init_1d_row(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col(weight, bias):
@ -51,8 +51,8 @@ def init_1d_col(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)
weight.set_tensor_spec(spec)
bias.set_tensor_spec(spec)
def run_with_spec(spec_init_func):
@ -63,6 +63,7 @@ def run_with_spec(spec_init_func):
x = torch.rand(2, 16).cuda()
out = model(x)
colo_out = torch.addmm(bias, x, weight)
colo_out = colo_out.to_replicate()
assert tensor_equal(out, colo_out)
grad = torch.rand_like(out)
out.backward(grad)

View File

@ -20,7 +20,7 @@ def init_1d_col(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func):

View File

@ -20,7 +20,7 @@ def init_1d_row(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col(weight):
@ -28,7 +28,7 @@ def init_1d_col(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def run_with_spec(spec_init_func):

View File

@ -22,7 +22,7 @@ def init_1d_row_spec(model):
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_spec(spec)
p.set_tensor_spec(spec)
def init_1d_col_spec(model):
@ -32,7 +32,7 @@ def init_1d_col_spec(model):
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_spec(spec)
p.set_tensor_spec(spec)
def check_param_equal(model, torch_model):

View File

@ -21,7 +21,7 @@ def init_1d_row(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col(weight, bias):
@ -29,8 +29,8 @@ def init_1d_col(weight, bias):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
bias.set_spec(spec)
weight.set_tensor_spec(spec)
bias.set_tensor_spec(spec)
def run_with_spec(spec_init_func):

View File

@ -23,7 +23,7 @@ def init_1d_row_linear(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col_linear(weight):
@ -31,7 +31,7 @@ def init_1d_col_linear(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_row_embedding(weight):
@ -39,7 +39,7 @@ def init_1d_row_embedding(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def init_1d_col_embedding(weight):
@ -47,7 +47,7 @@ def init_1d_col_embedding(weight):
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_spec(spec)
weight.set_tensor_spec(spec)
def run_1d_hybrid_tp(model_name):

View File

@ -157,7 +157,7 @@ def run_check_shared_param():
col_spec = TensorSpec(
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
ComputeSpec(ComputePattern.TP1D))
model.cls.predictions.bias.set_spec(col_spec)
model.cls.predictions.bias.set_tensor_spec(col_spec)
try:
check_colo_module(model.cls.predictions.decoder, recursive=False)
except Exception as e:

View File

@ -36,10 +36,10 @@ def test_layernorm():
def check_spec_eq(tensor, other):
assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor)
for k in dir(tensor.spec.dist_spec):
for k in dir(tensor.tensor_spec.dist_spec):
if not k.startswith('__'):
assert hasattr(other.spec.dist_spec, k)
assert getattr(tensor.spec.dist_spec, k) == getattr(other.spec.dist_spec, k)
assert hasattr(other.tensor_spec.dist_spec, k)
assert getattr(tensor.tensor_spec.dist_spec, k) == getattr(other.tensor_spec.dist_spec, k)
def check_element_wise_ops():

View File

@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size):
shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size])
tensor_spec = TensorSpec(shard_spec)
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
t.set_spec(TensorSpec(dist_spec=distspec.replicate()))
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
assert t.shape == torch.Size((4 * world_size, 5))

View File

@ -51,7 +51,7 @@ def init_1d_row_spec(model):
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'weight' in n and 'ln' not in n:
p.set_spec(spec)
p.set_tensor_spec(spec)
def init_1d_col_spec(model):
@ -61,7 +61,7 @@ def init_1d_col_spec(model):
with DistSpecManager.no_grad():
for n, p in model.named_parameters():
if 'ln' not in n and ('weight' in n or 'bias' in n):
p.set_spec(spec)
p.set_tensor_spec(spec)
@parameterize('use_chunk', [False, True])