diff --git a/colossalai/core.py b/colossalai/core.py index 4ae054d46..153247bbe 100644 --- a/colossalai/core.py +++ b/colossalai/core.py @@ -2,3 +2,5 @@ # -*- encoding: utf-8 -*- from colossalai.context.parallel_context import global_context + +__all__ = ['global_context'] \ No newline at end of file diff --git a/colossalai/nn/_ops/addmm.py b/colossalai/nn/_ops/addmm.py index 78c7a154c..2091f3247 100644 --- a/colossalai/nn/_ops/addmm.py +++ b/colossalai/nn/_ops/addmm.py @@ -68,11 +68,11 @@ def colo_addmm(input_tensor: GeneralTensor, # Add communication logic before and after linear call. ret_tensor = None if not mat2.has_compute_spec(): # No Model Parallel Applied - assert mat2.tensor_spec.is_gathered(), 'Invalid mat2 spec for native addmm op' - assert input_tensor.tensor_spec.is_gathered(), 'Invalid input spec for native addmm op' + assert mat2.tensor_spec.is_replicate(), 'Invalid mat2 spec for native addmm op' + assert input_tensor.tensor_spec.is_replicate(), 'Invalid input spec for native addmm op' ret_tensor = ColoTensor.from_torch_tensor(torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)) elif mat2.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_gathered(): + if mat2.tensor_spec.is_1D_row() and input_tensor.tensor_spec.is_replicate(): mode = 'row' elif mat2.tensor_spec.is_1D_col() and (input_tensor.tensor_spec.is_1D_col() or input_tensor.tensor_spec.is_1D_row()): diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index 284ed1f00..5f41b0c6e 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -51,7 +51,7 @@ def colo_embedding_1Drow(input_tensor: ColoTensor, input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group())) tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - num_embeddings_per_partition = weight.size(0) + num_embeddings_per_partition = weight.size_base(0) vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition vocab_end_index = vocab_start_index + num_embeddings_per_partition @@ -115,7 +115,7 @@ def colo_embedding(input_tensor: GeneralTensor, # Handle differen parallel actions. if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op' + assert weight.tensor_spec.is_replicate(), 'Invalid weight spec for native embedding op' return ColoTensor.from_torch_tensor( F.embedding(input_tensor, weight, diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py index 77a2d685e..825dd8d92 100644 --- a/colossalai/nn/_ops/embedding_bag.py +++ b/colossalai/nn/_ops/embedding_bag.py @@ -90,7 +90,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor, # Handle differen parallel actions. if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native embedding op' + assert weight.tensor_spec.is_replicate(), 'Invalid weight spec for native embedding op' return ColoTensor.from_torch_tensor( F.embedding_bag(input_tensor, weight, diff --git a/colossalai/nn/_ops/linear.py b/colossalai/nn/_ops/linear.py index 1de4d2dca..01dcef6a6 100644 --- a/colossalai/nn/_ops/linear.py +++ b/colossalai/nn/_ops/linear.py @@ -67,17 +67,17 @@ def colo_linear_imp(input_tensor: GeneralTensor, # Add communication logic before and after linear call. ret_tensor = None if not weight.has_compute_spec(): # No Model Parallel Applied - assert weight.tensor_spec.is_gathered(), 'Invalid weight spec for native Linear op' - assert bias is None or bias.tensor_spec.is_gathered(), 'Invalid bias spec for native Linear op' + assert weight.tensor_spec.is_replicate(), 'Invalid weight spec for native Linear op' + assert bias is None or bias.tensor_spec.is_replicate(), 'Invalid bias spec for native Linear op' ret_tensor = ColoTensor.from_torch_tensor(F.linear(input_tensor, weight, bias)) elif weight.tensor_spec.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied - if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_gathered()): + if weight.tensor_spec.is_1D_col() and (bias is None or bias.tensor_spec.is_replicate()): mode = 'row' elif weight.tensor_spec.is_1D_row() and (bias is None or bias.tensor_spec.is_1D_row() or bias.tensor_spec.is_1D_col()): mode = 'col' else: - raise NotImplementedError + raise RuntimeError(f"the weight or bias tensor spec is not valid, weight {weight.tensor_spec}, bias {bias}") ret_tensor = colo_linear_1d(mode, input_tensor, weight, bias) else: raise NotImplementedError diff --git a/colossalai/nn/_ops/loss.py b/colossalai/nn/_ops/loss.py index 0082b1979..1fc814937 100644 --- a/colossalai/nn/_ops/loss.py +++ b/colossalai/nn/_ops/loss.py @@ -18,7 +18,7 @@ def colo_cross_entropy(input_tensor: GeneralTensor, label_smoothing: float = 0.0): input_tensor, target, weight = tuple(map(convert_to_colo_tensor, (input_tensor, target, weight))) - if input_tensor.tensor_spec.is_gathered(): # Input is gathered + if input_tensor.tensor_spec.is_replicate(): # Input is gathered output = F.cross_entropy(input_tensor, target, weight=weight, diff --git a/colossalai/tensor/chunk.py b/colossalai/tensor/chunk.py index 4cacb0c7b..bb7a17ae5 100644 --- a/colossalai/tensor/chunk.py +++ b/colossalai/tensor/chunk.py @@ -114,7 +114,7 @@ class Chunk: # if the process owns the rank, then copy the tensor to its chunk buffer # otherwise set its storage size to 0 to reduce memory consumption if self.is_src_rank: - self._payload[self.utilized_size:new_utilized_size].copy_(tensor.view(-1)) + self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten()) tensor_state = TensorState.HOLD tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape) else: diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index 0414c7d07..960c26474 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -101,3 +101,13 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): # TODO(jzy) we don't support object reflection now. # distspec cannot be pickled or rebuilt because it's tightly connected to runtime attribute `process_group`. raise NotImplementedError + + #### the ColoParameter should use the torch.Tensor's builtin methodes ### + + def view(self, *args) -> 'ColoTensor': + return super().view_base(*args) + + def size(self, *args, **kwargs) -> torch.Size: + # import inspect + # print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()]) + return super().size_base(*args, **kwargs) diff --git a/colossalai/tensor/colo_tensor.py b/colossalai/tensor/colo_tensor.py index d277207bf..bd7859dee 100644 --- a/colossalai/tensor/colo_tensor.py +++ b/colossalai/tensor/colo_tensor.py @@ -8,6 +8,7 @@ from colossalai.tensor import TensorSpec from colossalai.tensor import distspec from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.distspec import _DistSpec +from typing import Optional def _convert_output(output): @@ -60,6 +61,12 @@ class ColoTensor(torch.Tensor): def tensor_spec(self) -> TensorSpec: return self._tensor_spec + @tensor_spec.setter + def tensor_spec(self, tenseor_spec: TensorSpec): + spec = copy(spec) + self._convert_to_dist_spec(spec.dist_spec) + self._tensor_spec = spec + def set_tensor_spec(self, spec: TensorSpec) -> None: spec = copy(spec) self._convert_to_dist_spec(spec.dist_spec) @@ -136,4 +143,52 @@ class ColoTensor(torch.Tensor): data = self.data.clone() tensor = ColoTensor(data, spec=copy(self.tensor_spec)) memo[id(self)] = tensor - return tensor \ No newline at end of file + return tensor + + ##### override builtin functions which must use tensor in replicate placement #### + + def view_base(self, *args) -> 'ColoTensor': + return super().view(*args) + + def size_base(self, *args, **kwargs) -> torch.Size: + return super().size(*args, **kwargs) + + def view(self, *args) -> 'ColoTensor': + """override the torch buildin view() + the args passed in must be in a replicate placement. + Returns: + ColoTensor: a tensor after viewed. + """ + if self.tensor_spec.is_replicate(): + return super().view(*args) + # TODO(jiaruifang) check why this not work + # self.data = self.to_replicate() + self.data = DistSpecManager.handle_trans_spec(self.data, self.tensor_spec.dist_spec, distspec.replicate()) + self._tensor_spec.dist_spec = distspec.replicate() + return super().view(*args) + + def size(self, args: Optional[int] = None): + """override the torch buildin size() + the shape passed in must be in a replicate placement. + Returns: + ColoTensor: a tensor after viewed. + """ + if self.tensor_spec.is_replicate(): + if args is not None: + return super().size(args) + else: + return super().size() + + spec = self.tensor_spec.dist_spec + dims = spec.dims + num_partitions = spec.num_partitions + # import inspect + # print(*['{:40}| {}:{}\n'.format(x.function, x.filename, x.lineno) for x in inspect.stack()]) + + size_list = list(super().size()) + for dim, num_partition in zip(dims, num_partitions): + size_list[dim] *= num_partition + if args is not None: + return size_list[args] + else: + return torch.Size(size_list) diff --git a/colossalai/tensor/dist_spec_mgr.py b/colossalai/tensor/dist_spec_mgr.py index 9bb91ad3e..a37261fb4 100644 --- a/colossalai/tensor/dist_spec_mgr.py +++ b/colossalai/tensor/dist_spec_mgr.py @@ -68,6 +68,7 @@ class DistSpecManager: num_parts = prod(dist_spec.num_partitions) for i, dim in enumerate(dist_spec.dims): num_parts //= dist_spec.num_partitions[i] + chunk_size = divide(tensor.size(dim), dist_spec.num_partitions[i]) chunk = chunk.narrow(dim, idx // num_parts * chunk_size, chunk_size) idx %= num_parts diff --git a/colossalai/tensor/tensor_spec.py b/colossalai/tensor/tensor_spec.py index 9d554f68e..4dc944eca 100644 --- a/colossalai/tensor/tensor_spec.py +++ b/colossalai/tensor/tensor_spec.py @@ -26,7 +26,7 @@ class TensorSpec(object): def get_placement(self): return self.dist_spec.placement - def is_gathered(self): + def is_replicate(self): return self.dist_spec.placement == DistPlacementPattern.REPLICATE \ or (len(self.dist_spec.num_partitions) == 1 and self.dist_spec.num_partitions[0] == 1) \ diff --git a/tests/test_tensor/test_gpt.py b/tests/test_tensor/test_gpt.py index 0b5c7f2b7..ffd5ebc85 100644 --- a/tests/test_tensor/test_gpt.py +++ b/tests/test_tensor/test_gpt.py @@ -101,4 +101,4 @@ def test_gpt(world_size, use_ddp): if __name__ == '__main__': - test_gpt(4, False) + test_gpt(4, True) diff --git a/tests/test_tensor/test_tensor.py b/tests/test_tensor/test_tensor.py index a940234a9..e28a94bd8 100644 --- a/tests/test_tensor/test_tensor.py +++ b/tests/test_tensor/test_tensor.py @@ -60,6 +60,19 @@ def test_operand(): #### Test Distributed init a Colotensor +def _run_view(world_size): + t_ref = torch.randn(4, 5) + t = ColoTensor.from_torch_tensor( + t_ref, TensorSpec(distspec.shard(process_group=gpc.get_group(ParallelMode.DATA), dims=[0], num_partitions=[2]))) + + assert t.size()[0] == 4 * world_size + assert t.size(1) == 5 + assert t.size() == torch.Size([4 * world_size, 5]) + + t = t.view(4 * 5 * world_size) + assert t.shape == torch.Size([4 * 5 * world_size]) + + def _run_tensor_shard_init(world_size): t_ref = torch.randn(4, 5) print(gpc.get_group(ParallelMode.DATA).size()) @@ -77,20 +90,21 @@ def _run_tensor_replicated_init(world_size): assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}" -def run_tensor_init(rank, world_size, port): +def run_dist_tests(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') _run_tensor_shard_init(world_size) _run_tensor_replicated_init(world_size) + _run_view(world_size) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() -def _test_dist_init(world_size): - run_func = partial(run_tensor_init, world_size=world_size, port=free_port()) +def _test_dist_cases(world_size): + run_func = partial(run_dist_tests, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) if __name__ == '__main__': # _test_dist_init(4) - test_new() + _test_dist_cases(2)