From 8eb09a4c6946b40930cffc7f2d9bb150ee714b63 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Thu, 22 Jun 2023 11:42:11 +0800 Subject: [PATCH] [shardformer] support module saving and loading (#4062) * [shardformer] support module saving and loading * polish code --- colossalai/checkpoint_io/utils.py | 4 +- colossalai/lazy/lazy_init.py | 7 +- colossalai/shardformer/layer/embedding.py | 25 +- colossalai/shardformer/layer/linear.py | 30 +- colossalai/shardformer/layer/linear_conv.py | 2 - .../shardformer/layer/parallel_module.py | 142 ++++++++++ colossalai/tensor/d_tensor/__init__.py | 24 ++ colossalai/tensor/d_tensor/api.py | 263 ++++++++++++++++-- colossalai/tensor/d_tensor/layout.py | 3 +- .../tensor/d_tensor/layout_converter.py | 13 +- colossalai/tensor/d_tensor/utils.py | 2 +- test.py | 1 + tests/test_lazy/lazy_init_utils.py | 9 +- .../test_layer/test_embedding.py | 4 + .../test_layer/test_linear_1d.py | 12 +- .../test_vocab_parallel_embedding_1d.py | 6 +- .../test_dtensor/test_comm_spec.py | 3 - .../test_tensor/test_dtensor/test_dtensor.py | 43 ++- .../test_dtensor/test_layout_converter.py | 2 +- 19 files changed, 493 insertions(+), 102 deletions(-) create mode 100644 test.py diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 3dada00cd..68981dff0 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -10,7 +10,7 @@ import torch import torch.nn as nn from torch.optim import Optimizer -from colossalai.tensor.d_tensor.d_tensor import DTensor +from colossalai.tensor.d_tensor import is_distributed_tensor SAFE_WEIGHTS_NAME = "model.safetensors" WEIGHTS_NAME = "pytorch_model.bin" @@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) for key, weight in state_dict.items(): ret_block = None ret_block_size = 0 - if type(weight) != DTensor: + if is_distributed_tensor(weight): weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. diff --git a/colossalai/lazy/lazy_init.py b/colossalai/lazy/lazy_init.py index 76f550dc4..1e45eced5 100644 --- a/colossalai/lazy/lazy_init.py +++ b/colossalai/lazy/lazy_init.py @@ -8,8 +8,9 @@ from torch import Tensor from torch.utils._pytree import tree_map from colossalai._analyzer._subclasses import MetaTensor -from colossalai.tensor.d_tensor.d_tensor import DTensor -from colossalai.tensor.d_tensor.layout import Layout +from colossalai.device.device_mesh import DeviceMesh +from colossalai.tensor.d_tensor import distribute_tensor +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html _NORMAL_FACTORY = [ @@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor): """ target = self._materialize_data() self.clean() - local_tensor = DTensor(target, layout).local_tensor + local_tensor = distribute_tensor(target, device_mesh, sharding_spec) return _convert_cls(self, local_tensor) def clean(self) -> None: diff --git a/colossalai/shardformer/layer/embedding.py b/colossalai/shardformer/layer/embedding.py index 8b9fb03ec..23601a04a 100644 --- a/colossalai/shardformer/layer/embedding.py +++ b/colossalai/shardformer/layer/embedding.py @@ -13,8 +13,7 @@ from torch.nn.parameter import Parameter from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise -from colossalai.utils.cuda import get_current_device +from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param from ._operation import gather_forward_split_backward, reduce_input from .parallel_module import ParallelModule @@ -69,18 +68,17 @@ class Embedding1D(ParallelModule): self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim self.process_group = process_group - self.num_partitions = dist.get_world_size(process_group) - self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs self.gather_output = gather_output - if device is None: - device = get_current_device() - - self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) + # Parameters. + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) + sharded_weight = shard_colwise(weight, process_group) + self.weight = sharded_tensor_to_param(sharded_weight) # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -194,7 +192,7 @@ class VocabParallelEmbedding1D(ParallelModule): **kwargs): super().__init__() self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim + self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs @@ -208,8 +206,11 @@ class VocabParallelEmbedding1D(ParallelModule): self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition - self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) + # parameter + factory_kwargs = {'device': device, 'dtype': dtype} + weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs) + sharded_weight = shard_rowwise(weight, process_group) + self.weight = sharded_tensor_to_param(sharded_weight) # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -252,7 +253,7 @@ class VocabParallelEmbedding1D(ParallelModule): def reset_parameters(self, weight_initializer) -> None: with self.randomizer.fork_rng(enable_cpu=True): - fan_in, fan_out = self.num_embeddings, self.embed_dim + fan_in, fan_out = self.num_embeddings, self.embedding_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() diff --git a/colossalai/shardformer/layer/linear.py b/colossalai/shardformer/layer/linear.py index b87981c6d..912be26b9 100644 --- a/colossalai/shardformer/layer/linear.py +++ b/colossalai/shardformer/layer/linear.py @@ -14,7 +14,7 @@ from torch.nn.parameter import Parameter from colossalai.nn import init as init from colossalai.nn.layer.utils import divide -from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise +from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param from colossalai.utils.cuda import get_current_device from ._operation import ( @@ -76,22 +76,21 @@ class Linear1D_Col(ParallelModule): self.skip_bias_add = skip_bias_add self.device = device self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - self.out_features_per_partition = divide(out_features, self.num_partitions) - # Parameters. - # Initialize weight. - if device is None: - device = get_current_device() factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs)) + + weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + sharded_weight = shard_rowwise(weight, self.process_group) + self.weight = sharded_tensor_to_param(sharded_weight) if bias: - self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) + bias = torch.empty(self.out_features, **factory_kwargs) + sharded_bias = shard_colwise(bias, self.process_group) + self.bias = sharded_tensor_to_param(sharded_bias) else: self.bias = None @@ -128,7 +127,6 @@ class Linear1D_Col(ParallelModule): *args, **kwargs) - # TODO: copy the sharded weights with torch.no_grad(): # the weigh to the linear layer is a transpose # thus shard on row is equal to shard on column @@ -137,7 +135,6 @@ class Linear1D_Col(ParallelModule): if bias: sharded_bias = shard_colwise(module.bias.data, process_group) linear_1d.bias.copy_(sharded_bias) - return linear_1d def reset_parameters(self, weight_initializer, bias_initializer) -> None: @@ -212,21 +209,20 @@ class Linear1D_Row(ParallelModule): self.parallel_input = parallel_input self.skip_bias_add = skip_bias_add self.process_group = process_group - self.num_partitions = dist.get_world_size(self.process_group) if skip_bias_add and not bias: raise ValueError('cannot skip bias addition if bias is None') - # Divide the weight matrix along the last dimension. - self.input_size_per_partition = divide(in_features, self.num_partitions) - # Parameters. # Initialize weight. if device is None: device = get_current_device() factory_kwargs = {'device': device, 'dtype': dtype} - self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs)) + + weight = torch.empty(self.out_features, self.in_features, **factory_kwargs) + sharded_weight = shard_colwise(weight, self.process_group) + self.weight = sharded_tensor_to_param(sharded_weight) if self.stream_chunk_num > 1: # TODO() work for inference only @@ -340,3 +336,5 @@ class Linear1D_Row(ParallelModule): return output else: return output, self.bias + return output, self.bias + return output, self.bias diff --git a/colossalai/shardformer/layer/linear_conv.py b/colossalai/shardformer/layer/linear_conv.py index b4599f489..2adfc1828 100644 --- a/colossalai/shardformer/layer/linear_conv.py +++ b/colossalai/shardformer/layer/linear_conv.py @@ -31,7 +31,6 @@ __all__ = ['LinearConv1D_Col', 'LinearConv1D_Row'] class LinearConv1D_Col(ParallelModule): r"""Linear layer with column parallelism. - Specially created for HuggingFace's GPT2 model. The linear layer is defined as :math:`Y = XA + b`. A is parallelized along its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. @@ -189,7 +188,6 @@ class LinearConv1D_Col(ParallelModule): class LinearConv1D_Row(ParallelModule): r""" Linear layer with row parallelism - Specially created for HuggingFace's GPT2 model. Args: in_features (int): size of each input sample. diff --git a/colossalai/shardformer/layer/parallel_module.py b/colossalai/shardformer/layer/parallel_module.py index c68cd5778..5edcb9dde 100644 --- a/colossalai/shardformer/layer/parallel_module.py +++ b/colossalai/shardformer/layer/parallel_module.py @@ -1,11 +1,23 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import itertools from abc import ABC, abstractmethod from typing import List, Union +import torch import torch.nn as nn from torch.distributed import ProcessGroup +from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module + +from colossalai.tensor.d_tensor import ( + distribute_tensor, + get_device_mesh, + get_sharding_spec, + is_distributed_tensor, + sharded_tensor_to_param, + to_global, +) __all__ = ['ParallelModule'] @@ -25,3 +37,133 @@ class ParallelModule(nn.Module, ABC): in the ith axis of the device mesh. Defaults to None, which means the global process group. """ pass + + def _save_to_state_dict(self, destination, prefix, keep_vars): + r"""Saves module state to `destination` dictionary, containing a state + of the module, but not its descendants. This is called on every + submodule in :meth:`~torch.nn.Module.state_dict`. + + In rare cases, subclasses can achieve class-specific behavior by + overriding this method with custom logic. + + Args: + destination (dict): a dict where state will be stored + prefix (str): the prefix for parameters and buffers used in this + module + """ + for name, param in self._parameters.items(): + if param is not None: + param_ = param if keep_vars else param.detach() + + if is_distributed_tensor(param_): + destination[prefix + name] = to_global(param_) + else: + destination[prefix + name] = param_ + + for name, buf in self._buffers.items(): + if buf is not None and name not in self._non_persistent_buffers_set: + destination[prefix + name] = buf if keep_vars else buf.detach() + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state: + destination[extra_state_key] = self.get_extra_state() + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + r"""Copies parameters and buffers from :attr:`state_dict` into only + this module, but not its descendants. This is called on every submodule + in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this + module in input :attr:`state_dict` is provided as :attr:`local_metadata`. + For state dicts without metadata, :attr:`local_metadata` is empty. + Subclasses can achieve class-specific backward compatible loading using + the version number at `local_metadata.get("version", None)`. + + .. note:: + :attr:`state_dict` is not the same object as the input + :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So + it can be modified. + + Args: + state_dict (dict): a dict containing parameters and + persistent buffers. + prefix (str): the prefix for parameters and buffers used in this + module + local_metadata (dict): a dict containing the metadata for this module. + See + strict (bool): whether to strictly enforce that the keys in + :attr:`state_dict` with :attr:`prefix` match the names of + parameters and buffers in this module + missing_keys (list of str): if ``strict=True``, add missing keys to + this list + unexpected_keys (list of str): if ``strict=True``, add unexpected + keys to this list + error_msgs (list of str): error messages should be added to this + list, and will be reported together in + :meth:`~torch.nn.Module.load_state_dict` + """ + for hook in self._load_state_dict_pre_hooks.values(): + hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set} + local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items()) + local_state = {k: v for k, v in local_name_params if v is not None} + + for name, param in local_state.items(): + key = prefix + name + + if key in state_dict: + input_param = state_dict[key] + if not torch.overrides.is_tensor_like(input_param): + error_msgs.append('While copying the parameter named "{}", ' + 'expected torch.Tensor or Tensor-like object from checkpoint but ' + 'received {}'.format(key, type(input_param))) + continue + + if is_distributed_tensor(param): + # shard the input param + device_mesh = get_device_mesh(param) + sharding_spec = get_sharding_spec(param) + sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec) + input_param = sharded_tensor_to_param(sharded_tensor) + + # This is used to avoid copying uninitialized parameters into + # non-lazy modules, since they dont have the hook to do the checks + # in such case, it will error when accessing the .shape attribute. + is_param_lazy = torch.nn.parameter.is_lazy(param) + # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ + if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: + input_param = input_param[0] + + if not is_param_lazy and input_param.shape != param.shape: + # local shape should match the one in checkpoint + error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' + 'the shape in current model is {}.'.format(key, input_param.shape, param.shape)) + continue + + try: + with torch.no_grad(): + param.copy_(input_param) + except Exception as ex: + error_msgs.append('While copying the parameter named "{}", ' + 'whose dimensions in the model are {} and ' + 'whose dimensions in the checkpoint are {}, ' + 'an exception occurred : {}.'.format(key, param.size(), input_param.size(), + ex.args)) + elif strict: + missing_keys.append(key) + + extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX + if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state: + if extra_state_key in state_dict: + self.set_extra_state(state_dict[extra_state_key]) + elif strict: + missing_keys.append(extra_state_key) + elif strict and (extra_state_key in state_dict): + unexpected_keys.append(extra_state_key) + + if strict: + for key in state_dict.keys(): + if key.startswith(prefix) and key != extra_state_key: + input_name = key[len(prefix):] + input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child + if input_name not in self._modules and input_name not in local_state: + unexpected_keys.append(key) diff --git a/colossalai/tensor/d_tensor/__init__.py b/colossalai/tensor/d_tensor/__init__.py index e69de29bb..52eae0e14 100644 --- a/colossalai/tensor/d_tensor/__init__.py +++ b/colossalai/tensor/d_tensor/__init__.py @@ -0,0 +1,24 @@ +from .api import ( + compute_global_numel, + distribute_tensor, + get_device_mesh, + get_global_shape, + get_layout, + get_sharding_spec, + is_distributed_tensor, + is_sharded, + redistribute, + shard_colwise, + shard_rowwise, + sharded_tensor_to_param, + to_global, +) +from .layout import Layout +from .sharding_spec import ShardingSpec + +__all__ = [ + 'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise', + 'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh', + 'redistribute', 'get_layout' + 'Layout', 'ShardingSpec' +] diff --git a/colossalai/tensor/d_tensor/api.py b/colossalai/tensor/d_tensor/api.py index b58edadfe..a38e5e6b7 100644 --- a/colossalai/tensor/d_tensor/api.py +++ b/colossalai/tensor/d_tensor/api.py @@ -1,3 +1,6 @@ +import copy +import operator +from functools import reduce from typing import Union import torch @@ -6,13 +9,165 @@ from torch.distributed import ProcessGroup from colossalai.device.device_mesh import DeviceMesh -from .d_tensor import DTensor +from .layout import Layout +from .layout_converter import LayoutConverter from .sharding_spec import ShardingSpec +layout_converter = LayoutConverter() -def shard_rowwise(tensor: torch.Tensor, - group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, - inplace: bool = False) -> DTensor: + +def is_distributed_tensor(tensor: torch.Tensor) -> bool: + """ + Check whether the given tensor is a distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: Whether the given tensor is a distributed tensor. + """ + return hasattr(tensor, "dist_layout") + + +def is_sharded(dtensor: torch.Tensor) -> bool: + """ + Check if a tensor is sharded. + + Args: + tensor (torch.Tensor): The tensor to be checked. + + Returns: + bool: True if the tensor is sharded, False otherwise. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return list(dtensor.shape) == list(dtensor.dist_layout.global_shape) + + +def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor: + """ + Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied. + + Args: + tensor (torch.Tensor): The tensor to be hijacked. + + Returns: + torch.Tensor: The hijacked tensor. + """ + dtensor._old_detach = dtensor.detach + dtensor._old_clone = dtensor.clone + + def new_detach(self): + t_ = self._old_detach() + t_.dist_layout = copy.deepcopy(self.dist_layout) + return t_ + + def new_clone(self, *args, **kwargs): + t_ = self._old_clone(*args, **kwargs) + t_.dist_layout = copy.deepcopy(self.dist_layout) + return t_ + + # bind the new methods to the tensor + dtensor.detach = new_detach.__get__(dtensor) + dtensor.clone = new_clone.__get__(dtensor) + return dtensor + + +def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: + ''' + Construct the default sharding specification for the tensor. + + Args: + tensor (`torch.Tensor`): the tensor to be sharded. + + Returns: + A `ShardingSpec` object without any sharding specified. + ''' + return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) + + +def _apply_layout(tensor, layout): + ''' + Apply the layout to the local tensor during initializing process. + ''' + # layout converter requires a source and target laytout + # we construct the source layer for an unsharded tensor + # and use self.dist_layer as the targer layout for the sharded tensor + source_spec = _construct_default_sharding_spec(tensor) + source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape) + sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout) + return sharded_tensor + + +def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor: + """ + Convert the given tensor to a distributed tensor. + + Args: + tensor (torch.Tensor): The tensor to be converted. + device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices. + sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded. + + Returns: + torch.Tensor: The distributed tensor. + """ + assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.' + dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape) + + # shard tensor + sharded_tensor = _apply_layout(tensor, dist_layout) + + # hack some tensor methods + _hijack_detach_and_clone(sharded_tensor) + + return sharded_tensor + + +def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: + ''' + Convert the layout of the tensor from source_spec to target_spec. + This will update the `local_tensor` and `dist_layout` in place. + + Args: + dtensor (torch.Tensor): the distributed tensor to be converted. + device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices. + target_layout (Layout): the target layout specification. + ''' + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + global_shape = get_global_shape(dtensor) + target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) + resharded_tensor = layout_converter.apply(tensor=dtensor, + source_layout=dtensor.dist_layout, + target_layout=target_layout) + return resharded_tensor + + +def to_global(dtensor: torch.Tensor) -> torch.Tensor: + """ + Convert a distributed tensor to the global tensor with the given layout. + This function returns a native `torch.Tensor` object. + + Args: + dtensor (torch.Tensor): the distributed tensor to be converted. + + Returns: + torch.Tensor: the global tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + layout_converter = LayoutConverter() + + global_sharding_spec = ShardingSpec(dtensor.dim(), {}) + device_mesh = get_device_mesh(dtensor) + global_shape = get_global_shape(dtensor) + global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape) + + global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout) + return global_tensor + + +def shard_rowwise( + tensor: torch.Tensor, + group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, +) -> torch.Tensor: """ Shard the first dim of the given tensor. @@ -24,7 +179,7 @@ def shard_rowwise(tensor: torch.Tensor, inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. Returns: - DTensor: The sharded tensor. + torch.Tensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: @@ -35,17 +190,13 @@ def shard_rowwise(tensor: torch.Tensor, else: assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' device_mesh = group_or_device_mesh + sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) - if not inplace: - tensor = tensor.detach().clone() - - return DTensor(tensor, device_mesh, sharding_spec) + return distribute_tensor(tensor, device_mesh, sharding_spec) -def shard_colwise(tensor: torch.Tensor, - group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, - inplace: bool = False) -> DTensor: +def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor: """ Shard the first dim of the given tensor. @@ -57,7 +208,7 @@ def shard_colwise(tensor: torch.Tensor, inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. Returns: - DTensor: The sharded tensor. + torch.Tensor: The sharded tensor. """ # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group if group_or_device_mesh is None: @@ -70,7 +221,87 @@ def shard_colwise(tensor: torch.Tensor, device_mesh = group_or_device_mesh sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) - if not inplace: - tensor = tensor.detach().clone() + return distribute_tensor(tensor, device_mesh, sharding_spec) - return DTensor(tensor, device_mesh, sharding_spec) + +def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True): + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + param = torch.nn.Parameter(dtensor, requires_grad=requires_grad) + + # make it distributed as well + param.dist_layout = dtensor.dist_layout + _hijack_detach_and_clone(param) + + return param + + +def compute_global_numel(dtensor: torch.Tensor) -> int: + """ + Compute the global number of elements in the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + int: The global number of elements in the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + numel = reduce(operator.mul, dtensor.dist_layout.global_shape) + return numel + + +def get_layout(dtensor: torch.Tensor) -> Layout: + """ + Get the layout of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + Layout: The layout of the distributed tensor. + + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout + + +def get_global_shape(dtensor: torch.Tensor) -> torch.Size: + """ + Get the global shape of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + torch.Size: The global shape of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout.global_shape + + +def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh: + """ + Get the device mesh of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + DeviceMesh: The device mesh of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout.device_mesh + + +def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec: + """ + Get the sharding spec of the distributed tensor. + + Args: + dtensor (torch.Tensor): The distributed tensor. + + Returns: + ShardingSpec: The sharding spec of the distributed tensor. + """ + assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.' + return dtensor.dist_layout.sharding_spec diff --git a/colossalai/tensor/d_tensor/layout.py b/colossalai/tensor/d_tensor/layout.py index f15956ea3..4185b8586 100644 --- a/colossalai/tensor/d_tensor/layout.py +++ b/colossalai/tensor/d_tensor/layout.py @@ -1,12 +1,11 @@ import operator -from dataclasses import dataclass from functools import reduce import torch from colossalai.device.device_mesh import DeviceMesh -from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError +from .misc import DuplicatedShardingDimensionError, ShardingNotDivisibleError from .sharding_spec import ShardingSpec diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index abc70e19a..14f9c4561 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -28,18 +28,6 @@ class LayoutConverterOptions: pass -def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: - layout_converter = LayoutConverter() - global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) - global_layout = Layout(device_mesh=layout.device_mesh, - device_type=layout.device_type, - sharding_spec=global_sharding_spec, - entire_shape=layout.entire_shape) - with torch.no_grad(): - global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout) - return global_tensor - - def set_layout_converting_options(options: LayoutConverterOptions): """ Configure the shape consistency manager via function call. @@ -553,4 +541,5 @@ class LayoutConverter(metaclass=SingletonMeta): _, comm_action_sequence = self.layout_converting(source_layout, target_layout) for comm_spec in comm_action_sequence: tensor = comm_spec.covert_spec_to_action(tensor) + tensor.dist_layout = target_layout return tensor diff --git a/colossalai/tensor/d_tensor/utils.py b/colossalai/tensor/d_tensor/utils.py index 644bb6306..fc22b990d 100644 --- a/colossalai/tensor/d_tensor/utils.py +++ b/colossalai/tensor/d_tensor/utils.py @@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals # the comm size for all gather is the size of the gathered tensor gather_dim = comm_spec.gather_dim all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1] - all_gather_size = device_mesh.mesh_shape[all_gather_axis] + all_gather_size = device_mesh.shape[all_gather_axis] comm_size_for_all_gather = comm_size * all_gather_size forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis) # give a tiny cost to shard diff --git a/test.py b/test.py new file mode 100644 index 000000000..f283e21a1 --- /dev/null +++ b/test.py @@ -0,0 +1 @@ +from colossalai.tensor.d_tensor.api import to_distributed_tensor diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 2dd8d1ca3..3879363bc 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -7,7 +7,8 @@ import torch from packaging import version from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor -from colossalai.tensor.d_tensor.layout_converter import to_global +from colossalai.tensor.d_tensor import to_global +from colossalai.tensor.d_tensor.layout import Layout from tests.kit.model_zoo.registry import ModelAttribute SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0') @@ -91,6 +92,8 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn. assert n1 == n2 t1 = t1.cuda() t2 = t2.cuda() - if n2 in layout_dict: - t2 = to_global(t2, layout_dict[n2]) + if n2 in sharding_spec_dict: + layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape) + t2.dist_layout = layout + t2 = to_global(t2) assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' diff --git a/tests/test_shardformer/test_layer/test_embedding.py b/tests/test_shardformer/test_layer/test_embedding.py index 70500008c..8a6aa42a4 100644 --- a/tests/test_shardformer/test_layer/test_embedding.py +++ b/tests/test_shardformer/test_layer/test_embedding.py @@ -14,6 +14,10 @@ def check_embedding_1d(): assert embedding_1d.weight.shape == torch.Size([32, 64]) + # ensure state dict is reversibly loadable + embedding.load_state_dict(embedding_1d.state_dict()) + embedding_1d.load_state_dict(embedding.state_dict()) + # check computation correctness x = torch.randint(low=0, high=32, size=(4, 32)).cuda() out = embedding(x) diff --git a/tests/test_shardformer/test_layer/test_linear_1d.py b/tests/test_shardformer/test_layer/test_linear_1d.py index 00ecc37ce..a2b8bf22c 100644 --- a/tests/test_shardformer/test_layer/test_linear_1d.py +++ b/tests/test_shardformer/test_layer/test_linear_1d.py @@ -5,6 +5,7 @@ from torch.testing import assert_close import colossalai from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row +from colossalai.tensor.d_tensor import is_distributed_tensor from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -12,9 +13,18 @@ def check_linear_1d_col(): linear = nn.Linear(32, 128).cuda() linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) + # ensure that the parameters are distributed + assert is_distributed_tensor(linear_col.weight) + assert is_distributed_tensor(linear_col.bias) + + # ensure the shape is correct assert linear_col.weight.shape == torch.Size([64, 32]) assert linear_col.bias.shape == torch.Size([64]) + # ensure state dict is reversibly loadable + linear.load_state_dict(linear_col.state_dict()) + linear_col.load_state_dict(linear.state_dict()) + # check computation correctness x = torch.rand(4, 32).cuda() out = linear(x) @@ -55,7 +65,7 @@ def check_linear_1d_row(): def run_dist(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') check_linear_1d_col() - check_linear_1d_row() + # check_linear_1d_row() @rerun_if_address_is_in_use() diff --git a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py index bee44a2fb..8991d9b30 100644 --- a/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py +++ b/tests/test_shardformer/test_layer/test_vocab_parallel_embedding_1d.py @@ -14,7 +14,11 @@ def check_vocab_embedding_1d(): assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) assert dist_embedding_1d.num_embeddings == 64 - assert dist_embedding_1d.embed_dim == 32 + assert dist_embedding_1d.embedding_dim == 32 + + # ensure state dict is reversibly loadable + embedding.load_state_dict(dist_embedding_1d.state_dict()) + dist_embedding_1d.load_state_dict(embedding.state_dict()) # check embedding correctness x = torch.randint(0, 128, (4, 32)).to('cuda') diff --git a/tests/test_tensor/test_dtensor/test_comm_spec.py b/tests/test_tensor/test_dtensor/test_comm_spec.py index d1f5b9299..958eabb65 100644 --- a/tests/test_tensor/test_dtensor/test_comm_spec.py +++ b/tests/test_tensor/test_dtensor/test_comm_spec.py @@ -1,14 +1,11 @@ import pytest import torch -import torch.distributed as dist -from torch.distributed import ReduceOp from colossalai.core import global_context as gpc from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn diff --git a/tests/test_tensor/test_dtensor/test_dtensor.py b/tests/test_tensor/test_dtensor/test_dtensor.py index 3ca369acb..8350fb3e7 100644 --- a/tests/test_tensor/test_dtensor/test_dtensor.py +++ b/tests/test_tensor/test_dtensor/test_dtensor.py @@ -3,9 +3,7 @@ import torch from colossalai.device.device_mesh import DeviceMesh from colossalai.initialize import launch from colossalai.logging import disable_existing_loggers -from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor -from colossalai.tensor.d_tensor.layout import Layout -from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec +from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global from colossalai.testing import rerun_if_address_is_in_use, spawn @@ -31,22 +29,18 @@ def check_dtensor(rank, world_size, port): device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) - layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=target_sharding_spec, - entire_shape=original_tensor.shape) - d_tensor = DTensor(original_tensor, layout) + d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec) - assert d_tensor.entire_shape == original_tensor.shape - assert d_tensor.data_type == original_tensor.dtype + assert get_global_shape(d_tensor) == original_tensor.shape + assert d_tensor.dtype == original_tensor.dtype if rank in (0, 1): - assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2)) + assert d_tensor.equal(original_tensor.narrow(0, 0, 2)) elif rank in (2, 3): - assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2)) + assert d_tensor.equal(original_tensor.narrow(0, 2, 2)) else: raise ValueError(f'rank {rank} is not in the device mesh') - assert d_tensor.to_global().equal(original_tensor) + assert to_global(d_tensor).equal(original_tensor) output = test_model(d_tensor) if rank in (0, 1): @@ -57,34 +51,29 @@ def check_dtensor(rank, world_size, port): raise ValueError(f'rank {rank} is not in the device mesh') new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) - new_layout = Layout(device_mesh=device_mesh, - device_type=torch.device('cuda'), - sharding_spec=new_sharding_spec, - entire_shape=original_tensor.shape) - - d_tensor.layout_convert(new_layout) + d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec) if rank == 0: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 0, 1)) elif rank == 1: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 1, 1)) elif rank == 2: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 2, 1)) elif rank == 3: - assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + assert d_tensor.equal(original_tensor.narrow(0, 3, 1)) else: raise ValueError(f'rank {rank} is not in the device mesh') dtensor_from_local = distribute_tensor(original_tensor, new_layout) if rank == 0: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1)) elif rank == 1: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1)) elif rank == 2: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1)) elif rank == 3: - assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1)) + assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1)) else: raise ValueError(f'rank {rank} is not in the device mesh') diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 5c3da5f2b..d9dff8af9 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -9,7 +9,7 @@ from colossalai.logging import disable_existing_loggers from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout_converter import LayoutConverter -from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec +from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec from colossalai.testing import rerun_if_address_is_in_use, spawn entire_shape = torch.Size((64, 32, 16))