[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
pull/4157/head
Frank Lee 2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@ -10,7 +10,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.tensor.d_tensor.d_tensor import DTensor from colossalai.tensor.d_tensor import is_distributed_tensor
SAFE_WEIGHTS_NAME = "model.safetensors" SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin" WEIGHTS_NAME = "pytorch_model.bin"
@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items(): for key, weight in state_dict.items():
ret_block = None ret_block = None
ret_block_size = 0 ret_block_size = 0
if type(weight) != DTensor: if is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight) weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split. # If this weight is going to tip up over the maximal size, we split.

View File

@ -8,8 +8,9 @@ from torch import Tensor
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor from colossalai._analyzer._subclasses import MetaTensor
from colossalai.tensor.d_tensor.d_tensor import DTensor from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor import distribute_tensor
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [ _NORMAL_FACTORY = [
@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor):
""" """
target = self._materialize_data() target = self._materialize_data()
self.clean() self.clean()
local_tensor = DTensor(target, layout).local_tensor local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
return _convert_cls(self, local_tensor) return _convert_cls(self, local_tensor)
def clean(self) -> None: def clean(self) -> None:

View File

@ -13,8 +13,7 @@ from torch.nn.parameter import Parameter
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.utils.cuda import get_current_device
from ._operation import gather_forward_split_backward, reduce_input from ._operation import gather_forward_split_backward, reduce_input
from .parallel_module import ParallelModule from .parallel_module import ParallelModule
@ -69,18 +68,17 @@ class Embedding1D(ParallelModule):
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.process_group = process_group self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.embed_args = args self.embed_args = args
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
self.gather_output = gather_output self.gather_output = gather_output
if device is None: # Parameters.
device = get_current_device() factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype)) sharded_weight = shard_colwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank # offset the seed with randomizer index and rank
seed = torch.random.initial_seed() seed = torch.random.initial_seed()
@ -194,7 +192,7 @@ class VocabParallelEmbedding1D(ParallelModule):
**kwargs): **kwargs):
super().__init__() super().__init__()
self.num_embeddings = num_embeddings self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim self.embedding_dim = embedding_dim
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.embed_args = args self.embed_args = args
self.embed_kwargs = kwargs self.embed_kwargs = kwargs
@ -208,8 +206,11 @@ class VocabParallelEmbedding1D(ParallelModule):
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter( # parameter
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_rowwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank # offset the seed with randomizer index and rank
seed = torch.random.initial_seed() seed = torch.random.initial_seed()
@ -252,7 +253,7 @@ class VocabParallelEmbedding1D(ParallelModule):
def reset_parameters(self, weight_initializer) -> None: def reset_parameters(self, weight_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True): with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.num_embeddings, self.embed_dim fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero() self._fill_padding_idx_with_zero()

View File

@ -14,7 +14,7 @@ from torch.nn.parameter import Parameter
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from ._operation import ( from ._operation import (
@ -76,22 +76,21 @@ class Linear1D_Col(ParallelModule):
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.device = device self.device = device
self.process_group = process_group self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None') raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, self.num_partitions)
# Parameters. # Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_rowwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if bias: if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs)) bias = torch.empty(self.out_features, **factory_kwargs)
sharded_bias = shard_colwise(bias, self.process_group)
self.bias = sharded_tensor_to_param(sharded_bias)
else: else:
self.bias = None self.bias = None
@ -128,7 +127,6 @@ class Linear1D_Col(ParallelModule):
*args, *args,
**kwargs) **kwargs)
# TODO: copy the sharded weights
with torch.no_grad(): with torch.no_grad():
# the weigh to the linear layer is a transpose # the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column # thus shard on row is equal to shard on column
@ -137,7 +135,6 @@ class Linear1D_Col(ParallelModule):
if bias: if bias:
sharded_bias = shard_colwise(module.bias.data, process_group) sharded_bias = shard_colwise(module.bias.data, process_group)
linear_1d.bias.copy_(sharded_bias) linear_1d.bias.copy_(sharded_bias)
return linear_1d return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@ -212,21 +209,20 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.process_group = process_group self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias: if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None') raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)
# Parameters. # Parameters.
# Initialize weight. # Initialize weight.
if device is None: if device is None:
device = get_current_device() device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype} factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_colwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if self.stream_chunk_num > 1: if self.stream_chunk_num > 1:
# TODO() work for inference only # TODO() work for inference only
@ -340,3 +336,5 @@ class Linear1D_Row(ParallelModule):
return output return output
else: else:
return output, self.bias return output, self.bias
return output, self.bias
return output, self.bias

View File

@ -31,7 +31,6 @@ __all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
class LinearConv1D_Col(ParallelModule): class LinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism. r"""Linear layer with column parallelism.
Specially created for HuggingFace's GPT2 model.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface. its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
@ -189,7 +188,6 @@ class LinearConv1D_Col(ParallelModule):
class LinearConv1D_Row(ParallelModule): class LinearConv1D_Row(ParallelModule):
r""" Linear layer with row parallelism r""" Linear layer with row parallelism
Specially created for HuggingFace's GPT2 model.
Args: Args:
in_features (int): size of each input sample. in_features (int): size of each input sample.

View File

@ -1,11 +1,23 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import itertools
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Union from typing import List, Union
import torch
import torch.nn as nn import torch.nn as nn
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
from colossalai.tensor.d_tensor import (
distribute_tensor,
get_device_mesh,
get_sharding_spec,
is_distributed_tensor,
sharded_tensor_to_param,
to_global,
)
__all__ = ['ParallelModule'] __all__ = ['ParallelModule']
@ -25,3 +37,133 @@ class ParallelModule(nn.Module, ABC):
in the ith axis of the device mesh. Defaults to None, which means the global process group. in the ith axis of the device mesh. Defaults to None, which means the global process group.
""" """
pass pass
def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for name, param in self._parameters.items():
if param is not None:
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
destination[prefix + name] = to_global(param_)
else:
destination[prefix + name] = param_
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append('While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'.format(key, type(input_param)))
continue
if is_distributed_tensor(param):
# shard the input param
device_mesh = get_device_mesh(param)
sharding_spec = get_sharding_spec(param)
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
input_param = sharded_tensor_to_param(sharded_tensor)
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(key, input_param.shape, param.shape))
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
ex.args))
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)

View File

@ -0,0 +1,24 @@
from .api import (
compute_global_numel,
distribute_tensor,
get_device_mesh,
get_global_shape,
get_layout,
get_sharding_spec,
is_distributed_tensor,
is_sharded,
redistribute,
shard_colwise,
shard_rowwise,
sharded_tensor_to_param,
to_global,
)
from .layout import Layout
from .sharding_spec import ShardingSpec
__all__ = [
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
'redistribute', 'get_layout'
'Layout', 'ShardingSpec'
]

View File

@ -1,3 +1,6 @@
import copy
import operator
from functools import reduce
from typing import Union from typing import Union
import torch import torch
@ -6,13 +9,165 @@ from torch.distributed import ProcessGroup
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from .d_tensor import DTensor from .layout import Layout
from .layout_converter import LayoutConverter
from .sharding_spec import ShardingSpec from .sharding_spec import ShardingSpec
layout_converter = LayoutConverter()
def shard_rowwise(tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None, def is_distributed_tensor(tensor: torch.Tensor) -> bool:
inplace: bool = False) -> DTensor: """
Check whether the given tensor is a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a distributed tensor.
"""
return hasattr(tensor, "dist_layout")
def is_sharded(dtensor: torch.Tensor) -> bool:
"""
Check if a tensor is sharded.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: True if the tensor is sharded, False otherwise.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
'''
Construct the default sharding specification for the tensor.
Args:
tensor (`torch.Tensor`): the tensor to be sharded.
Returns:
A `ShardingSpec` object without any sharding specified.
'''
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
def _apply_layout(tensor, layout):
'''
Apply the layout to the local tensor during initializing process.
'''
# layout converter requires a source and target laytout
# we construct the source layer for an unsharded tensor
# and use self.dist_layer as the targer layout for the sharded tensor
source_spec = _construct_default_sharding_spec(tensor)
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
return sharded_tensor
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Convert the given tensor to a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be converted.
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
Returns:
torch.Tensor: The distributed tensor.
"""
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
# shard tensor
sharded_tensor = _apply_layout(tensor, dist_layout)
# hack some tensor methods
_hijack_detach_and_clone(sharded_tensor)
return sharded_tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
'''
Convert the layout of the tensor from source_spec to target_spec.
This will update the `local_tensor` and `dist_layout` in place.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
target_layout (Layout): the target layout specification.
'''
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
global_shape = get_global_shape(dtensor)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
resharded_tensor = layout_converter.apply(tensor=dtensor,
source_layout=dtensor.dist_layout,
target_layout=target_layout)
return resharded_tensor
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
"""
Convert a distributed tensor to the global tensor with the given layout.
This function returns a native `torch.Tensor` object.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
Returns:
torch.Tensor: the global tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
device_mesh = get_device_mesh(dtensor)
global_shape = get_global_shape(dtensor)
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
return global_tensor
def shard_rowwise(
tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
) -> torch.Tensor:
""" """
Shard the first dim of the given tensor. Shard the first dim of the given tensor.
@ -24,7 +179,7 @@ def shard_rowwise(tensor: torch.Tensor,
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns: Returns:
DTensor: The sharded tensor. torch.Tensor: The sharded tensor.
""" """
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None: if group_or_device_mesh is None:
@ -35,17 +190,13 @@ def shard_rowwise(tensor: torch.Tensor,
else: else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.' assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]}) sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
if not inplace: return distribute_tensor(tensor, device_mesh, sharding_spec)
tensor = tensor.detach().clone()
return DTensor(tensor, device_mesh, sharding_spec)
def shard_colwise(tensor: torch.Tensor, def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
inplace: bool = False) -> DTensor:
""" """
Shard the first dim of the given tensor. Shard the first dim of the given tensor.
@ -57,7 +208,7 @@ def shard_colwise(tensor: torch.Tensor,
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False. inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns: Returns:
DTensor: The sharded tensor. torch.Tensor: The sharded tensor.
""" """
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group # if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None: if group_or_device_mesh is None:
@ -70,7 +221,87 @@ def shard_colwise(tensor: torch.Tensor,
device_mesh = group_or_device_mesh device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]}) sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
if not inplace: return distribute_tensor(tensor, device_mesh, sharding_spec)
tensor = tensor.detach().clone()
return DTensor(tensor, device_mesh, sharding_spec)
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.dist_layout = dtensor.dist_layout
_hijack_detach_and_clone(param)
return param
def compute_global_numel(dtensor: torch.Tensor) -> int:
"""
Compute the global number of elements in the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
int: The global number of elements in the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
return numel
def get_layout(dtensor: torch.Tensor) -> Layout:
"""
Get the layout of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
Layout: The layout of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
"""
Get the global shape of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Size: The global shape of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.global_shape
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
"""
Get the device mesh of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
DeviceMesh: The device mesh of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.device_mesh
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
"""
Get the sharding spec of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
ShardingSpec: The sharding spec of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.sharding_spec

View File

@ -1,12 +1,11 @@
import operator import operator
from dataclasses import dataclass
from functools import reduce from functools import reduce
import torch import torch
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError from .misc import DuplicatedShardingDimensionError, ShardingNotDivisibleError
from .sharding_spec import ShardingSpec from .sharding_spec import ShardingSpec

View File

@ -28,18 +28,6 @@ class LayoutConverterOptions:
pass pass
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
global_layout = Layout(device_mesh=layout.device_mesh,
device_type=layout.device_type,
sharding_spec=global_sharding_spec,
entire_shape=layout.entire_shape)
with torch.no_grad():
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
return global_tensor
def set_layout_converting_options(options: LayoutConverterOptions): def set_layout_converting_options(options: LayoutConverterOptions):
""" """
Configure the shape consistency manager via function call. Configure the shape consistency manager via function call.
@ -553,4 +541,5 @@ class LayoutConverter(metaclass=SingletonMeta):
_, comm_action_sequence = self.layout_converting(source_layout, target_layout) _, comm_action_sequence = self.layout_converting(source_layout, target_layout)
for comm_spec in comm_action_sequence: for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor) tensor = comm_spec.covert_spec_to_action(tensor)
tensor.dist_layout = target_layout
return tensor return tensor

View File

@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals
# the comm size for all gather is the size of the gathered tensor # the comm size for all gather is the size of the gathered tensor
gather_dim = comm_spec.gather_dim gather_dim = comm_spec.gather_dim
all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1] all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
all_gather_size = device_mesh.mesh_shape[all_gather_axis] all_gather_size = device_mesh.shape[all_gather_axis]
comm_size_for_all_gather = comm_size * all_gather_size comm_size_for_all_gather = comm_size * all_gather_size
forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis) forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
# give a tiny cost to shard # give a tiny cost to shard

1
test.py Normal file
View File

@ -0,0 +1 @@
from colossalai.tensor.d_tensor.api import to_distributed_tensor

View File

@ -7,7 +7,8 @@ import torch
from packaging import version from packaging import version
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
from colossalai.tensor.d_tensor.layout_converter import to_global from colossalai.tensor.d_tensor import to_global
from colossalai.tensor.d_tensor.layout import Layout
from tests.kit.model_zoo.registry import ModelAttribute from tests.kit.model_zoo.registry import ModelAttribute
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0') SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
@ -91,6 +92,8 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
assert n1 == n2 assert n1 == n2
t1 = t1.cuda() t1 = t1.cuda()
t2 = t2.cuda() t2 = t2.cuda()
if n2 in layout_dict: if n2 in sharding_spec_dict:
t2 = to_global(t2, layout_dict[n2]) layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
t2.dist_layout = layout
t2 = to_global(t2)
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'

View File

@ -14,6 +14,10 @@ def check_embedding_1d():
assert embedding_1d.weight.shape == torch.Size([32, 64]) assert embedding_1d.weight.shape == torch.Size([32, 64])
# ensure state dict is reversibly loadable
embedding.load_state_dict(embedding_1d.state_dict())
embedding_1d.load_state_dict(embedding.state_dict())
# check computation correctness # check computation correctness
x = torch.randint(low=0, high=32, size=(4, 32)).cuda() x = torch.randint(low=0, high=32, size=(4, 32)).cuda()
out = embedding(x) out = embedding(x)

View File

@ -5,6 +5,7 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -12,9 +13,18 @@ def check_linear_1d_col():
linear = nn.Linear(32, 128).cuda() linear = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True) linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
assert is_distributed_tensor(linear_col.bias)
# ensure the shape is correct
assert linear_col.weight.shape == torch.Size([64, 32]) assert linear_col.weight.shape == torch.Size([64, 32])
assert linear_col.bias.shape == torch.Size([64]) assert linear_col.bias.shape == torch.Size([64])
# ensure state dict is reversibly loadable
linear.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear.state_dict())
# check computation correctness # check computation correctness
x = torch.rand(4, 32).cuda() x = torch.rand(4, 32).cuda()
out = linear(x) out = linear(x)
@ -55,7 +65,7 @@ def check_linear_1d_row():
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_1d_col() check_linear_1d_col()
check_linear_1d_row() # check_linear_1d_row()
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()

View File

@ -14,7 +14,11 @@ def check_vocab_embedding_1d():
assert dist_embedding_1d.weight.shape == torch.Size([64, 32]) assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
assert dist_embedding_1d.num_embeddings == 64 assert dist_embedding_1d.num_embeddings == 64
assert dist_embedding_1d.embed_dim == 32 assert dist_embedding_1d.embedding_dim == 32
# ensure state dict is reversibly loadable
embedding.load_state_dict(dist_embedding_1d.state_dict())
dist_embedding_1d.load_state_dict(embedding.state_dict())
# check embedding correctness # check embedding correctness
x = torch.randint(0, 128, (4, 32)).to('cuda') x = torch.randint(0, 128, (4, 32)).to('cuda')

View File

@ -1,14 +1,11 @@
import pytest import pytest
import torch import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn

View File

@ -3,9 +3,7 @@ import torch
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -31,22 +29,18 @@ def check_dtensor(rank, world_size, port):
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
layout = Layout(device_mesh=device_mesh, d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor = DTensor(original_tensor, layout)
assert d_tensor.entire_shape == original_tensor.shape assert get_global_shape(d_tensor) == original_tensor.shape
assert d_tensor.data_type == original_tensor.dtype assert d_tensor.dtype == original_tensor.dtype
if rank in (0, 1): if rank in (0, 1):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2)) assert d_tensor.equal(original_tensor.narrow(0, 0, 2))
elif rank in (2, 3): elif rank in (2, 3):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2)) assert d_tensor.equal(original_tensor.narrow(0, 2, 2))
else: else:
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')
assert d_tensor.to_global().equal(original_tensor) assert to_global(d_tensor).equal(original_tensor)
output = test_model(d_tensor) output = test_model(d_tensor)
if rank in (0, 1): if rank in (0, 1):
@ -57,34 +51,29 @@ def check_dtensor(rank, world_size, port):
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
new_layout = Layout(device_mesh=device_mesh, d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec)
device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor.layout_convert(new_layout)
if rank == 0: if rank == 0:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) assert d_tensor.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1: elif rank == 1:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1)) assert d_tensor.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2: elif rank == 2:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1)) assert d_tensor.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3: elif rank == 3:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1)) assert d_tensor.equal(original_tensor.narrow(0, 3, 1))
else: else:
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')
dtensor_from_local = distribute_tensor(original_tensor, new_layout) dtensor_from_local = distribute_tensor(original_tensor, new_layout)
if rank == 0: if rank == 0:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1: elif rank == 1:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1)) assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2: elif rank == 2:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1)) assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3: elif rank == 3:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1)) assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1))
else: else:
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')

View File

@ -9,7 +9,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
entire_shape = torch.Size((64, 32, 16)) entire_shape = torch.Size((64, 32, 16))