[shardformer] support module saving and loading (#4062)

* [shardformer] support module saving and loading

* polish code
pull/4157/head
Frank Lee 2023-06-22 11:42:11 +08:00
parent 7740c55c55
commit 8eb09a4c69
19 changed files with 493 additions and 102 deletions

View File

@ -10,7 +10,7 @@ import torch
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor import is_distributed_tensor
SAFE_WEIGHTS_NAME = "model.safetensors"
WEIGHTS_NAME = "pytorch_model.bin"
@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items():
ret_block = None
ret_block_size = 0
if type(weight) != DTensor:
if is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split.

View File

@ -8,8 +8,9 @@ from torch import Tensor
from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor
from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor import distribute_tensor
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [
@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor):
"""
target = self._materialize_data()
self.clean()
local_tensor = DTensor(target, layout).local_tensor
local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
return _convert_cls(self, local_tensor)
def clean(self) -> None:

View File

@ -13,8 +13,7 @@ from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.utils.cuda import get_current_device
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
from ._operation import gather_forward_split_backward, reduce_input
from .parallel_module import ParallelModule
@ -69,18 +68,17 @@ class Embedding1D(ParallelModule):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.gather_output = gather_output
if device is None:
device = get_current_device()
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
# Parameters.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_colwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
@ -194,7 +192,7 @@ class VocabParallelEmbedding1D(ParallelModule):
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
@ -208,8 +206,11 @@ class VocabParallelEmbedding1D(ParallelModule):
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
# parameter
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_rowwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
@ -252,7 +253,7 @@ class VocabParallelEmbedding1D(ParallelModule):
def reset_parameters(self, weight_initializer) -> None:
with self.randomizer.fork_rng(enable_cpu=True):
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

View File

@ -14,7 +14,7 @@ from torch.nn.parameter import Parameter
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.utils.cuda import get_current_device
from ._operation import (
@ -76,22 +76,21 @@ class Linear1D_Col(ParallelModule):
self.skip_bias_add = skip_bias_add
self.device = device
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
self.out_features_per_partition = divide(out_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_rowwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if bias:
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
bias = torch.empty(self.out_features, **factory_kwargs)
sharded_bias = shard_colwise(bias, self.process_group)
self.bias = sharded_tensor_to_param(sharded_bias)
else:
self.bias = None
@ -128,7 +127,6 @@ class Linear1D_Col(ParallelModule):
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on row is equal to shard on column
@ -137,7 +135,6 @@ class Linear1D_Col(ParallelModule):
if bias:
sharded_bias = shard_colwise(module.bias.data, process_group)
linear_1d.bias.copy_(sharded_bias)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@ -212,21 +209,20 @@ class Linear1D_Row(ParallelModule):
self.parallel_input = parallel_input
self.skip_bias_add = skip_bias_add
self.process_group = process_group
self.num_partitions = dist.get_world_size(self.process_group)
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)
# Parameters.
# Initialize weight.
if device is None:
device = get_current_device()
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_colwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
@ -340,3 +336,5 @@ class Linear1D_Row(ParallelModule):
return output
else:
return output, self.bias
return output, self.bias
return output, self.bias

View File

@ -31,7 +31,6 @@ __all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
class LinearConv1D_Col(ParallelModule):
r"""Linear layer with column parallelism.
Specially created for HuggingFace's GPT2 model.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
@ -189,7 +188,6 @@ class LinearConv1D_Col(ParallelModule):
class LinearConv1D_Row(ParallelModule):
r""" Linear layer with row parallelism
Specially created for HuggingFace's GPT2 model.
Args:
in_features (int): size of each input sample.

View File

@ -1,11 +1,23 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import itertools
from abc import ABC, abstractmethod
from typing import List, Union
import torch
import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
from colossalai.tensor.d_tensor import (
distribute_tensor,
get_device_mesh,
get_sharding_spec,
is_distributed_tensor,
sharded_tensor_to_param,
to_global,
)
__all__ = ['ParallelModule']
@ -25,3 +37,133 @@ class ParallelModule(nn.Module, ABC):
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
pass
def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for name, param in self._parameters.items():
if param is not None:
param_ = param if keep_vars else param.detach()
if is_distributed_tensor(param_):
destination[prefix + name] = to_global(param_)
else:
destination[prefix + name] = param_
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append('While copying the parameter named "{}", '
'expected torch.Tensor or Tensor-like object from checkpoint but '
'received {}'.format(key, type(input_param)))
continue
if is_distributed_tensor(param):
# shard the input param
device_mesh = get_device_mesh(param)
sharding_spec = get_sharding_spec(param)
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
input_param = sharded_tensor_to_param(sharded_tensor)
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'.format(key, input_param.shape, param.shape))
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append('While copying the parameter named "{}", '
'whose dimensions in the model are {} and '
'whose dimensions in the checkpoint are {}, '
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
ex.args))
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix):]
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)

View File

@ -0,0 +1,24 @@
from .api import (
compute_global_numel,
distribute_tensor,
get_device_mesh,
get_global_shape,
get_layout,
get_sharding_spec,
is_distributed_tensor,
is_sharded,
redistribute,
shard_colwise,
shard_rowwise,
sharded_tensor_to_param,
to_global,
)
from .layout import Layout
from .sharding_spec import ShardingSpec
__all__ = [
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
'redistribute', 'get_layout'
'Layout', 'ShardingSpec'
]

View File

@ -1,3 +1,6 @@
import copy
import operator
from functools import reduce
from typing import Union
import torch
@ -6,13 +9,165 @@ from torch.distributed import ProcessGroup
from colossalai.device.device_mesh import DeviceMesh
from .d_tensor import DTensor
from .layout import Layout
from .layout_converter import LayoutConverter
from .sharding_spec import ShardingSpec
layout_converter = LayoutConverter()
def shard_rowwise(tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
inplace: bool = False) -> DTensor:
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a distributed tensor.
"""
return hasattr(tensor, "dist_layout")
def is_sharded(dtensor: torch.Tensor) -> bool:
"""
Check if a tensor is sharded.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: True if the tensor is sharded, False otherwise.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
dtensor._old_detach = dtensor.detach
dtensor._old_clone = dtensor.clone
def new_detach(self):
t_ = self._old_detach()
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
def new_clone(self, *args, **kwargs):
t_ = self._old_clone(*args, **kwargs)
t_.dist_layout = copy.deepcopy(self.dist_layout)
return t_
# bind the new methods to the tensor
dtensor.detach = new_detach.__get__(dtensor)
dtensor.clone = new_clone.__get__(dtensor)
return dtensor
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
'''
Construct the default sharding specification for the tensor.
Args:
tensor (`torch.Tensor`): the tensor to be sharded.
Returns:
A `ShardingSpec` object without any sharding specified.
'''
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
def _apply_layout(tensor, layout):
'''
Apply the layout to the local tensor during initializing process.
'''
# layout converter requires a source and target laytout
# we construct the source layer for an unsharded tensor
# and use self.dist_layer as the targer layout for the sharded tensor
source_spec = _construct_default_sharding_spec(tensor)
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
return sharded_tensor
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""
Convert the given tensor to a distributed tensor.
Args:
tensor (torch.Tensor): The tensor to be converted.
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
Returns:
torch.Tensor: The distributed tensor.
"""
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
# shard tensor
sharded_tensor = _apply_layout(tensor, dist_layout)
# hack some tensor methods
_hijack_detach_and_clone(sharded_tensor)
return sharded_tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
'''
Convert the layout of the tensor from source_spec to target_spec.
This will update the `local_tensor` and `dist_layout` in place.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
target_layout (Layout): the target layout specification.
'''
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
global_shape = get_global_shape(dtensor)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
resharded_tensor = layout_converter.apply(tensor=dtensor,
source_layout=dtensor.dist_layout,
target_layout=target_layout)
return resharded_tensor
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
"""
Convert a distributed tensor to the global tensor with the given layout.
This function returns a native `torch.Tensor` object.
Args:
dtensor (torch.Tensor): the distributed tensor to be converted.
Returns:
torch.Tensor: the global tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
device_mesh = get_device_mesh(dtensor)
global_shape = get_global_shape(dtensor)
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
return global_tensor
def shard_rowwise(
tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
) -> torch.Tensor:
"""
Shard the first dim of the given tensor.
@ -24,7 +179,7 @@ def shard_rowwise(tensor: torch.Tensor,
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
DTensor: The sharded tensor.
torch.Tensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
@ -35,17 +190,13 @@ def shard_rowwise(tensor: torch.Tensor,
else:
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
if not inplace:
tensor = tensor.detach().clone()
return DTensor(tensor, device_mesh, sharding_spec)
return distribute_tensor(tensor, device_mesh, sharding_spec)
def shard_colwise(tensor: torch.Tensor,
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
inplace: bool = False) -> DTensor:
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
"""
Shard the first dim of the given tensor.
@ -57,7 +208,7 @@ def shard_colwise(tensor: torch.Tensor,
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
Returns:
DTensor: The sharded tensor.
torch.Tensor: The sharded tensor.
"""
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
if group_or_device_mesh is None:
@ -70,7 +221,87 @@ def shard_colwise(tensor: torch.Tensor,
device_mesh = group_or_device_mesh
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
if not inplace:
tensor = tensor.detach().clone()
return distribute_tensor(tensor, device_mesh, sharding_spec)
return DTensor(tensor, device_mesh, sharding_spec)
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
# make it distributed as well
param.dist_layout = dtensor.dist_layout
_hijack_detach_and_clone(param)
return param
def compute_global_numel(dtensor: torch.Tensor) -> int:
"""
Compute the global number of elements in the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
int: The global number of elements in the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
return numel
def get_layout(dtensor: torch.Tensor) -> Layout:
"""
Get the layout of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
Layout: The layout of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
"""
Get the global shape of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
torch.Size: The global shape of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.global_shape
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
"""
Get the device mesh of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
DeviceMesh: The device mesh of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.device_mesh
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
"""
Get the sharding spec of the distributed tensor.
Args:
dtensor (torch.Tensor): The distributed tensor.
Returns:
ShardingSpec: The sharding spec of the distributed tensor.
"""
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
return dtensor.dist_layout.sharding_spec

View File

@ -1,12 +1,11 @@
import operator
from dataclasses import dataclass
from functools import reduce
import torch
from colossalai.device.device_mesh import DeviceMesh
from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError
from .misc import DuplicatedShardingDimensionError, ShardingNotDivisibleError
from .sharding_spec import ShardingSpec

View File

@ -28,18 +28,6 @@ class LayoutConverterOptions:
pass
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
global_layout = Layout(device_mesh=layout.device_mesh,
device_type=layout.device_type,
sharding_spec=global_sharding_spec,
entire_shape=layout.entire_shape)
with torch.no_grad():
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
return global_tensor
def set_layout_converting_options(options: LayoutConverterOptions):
"""
Configure the shape consistency manager via function call.
@ -553,4 +541,5 @@ class LayoutConverter(metaclass=SingletonMeta):
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor)
tensor.dist_layout = target_layout
return tensor

View File

@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals
# the comm size for all gather is the size of the gathered tensor
gather_dim = comm_spec.gather_dim
all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
all_gather_size = device_mesh.mesh_shape[all_gather_axis]
all_gather_size = device_mesh.shape[all_gather_axis]
comm_size_for_all_gather = comm_size * all_gather_size
forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
# give a tiny cost to shard

1
test.py Normal file
View File

@ -0,0 +1 @@
from colossalai.tensor.d_tensor.api import to_distributed_tensor

View File

@ -7,7 +7,8 @@ import torch
from packaging import version
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
from colossalai.tensor.d_tensor.layout_converter import to_global
from colossalai.tensor.d_tensor import to_global
from colossalai.tensor.d_tensor.layout import Layout
from tests.kit.model_zoo.registry import ModelAttribute
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
@ -91,6 +92,8 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
assert n1 == n2
t1 = t1.cuda()
t2 = t2.cuda()
if n2 in layout_dict:
t2 = to_global(t2, layout_dict[n2])
if n2 in sharding_spec_dict:
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
t2.dist_layout = layout
t2 = to_global(t2)
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'

View File

@ -14,6 +14,10 @@ def check_embedding_1d():
assert embedding_1d.weight.shape == torch.Size([32, 64])
# ensure state dict is reversibly loadable
embedding.load_state_dict(embedding_1d.state_dict())
embedding_1d.load_state_dict(embedding.state_dict())
# check computation correctness
x = torch.randint(low=0, high=32, size=(4, 32)).cuda()
out = embedding(x)

View File

@ -5,6 +5,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -12,9 +13,18 @@ def check_linear_1d_col():
linear = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
assert is_distributed_tensor(linear_col.bias)
# ensure the shape is correct
assert linear_col.weight.shape == torch.Size([64, 32])
assert linear_col.bias.shape == torch.Size([64])
# ensure state dict is reversibly loadable
linear.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
out = linear(x)
@ -55,7 +65,7 @@ def check_linear_1d_row():
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_1d_col()
check_linear_1d_row()
# check_linear_1d_row()
@rerun_if_address_is_in_use()

View File

@ -14,7 +14,11 @@ def check_vocab_embedding_1d():
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
assert dist_embedding_1d.num_embeddings == 64
assert dist_embedding_1d.embed_dim == 32
assert dist_embedding_1d.embedding_dim == 32
# ensure state dict is reversibly loadable
embedding.load_state_dict(dist_embedding_1d.state_dict())
dist_embedding_1d.load_state_dict(embedding.state_dict())
# check embedding correctness
x = torch.randint(0, 128, (4, 32)).to('cuda')

View File

@ -1,14 +1,11 @@
import pytest
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from colossalai.core import global_context as gpc
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn

View File

@ -3,9 +3,7 @@ import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global
from colossalai.testing import rerun_if_address_is_in_use, spawn
@ -31,22 +29,18 @@ def check_dtensor(rank, world_size, port):
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor = DTensor(original_tensor, layout)
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
assert d_tensor.entire_shape == original_tensor.shape
assert d_tensor.data_type == original_tensor.dtype
assert get_global_shape(d_tensor) == original_tensor.shape
assert d_tensor.dtype == original_tensor.dtype
if rank in (0, 1):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
assert d_tensor.equal(original_tensor.narrow(0, 0, 2))
elif rank in (2, 3):
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
assert d_tensor.equal(original_tensor.narrow(0, 2, 2))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
assert d_tensor.to_global().equal(original_tensor)
assert to_global(d_tensor).equal(original_tensor)
output = test_model(d_tensor)
if rank in (0, 1):
@ -57,34 +51,29 @@ def check_dtensor(rank, world_size, port):
raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
new_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor.layout_convert(new_layout)
d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec)
if rank == 0:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
assert d_tensor.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
assert d_tensor.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
assert d_tensor.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
assert d_tensor.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
if rank == 0:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1))
elif rank == 1:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1))
elif rank == 2:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1))
elif rank == 3:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1))
else:
raise ValueError(f'rank {rank} is not in the device mesh')

View File

@ -9,7 +9,7 @@ from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn
entire_shape = torch.Size((64, 32, 16))