diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index 71f868fa5..78c7a154c 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -13,34 +13,37 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso # beta * input + alpha * All-Reduce(Output) = res mat1 = mat1.convert_to_dist_spec( - distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()])) + distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()])) # Output:P partial_output = torch.mm(mat1, mat2) # Reduce(Output) output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) # input - assert not input_tensor.has_spec(), 'Invalid input spec for 1Drow addmm op' + assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op' output = beta * input_tensor + alpha * output - output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.spec.get_process_group()))) + output = ColoTensor.from_torch_tensor(output, + spec=TensorSpec(distspec.replicate(mat2.tensor_spec.get_process_group()))) return output def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, alpha: Number) -> ColoTensor: # mat1:B x mat2:S[1] + input:S[1] = Output:S[1] - parallel_action = mat2.spec.compute_spec - mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.spec.get_process_group())) + compute_spec = mat2.tensor_spec.compute_spec + mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.tensor_spec.get_process_group())) mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D) output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha) - output_spec = TensorSpec(distspec.shard(mat2.spec.get_process_group(), [-1], [mat2.spec.get_process_group_size()]), - ComputeSpec(ComputePattern.TP1D)) + output_spec = TensorSpec( + distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]), + ComputeSpec(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - # TODO(jiaruifang) addam is special case - # since gpt call view after the Op. - return output.to_replicate() + if compute_spec.output_replicate: + return output.to_replicate() + else: + return output def colo_addmm_1d(mode: str, input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number, @@ -64,14 +67,15 @@ def colo_addmm(input_tensor: GeneralTensor, # Add communication logic before and after linear call. ret_tensor = None - if not mat2.has_spec(): # No Model Parallel Applied - assert mat2.spec.is_gathered(), 'Invalid mat2 spec for native addmm op' - assert input_tensor.spec.is_gathered(), 'Invalid input spec for native addmm op' + if not mat2.has_compute_spec(): # No Model Parallel Applied + assert mat2.tensor_spec.is_gathered(), 'Invalid mat2 spec for native addmm op' + assert input_tensor.tensor_spec.is_gathered(), 'Invalid input spec for native addmm op' ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)) - elif mat2.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if mat2.spec.is_1D_row() and input_tensor.spec.is_gathered(): + elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_gathered(): mode = 'row' - elif mat2.spec.is_1D_col() and (input_tensor.spec.is_1D_col() or input_tensor.spec.is_1D_row()): + elif mat2.tensor_spec.is_1D_col() and (input_tensor.tensor_spec.is_1D_col() + or input_tensor.tensor_spec.is_1D_row()): mode = 'col' else: raise NotImplementedError diff --git a/colossalai/nn/_ops/element_wise.py b/colossalai/nn/_ops/element_wise.py index 44de07f83..49f45d96b 100644 --- a/colossalai/nn/_ops/element_wise.py +++ b/colossalai/nn/_ops/element_wise.py @@ -18,7 +18,7 @@ def register_elementwise_op(op): """ output = op(input_tensor, *args, **kwargs) if isinstance(input_tensor, ColoTensor): - spec = copy(input_tensor.spec) + spec = copy(input_tensor.tensor_spec) return ColoTensor.from_torch_tensor(output, spec=spec) return ColoTensor.from_torch_tensor(output) diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index 03ce57a76..284ed1f00 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -17,7 +17,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, sparse: bool = False) -> ColoTensor: # embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # Gather splitted lookup table - input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) + input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) output_parallel = F.embedding(input_tensor, weight, @@ -27,10 +27,15 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse) output_spec = TensorSpec( - distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), + distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]), ComputeSpec(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - return output.to_replicate() + + compute_spec = weight.tensor_spec.compute_spec + if compute_spec.output_replicate: + return output.to_replicate() + else: + return output def colo_embedding_1Drow(input_tensor: ColoTensor, @@ -43,7 +48,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, # embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim) # Find index in this shard and mask those not here # Reduce all - input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) + input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) num_embeddings_per_partition = weight.size(0) @@ -70,7 +75,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, partial_output[input_mask, :] = 0. # Reduce across all the model parallel GPUs. output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) - output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group()))) + output = ColoTensor.from_torch_tensor(output, + spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group()))) return output @@ -108,8 +114,8 @@ def colo_embedding(input_tensor: GeneralTensor, # Handle differen parallel actions. - if not weight.has_spec(): # No Model Parallel Applied - assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op' + if not weight.has_compute_spec(): # No Model Parallel Applied + assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op' return ColoTensor.from_torch_tensor( F.embedding(input_tensor, weight, @@ -118,10 +124,10 @@ def colo_embedding(input_tensor: GeneralTensor, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, sparse=sparse)) - elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.spec.is_1D_row(): + elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.tensor_spec.is_1D_row(): mode = 'row' - elif weight.spec.is_1D_col(): + elif weight.tensor_spec.is_1D_col(): mode = 'col' else: raise NotImplementedError diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py index eb6e495e9..77a2d685e 100644 --- a/colossalai/nn/_ops/embedding_bag.py +++ b/colossalai/nn/_ops/embedding_bag.py @@ -19,7 +19,7 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, padding_idx: Optional[int] = None) -> ColoTensor: # embedding_bag_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P) # Gather splitted lookup table - input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) + input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) output_parallel = F.embedding_bag(input_tensor, weight, @@ -33,11 +33,14 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor, include_last_offset=include_last_offset, padding_idx=padding_idx) output_spec = TensorSpec( - distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), + distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]), ComputeSpec(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - return output.to_replicate() + if weight.tensor_spec.compute_spec.output_replicate: + return output.to_replicate() + else: + return output def colo_embedding_bag_1d(tp_mode: str, @@ -86,8 +89,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor, # Handle differen parallel actions. - if not weight.has_spec(): # No Model Parallel Applied - assert weight.spec.is_gathered(), 'Invalid weight spec for native embedding op' + if not weight.has_compute_spec(): # No Model Parallel Applied + assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op' return ColoTensor.from_torch_tensor( F.embedding_bag(input_tensor, weight, @@ -100,8 +103,8 @@ def colo_embedding_bag(input_tensor: GeneralTensor, per_sample_weights=per_sample_weights, include_last_offset=include_last_offset, padding_idx=padding_idx)) - elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.spec.is_1D_col(): + elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.tensor_spec.is_1D_col(): tp_mode = 'col' else: raise NotImplementedError diff --git a/colossalai/nn/_ops/layernorm.py b/colossalai/nn/_ops/layernorm.py index 8f3ca8cac..12dcf6bfb 100644 --- a/colossalai/nn/_ops/layernorm.py +++ b/colossalai/nn/_ops/layernorm.py @@ -17,8 +17,8 @@ def colo_layernorm( input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias))) # TODO (ver217): check dist spec - input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.spec.get_process_group())) + input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.tensor_spec.get_process_group())) output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps) - output = ColoTensor.from_torch_tensor(output, input_tensor.spec) + output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec) return output diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index dc4487e7b..1de4d2dca 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -13,7 +13,7 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option # All-Reduce(Output) + bias = res # Input:S[1] input_tensor = input_tensor.convert_to_dist_spec( - distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()])) + distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()])) # Output:P partial_output = F.linear(input_tensor, weight) @@ -21,10 +21,11 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option output = reduce_input(partial_output, ParallelMode.PARALLEL_1D) # Bias if bias is not None: - assert not bias.has_spec(), 'Invalid bias spec for 1Drow Linear op' + assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op' output = output + bias - output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.spec.get_process_group()))) + output = ColoTensor.from_torch_tensor(output, + spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group()))) return output @@ -32,17 +33,20 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option # Input:B x Weight:S[1] + Bias:S[1] = Output:S[1] # All-Gather(Output) # Input:B - parallel_action = weight.spec.compute_spec - input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) + compute_spec = weight.tensor_spec.compute_spec + input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D) output_parallel = F.linear(input_parallel, weight, bias) output = ColoTensor.from_torch_tensor(output_parallel, spec=TensorSpec( - distspec.shard(weight.spec.get_process_group(), [-1], - [weight.spec.get_process_group_size()]), + distspec.shard(weight.tensor_spec.get_process_group(), [-1], + [weight.tensor_spec.get_process_group_size()]), ComputeSpec(ComputePattern.TP1D))) - return output.to_replicate() + if compute_spec.output_replicate: + return output.to_replicate() + else: + return output def colo_linear_1d(mode: str, input_tensor: ColoTensor, weight: ColoTensor, bias: Optional[ColoTensor]) -> 'ColoTensor': @@ -62,14 +66,15 @@ def colo_linear_imp(input_tensor: GeneralTensor, # Add communication logic before and after linear call. ret_tensor = None - if not weight.has_spec(): # No Model Parallel Applied - assert weight.spec.is_gathered(), 'Invalid weight spec for native Linear op' - assert bias is None or bias.spec.is_gathered(), 'Invalid bias spec for native Linear op' + if not weight.has_compute_spec(): # No Model Parallel Applied + assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native Linear op' + assert bias is None or bias.tensor_spec.is_gathered(), 'Invalid bias spec for native Linear op' ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias)) - elif weight.spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.spec.is_1D_col() and (bias is None or bias.spec.is_gathered()): + elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_gathered()): mode = 'row' - elif weight.spec.is_1D_row() and (bias is None or bias.spec.is_1D_row() or bias.spec.is_1D_col()): + elif weight.tensor_spec.is_1D_row() and (bias is None or bias.tensor_spec.is_1D_row() + or bias.tensor_spec.is_1D_col()): mode = 'col' else: raise NotImplementedError diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py index cf4468c43..0082b1979 100644 --- a/colossalai/nn/_ops/loss.py +++ b/colossalai/nn/_ops/loss.py @@ -18,7 +18,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor, label_smoothing: float = 0.0): input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight))) - if input_tensor.spec.is_gathered(): # Input is gathered + if input_tensor.tensor_spec.is_gathered(): # Input is gathered output = F.cross_entropy(input_tensor, target, weight=weight, @@ -28,8 +28,8 @@ def colo_cross_entropy(input_tensor: GeneralTensor, reduction=reduction, label_smoothing=label_smoothing) return ColoTensor.from_torch_tensor(output) - elif input_tensor.has_spec(): # Single Model Parallel Applied - if input_tensor.spec.is_1D_col(): + elif input_tensor.has_compute_spec(): # Single Model Parallel Applied + if input_tensor.tensor_spec.is_1D_col(): output = VocabParallelCrossEntropyLoss1D()(input_tensor, target) return ColoTensor.from_torch_tensor(output) else: diff --git a/colossalai/nn/parallel/layers/module_utils.py b/colossalai/nn/parallel/layers/module_utils.py index 5474480bb..37d8afbec 100644 --- a/colossalai/nn/parallel/layers/module_utils.py +++ b/colossalai/nn/parallel/layers/module_utils.py @@ -38,8 +38,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True): param = module.get_parameter(param_name) if not isinstance(param, ColoParameter): raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.') - if param.has_spec(): - cur_compute_pattern = param.spec.compute_spec.compute_pattern + if param.has_compute_spec(): + cur_compute_pattern = param.tensor_spec.compute_spec.compute_pattern if compute_pattern is None: compute_pattern = cur_compute_pattern else: @@ -61,8 +61,8 @@ def check_colo_module(module: torch.nn.Module, recursive=True): cur_match = True for param_name, dist_spec in param_specs.items(): param = module.get_parameter(param_name) - if param.has_spec(): - if dist_spec != param.spec.dist_spec: + if param.has_compute_spec(): + if dist_spec != param.tensor_spec.dist_spec: cur_match = False break else: @@ -97,7 +97,7 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ComputeSpec, recu param = module.get_parameter(param_name) if isinstance(param, ColoParameter): spec = TensorSpec(dist_spec, parallel_action) - param.set_spec(spec) + param.set_tensor_spec(spec) for mod in param.shared_param_modules: modules_update_param.add(mod) for mod in modules_update_param: diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 501185d13..0414c7d07 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -82,7 +82,7 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): else: with torch._C.DisableTorchFunction(): data = self.data.clone() - tensor = ColoParameter(data, self.requires_grad, spec=copy(self.spec)) + tensor = ColoParameter(data, self.requires_grad, spec=copy(self.tensor_spec)) memo[id(self)] = tensor return tensor diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index 98684ca54..d277207bf 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -57,15 +57,15 @@ class ColoTensor(torch.Tensor): self._graph_node = None @property - def spec(self) -> TensorSpec: + def tensor_spec(self) -> TensorSpec: return self._tensor_spec - def set_spec(self, spec: TensorSpec) -> None: + def set_tensor_spec(self, spec: TensorSpec) -> None: spec = copy(spec) self._convert_to_dist_spec(spec.dist_spec) self._tensor_spec = spec - def has_spec(self) -> bool: + def has_compute_spec(self) -> bool: return self._tensor_spec.compute_spec is not None def is_model_data(self) -> bool: @@ -100,27 +100,27 @@ class ColoTensor(torch.Tensor): dist_spec (_DistSpec): the target dist. spec. """ with DistSpecManager.no_grad(): - self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) + self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec) self._tensor_spec.dist_spec = dist_spec def convert_to_dist_spec(self, dist_spec: _DistSpec) -> 'ColoTensor': tensor_spec = copy(self._tensor_spec) tensor_spec.dist_spec = dist_spec - ret = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, dist_spec) + ret = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, dist_spec) return ColoTensor.from_torch_tensor(ret, tensor_spec) def to_replicate_(self): """to_replicate_ an inline member function, converting dist spec of the tensor to REPLICATE """ - self.data = DistSpecManager.handle_trans_spec(self, self.spec.dist_spec, distspec.replicate()) + self.data = DistSpecManager.handle_trans_spec(self, self.tensor_spec.dist_spec, distspec.replicate()) self._tensor_spec.dist_spec = distspec.replicate() def to_replicate(self) -> 'ColoTensor': """to_replicate converting dist spec of the tensor to REPLICATE """ - return self.convert_to_dist_spec(distspec.replicate(self.spec.get_process_group())) + return self.convert_to_dist_spec(distspec.replicate(self.tensor_spec.get_process_group())) @staticmethod def from_torch_tensor(tensor: torch.Tensor, spec: TensorSpec = TensorSpec(distspec.replicate())) -> 'ColoTensor': @@ -134,16 +134,6 @@ class ColoTensor(torch.Tensor): else: with torch._C.DisableTorchFunction(): data = self.data.clone() - tensor = ColoTensor(data, spec=copy(self.spec)) + tensor = ColoTensor(data, spec=copy(self.tensor_spec)) memo[id(self)] = tensor - return tensor - - # TODO(jiaruifang) a patch for gpt test. - # We need to override the member function must operate on a replicated tensor - # def view(self, *args, **kwargs): - # self.data = DistSpecManager.handle_trans_spec(self, - # self.spec.dist_spec, - # distspec.replicate(self.spec.get_process_group())) - # # self._tensor_spec.dist_spec = distspec.replicate(self.spec.get_process_group()) - # self.data.view(*args, **kwargs) - # return ColoTensor.from_torch_tensor(self.data) + return tensor \ No newline at end of file diff --git a/colossalai/tensor/compute_spec.py b/colossalai/tensor/compute_spec.py index b1e07af12..acaba2a46 100644 --- a/colossalai/tensor/compute_spec.py +++ b/colossalai/tensor/compute_spec.py @@ -18,6 +18,8 @@ class ComputeSpec(object): def __init__(self, compute_pattern: ComputePattern) -> None: assert isinstance(compute_pattern, ComputePattern) self.compute_pattern = compute_pattern + # Make sure output tensors are replicate + self.output_replicate = True def __repr__(self): return f'compute pattern: {self.compute_pattern}' diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index fee6a0a6b..8c83d9914 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -129,7 +129,7 @@ def _get_colo_tensors_info(*args) -> list: info = [] for arg in args: if isinstance(arg, ColoTensor): - info.append((arg.__class__, arg.spec)) + info.append((arg.__class__, arg.tensor_spec)) else: info.append(None) return info diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index 283af5acc..6298ac102 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -42,10 +42,10 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di has_dist_parameter = False with torch.no_grad(): for param in self.parameters(): - if isinstance(param, ColoParameter) and param.has_spec(): + if isinstance(param, ColoParameter) and param.has_compute_spec(): has_dist_parameter = True - mapping[id(param)] = copy(param.spec) - param.set_spec(TensorSpec(distspec.replicate())) + mapping[id(param)] = copy(param.tensor_spec) + param.set_tensor_spec(TensorSpec(distspec.replicate())) # TODO: fix when keep_vars = True # when keep_vars = False, the state_dict_func will call detach to create @@ -62,7 +62,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di param_id = id(param) if param_id in mapping: spec = mapping[id(param)] - param.set_spec(spec) + param.set_tensor_spec(spec) return ret diff --git a/tests/test_tensor/test_addmm_tp.py b/tests/test_tensor/test_addmm_tp.py index 2aa7f2753..985fe7818 100644 --- a/tests/test_tensor/test_addmm_tp.py +++ b/tests/test_tensor/test_addmm_tp.py @@ -43,7 +43,7 @@ def init_1d_row(weight, bias): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) + weight.set_tensor_spec(spec) def init_1d_col(weight, bias): @@ -51,8 +51,8 @@ def init_1d_col(weight, bias): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) - bias.set_spec(spec) + weight.set_tensor_spec(spec) + bias.set_tensor_spec(spec) def run_with_spec(spec_init_func): @@ -63,6 +63,7 @@ def run_with_spec(spec_init_func): x = torch.rand(2, 16).cuda() out = model(x) colo_out = torch.addmm(bias, x, weight) + colo_out = colo_out.to_replicate() assert tensor_equal(out, colo_out) grad = torch.rand_like(out) out.backward(grad) diff --git a/tests/test_tensor/test_embedding_bag_tp.py b/tests/test_tensor/test_embedding_bag_tp.py index 819d763e8..d290fb980 100644 --- a/tests/test_tensor/test_embedding_bag_tp.py +++ b/tests/test_tensor/test_embedding_bag_tp.py @@ -20,7 +20,7 @@ def init_1d_col(weight): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) + weight.set_tensor_spec(spec) def run_with_spec(spec_init_func): diff --git a/tests/test_tensor/test_embedding_tp.py b/tests/test_tensor/test_embedding_tp.py index 38ad1fe93..5ac785200 100644 --- a/tests/test_tensor/test_embedding_tp.py +++ b/tests/test_tensor/test_embedding_tp.py @@ -20,7 +20,7 @@ def init_1d_row(weight): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) + weight.set_tensor_spec(spec) def init_1d_col(weight): @@ -28,7 +28,7 @@ def init_1d_col(weight): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) + weight.set_tensor_spec(spec) def run_with_spec(spec_init_func): diff --git a/tests/test_tensor/test_gpt.py b/tests/test_tensor/test_gpt.py index ae90b2378..0b5c7f2b7 100644 --- a/tests/test_tensor/test_gpt.py +++ b/tests/test_tensor/test_gpt.py @@ -22,7 +22,7 @@ def init_1d_row_spec(model): with DistSpecManager.no_grad(): for n, p in model.named_parameters(): if 'weight' in n and 'ln' not in n: - p.set_spec(spec) + p.set_tensor_spec(spec) def init_1d_col_spec(model): @@ -32,7 +32,7 @@ def init_1d_col_spec(model): with DistSpecManager.no_grad(): for n, p in model.named_parameters(): if 'ln' not in n and ('weight' in n or 'bias' in n): - p.set_spec(spec) + p.set_tensor_spec(spec) def check_param_equal(model, torch_model): diff --git a/tests/test_tensor/test_linear_tp.py b/tests/test_tensor/test_linear_tp.py index 5d57fb218..1462c296f 100644 --- a/tests/test_tensor/test_linear_tp.py +++ b/tests/test_tensor/test_linear_tp.py @@ -21,7 +21,7 @@ def init_1d_row(weight, bias): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) + weight.set_tensor_spec(spec) def init_1d_col(weight, bias): @@ -29,8 +29,8 @@ def init_1d_col(weight, bias): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) - bias.set_spec(spec) + weight.set_tensor_spec(spec) + bias.set_tensor_spec(spec) def run_with_spec(spec_init_func): diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index 4d56b221b..2880af885 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -23,7 +23,7 @@ def init_1d_row_linear(weight): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) + weight.set_tensor_spec(spec) def init_1d_col_linear(weight): @@ -31,7 +31,7 @@ def init_1d_col_linear(weight): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) + weight.set_tensor_spec(spec) def init_1d_row_embedding(weight): @@ -39,7 +39,7 @@ def init_1d_row_embedding(weight): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) + weight.set_tensor_spec(spec) def init_1d_col_embedding(weight): @@ -47,7 +47,7 @@ def init_1d_col_embedding(weight): distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) with DistSpecManager.no_grad(): - weight.set_spec(spec) + weight.set_tensor_spec(spec) def run_1d_hybrid_tp(model_name): diff --git a/tests/test_tensor/test_module_spec.py b/tests/test_tensor/test_module_spec.py index fd0727795..3223deb9b 100644 --- a/tests/test_tensor/test_module_spec.py +++ b/tests/test_tensor/test_module_spec.py @@ -157,7 +157,7 @@ def run_check_shared_param(): col_spec = TensorSpec( distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]), ComputeSpec(ComputePattern.TP1D)) - model.cls.predictions.bias.set_spec(col_spec) + model.cls.predictions.bias.set_tensor_spec(col_spec) try: check_colo_module(model.cls.predictions.decoder, recursive=False) except Exception as e: diff --git a/tests/test_tensor/test_op.py b/tests/test_tensor/test_op.py index 5298f292d..26d64b009 100644 --- a/tests/test_tensor/test_op.py +++ b/tests/test_tensor/test_op.py @@ -36,10 +36,10 @@ def test_layernorm(): def check_spec_eq(tensor, other): assert isinstance(tensor, ColoTensor) and isinstance(other, ColoTensor) - for k in dir(tensor.spec.dist_spec): + for k in dir(tensor.tensor_spec.dist_spec): if not k.startswith('__'): - assert hasattr(other.spec.dist_spec, k) - assert getattr(tensor.spec.dist_spec, k) == getattr(other.spec.dist_spec, k) + assert hasattr(other.tensor_spec.dist_spec, k) + assert getattr(tensor.tensor_spec.dist_spec, k) == getattr(other.tensor_spec.dist_spec, k) def check_element_wise_ops(): diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index e2cf25e83..a940234a9 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -66,7 +66,7 @@ def _run_tensor_shard_init(world_size): shard_spec = distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[world_size]) tensor_spec = TensorSpec(shard_spec) t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec) - t.set_spec(TensorSpec(dist_spec=distspec.replicate())) + t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate())) assert t.shape == torch.Size((4 * world_size, 5)) diff --git a/tests/test_tensor/test_zero_optim.py b/tests/test_tensor/test_zero_optim.py index 94747563d..bd756fefd 100644 --- a/tests/test_tensor/test_zero_optim.py +++ b/tests/test_tensor/test_zero_optim.py @@ -51,7 +51,7 @@ def init_1d_row_spec(model): with DistSpecManager.no_grad(): for n, p in model.named_parameters(): if 'weight' in n and 'ln' not in n: - p.set_spec(spec) + p.set_tensor_spec(spec) def init_1d_col_spec(model): @@ -61,7 +61,7 @@ def init_1d_col_spec(model): with DistSpecManager.no_grad(): for n, p in model.named_parameters(): if 'ln' not in n and ('weight' in n or 'bias' in n): - p.set_spec(spec) + p.set_tensor_spec(spec) @parameterize('use_chunk', [False, True])