mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish codepull/4157/head
parent
7740c55c55
commit
8eb09a4c69
|
@ -10,7 +10,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
|
||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
|
@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
|||
for key, weight in state_dict.items():
|
||||
ret_block = None
|
||||
ret_block_size = 0
|
||||
if type(weight) != DTensor:
|
||||
if is_distributed_tensor(weight):
|
||||
weight_size = calculate_tensor_size(weight)
|
||||
|
||||
# If this weight is going to tip up over the maximal size, we split.
|
||||
|
|
|
@ -8,8 +8,9 @@ from torch import Tensor
|
|||
from torch.utils._pytree import tree_map
|
||||
|
||||
from colossalai._analyzer._subclasses import MetaTensor
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.tensor.d_tensor import distribute_tensor
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
|
||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||
_NORMAL_FACTORY = [
|
||||
|
@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor):
|
|||
"""
|
||||
target = self._materialize_data()
|
||||
self.clean()
|
||||
local_tensor = DTensor(target, layout).local_tensor
|
||||
local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
|
||||
return _convert_cls(self, local_tensor)
|
||||
|
||||
def clean(self) -> None:
|
||||
|
|
|
@ -13,8 +13,7 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
|
||||
from ._operation import gather_forward_split_backward, reduce_input
|
||||
from .parallel_module import ParallelModule
|
||||
|
@ -69,18 +68,17 @@ class Embedding1D(ParallelModule):
|
|||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(process_group)
|
||||
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
|
||||
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
|
||||
# Parameters.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
|
@ -194,7 +192,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
**kwargs):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
|
@ -208,8 +206,11 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
|
||||
# parameter
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
|
@ -252,7 +253,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
||||
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
|
||||
from ._operation import (
|
||||
|
@ -76,22 +76,21 @@ class Linear1D_Col(ParallelModule):
|
|||
self.skip_bias_add = skip_bias_add
|
||||
self.device = device
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
self.out_features_per_partition = divide(out_features, self.num_partitions)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
||||
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
||||
bias = torch.empty(self.out_features, **factory_kwargs)
|
||||
sharded_bias = shard_colwise(bias, self.process_group)
|
||||
self.bias = sharded_tensor_to_param(sharded_bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
|
@ -128,7 +127,6 @@ class Linear1D_Col(ParallelModule):
|
|||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
# the weigh to the linear layer is a transpose
|
||||
# thus shard on row is equal to shard on column
|
||||
|
@ -137,7 +135,6 @@ class Linear1D_Col(ParallelModule):
|
|||
if bias:
|
||||
sharded_bias = shard_colwise(module.bias.data, process_group)
|
||||
linear_1d.bias.copy_(sharded_bias)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
|
@ -212,21 +209,20 @@ class Linear1D_Row(ParallelModule):
|
|||
self.parallel_input = parallel_input
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(self.process_group)
|
||||
|
||||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.input_size_per_partition = divide(in_features, self.num_partitions)
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
||||
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
|
@ -340,3 +336,5 @@ class Linear1D_Row(ParallelModule):
|
|||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
return output, self.bias
|
||||
return output, self.bias
|
||||
|
|
|
@ -31,7 +31,6 @@ __all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
|
|||
|
||||
class LinearConv1D_Col(ParallelModule):
|
||||
r"""Linear layer with column parallelism.
|
||||
Specially created for HuggingFace's GPT2 model.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
|
||||
|
@ -189,7 +188,6 @@ class LinearConv1D_Col(ParallelModule):
|
|||
|
||||
class LinearConv1D_Row(ParallelModule):
|
||||
r""" Linear layer with row parallelism
|
||||
Specially created for HuggingFace's GPT2 model.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
|
|
|
@ -1,11 +1,23 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
|
||||
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
get_device_mesh,
|
||||
get_sharding_spec,
|
||||
is_distributed_tensor,
|
||||
sharded_tensor_to_param,
|
||||
to_global,
|
||||
)
|
||||
|
||||
__all__ = ['ParallelModule']
|
||||
|
||||
|
@ -25,3 +37,133 @@ class ParallelModule(nn.Module, ABC):
|
|||
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
||||
"""
|
||||
pass
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
|
||||
In rare cases, subclasses can achieve class-specific behavior by
|
||||
overriding this method with custom logic.
|
||||
|
||||
Args:
|
||||
destination (dict): a dict where state will be stored
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
for name, param in self._parameters.items():
|
||||
if param is not None:
|
||||
param_ = param if keep_vars else param.detach()
|
||||
|
||||
if is_distributed_tensor(param_):
|
||||
destination[prefix + name] = to_global(param_)
|
||||
else:
|
||||
destination[prefix + name] = param_
|
||||
|
||||
for name, buf in self._buffers.items():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
|
||||
destination[extra_state_key] = self.get_extra_state()
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
||||
it can be modified.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
"""
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if not torch.overrides.is_tensor_like(input_param):
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'expected torch.Tensor or Tensor-like object from checkpoint but '
|
||||
'received {}'.format(key, type(input_param)))
|
||||
continue
|
||||
|
||||
if is_distributed_tensor(param):
|
||||
# shard the input param
|
||||
device_mesh = get_device_mesh(param)
|
||||
sharding_spec = get_sharding_spec(param)
|
||||
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
|
||||
input_param = sharded_tensor_to_param(sharded_tensor)
|
||||
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
# non-lazy modules, since they dont have the hook to do the checks
|
||||
# in such case, it will error when accessing the .shape attribute.
|
||||
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||
'the shape in current model is {}.'.format(key, input_param.shape, param.shape))
|
||||
continue
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append('While copying the parameter named "{}", '
|
||||
'whose dimensions in the model are {} and '
|
||||
'whose dimensions in the checkpoint are {}, '
|
||||
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
|
||||
ex.args))
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix):]
|
||||
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
from .api import (
|
||||
compute_global_numel,
|
||||
distribute_tensor,
|
||||
get_device_mesh,
|
||||
get_global_shape,
|
||||
get_layout,
|
||||
get_sharding_spec,
|
||||
is_distributed_tensor,
|
||||
is_sharded,
|
||||
redistribute,
|
||||
shard_colwise,
|
||||
shard_rowwise,
|
||||
sharded_tensor_to_param,
|
||||
to_global,
|
||||
)
|
||||
from .layout import Layout
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
__all__ = [
|
||||
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
|
||||
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
|
||||
'redistribute', 'get_layout'
|
||||
'Layout', 'ShardingSpec'
|
||||
]
|
|
@ -1,3 +1,6 @@
|
|||
import copy
|
||||
import operator
|
||||
from functools import reduce
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
@ -6,13 +9,165 @@ from torch.distributed import ProcessGroup
|
|||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .d_tensor import DTensor
|
||||
from .layout import Layout
|
||||
from .layout_converter import LayoutConverter
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
layout_converter = LayoutConverter()
|
||||
|
||||
def shard_rowwise(tensor: torch.Tensor,
|
||||
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
||||
inplace: bool = False) -> DTensor:
|
||||
|
||||
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check whether the given tensor is a distributed tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether the given tensor is a distributed tensor.
|
||||
"""
|
||||
return hasattr(tensor, "dist_layout")
|
||||
|
||||
|
||||
def is_sharded(dtensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check if a tensor is sharded.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
bool: True if the tensor is sharded, False otherwise.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
|
||||
|
||||
|
||||
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be hijacked.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The hijacked tensor.
|
||||
"""
|
||||
dtensor._old_detach = dtensor.detach
|
||||
dtensor._old_clone = dtensor.clone
|
||||
|
||||
def new_detach(self):
|
||||
t_ = self._old_detach()
|
||||
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
||||
return t_
|
||||
|
||||
def new_clone(self, *args, **kwargs):
|
||||
t_ = self._old_clone(*args, **kwargs)
|
||||
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
||||
return t_
|
||||
|
||||
# bind the new methods to the tensor
|
||||
dtensor.detach = new_detach.__get__(dtensor)
|
||||
dtensor.clone = new_clone.__get__(dtensor)
|
||||
return dtensor
|
||||
|
||||
|
||||
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||
'''
|
||||
Construct the default sharding specification for the tensor.
|
||||
|
||||
Args:
|
||||
tensor (`torch.Tensor`): the tensor to be sharded.
|
||||
|
||||
Returns:
|
||||
A `ShardingSpec` object without any sharding specified.
|
||||
'''
|
||||
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
||||
|
||||
|
||||
def _apply_layout(tensor, layout):
|
||||
'''
|
||||
Apply the layout to the local tensor during initializing process.
|
||||
'''
|
||||
# layout converter requires a source and target laytout
|
||||
# we construct the source layer for an unsharded tensor
|
||||
# and use self.dist_layer as the targer layout for the sharded tensor
|
||||
source_spec = _construct_default_sharding_spec(tensor)
|
||||
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
|
||||
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
|
||||
return sharded_tensor
|
||||
|
||||
|
||||
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||
"""
|
||||
Convert the given tensor to a distributed tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be converted.
|
||||
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
|
||||
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The distributed tensor.
|
||||
"""
|
||||
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
|
||||
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
|
||||
|
||||
# shard tensor
|
||||
sharded_tensor = _apply_layout(tensor, dist_layout)
|
||||
|
||||
# hack some tensor methods
|
||||
_hijack_detach_and_clone(sharded_tensor)
|
||||
|
||||
return sharded_tensor
|
||||
|
||||
|
||||
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||
'''
|
||||
Convert the layout of the tensor from source_spec to target_spec.
|
||||
This will update the `local_tensor` and `dist_layout` in place.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): the distributed tensor to be converted.
|
||||
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
|
||||
target_layout (Layout): the target layout specification.
|
||||
'''
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
global_shape = get_global_shape(dtensor)
|
||||
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||
resharded_tensor = layout_converter.apply(tensor=dtensor,
|
||||
source_layout=dtensor.dist_layout,
|
||||
target_layout=target_layout)
|
||||
return resharded_tensor
|
||||
|
||||
|
||||
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert a distributed tensor to the global tensor with the given layout.
|
||||
This function returns a native `torch.Tensor` object.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): the distributed tensor to be converted.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the global tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
layout_converter = LayoutConverter()
|
||||
|
||||
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
|
||||
device_mesh = get_device_mesh(dtensor)
|
||||
global_shape = get_global_shape(dtensor)
|
||||
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
|
||||
|
||||
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
|
||||
return global_tensor
|
||||
|
||||
|
||||
def shard_rowwise(
|
||||
tensor: torch.Tensor,
|
||||
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Shard the first dim of the given tensor.
|
||||
|
||||
|
@ -24,7 +179,7 @@ def shard_rowwise(tensor: torch.Tensor,
|
|||
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
||||
|
||||
Returns:
|
||||
DTensor: The sharded tensor.
|
||||
torch.Tensor: The sharded tensor.
|
||||
"""
|
||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||
if group_or_device_mesh is None:
|
||||
|
@ -35,17 +190,13 @@ def shard_rowwise(tensor: torch.Tensor,
|
|||
else:
|
||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||
device_mesh = group_or_device_mesh
|
||||
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
||||
|
||||
if not inplace:
|
||||
tensor = tensor.detach().clone()
|
||||
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
||||
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
||||
|
||||
|
||||
def shard_colwise(tensor: torch.Tensor,
|
||||
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
||||
inplace: bool = False) -> DTensor:
|
||||
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
|
||||
"""
|
||||
Shard the first dim of the given tensor.
|
||||
|
||||
|
@ -57,7 +208,7 @@ def shard_colwise(tensor: torch.Tensor,
|
|||
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
||||
|
||||
Returns:
|
||||
DTensor: The sharded tensor.
|
||||
torch.Tensor: The sharded tensor.
|
||||
"""
|
||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||
if group_or_device_mesh is None:
|
||||
|
@ -70,7 +221,87 @@ def shard_colwise(tensor: torch.Tensor,
|
|||
device_mesh = group_or_device_mesh
|
||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
||||
|
||||
if not inplace:
|
||||
tensor = tensor.detach().clone()
|
||||
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
||||
|
||||
return DTensor(tensor, device_mesh, sharding_spec)
|
||||
|
||||
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
||||
|
||||
# make it distributed as well
|
||||
param.dist_layout = dtensor.dist_layout
|
||||
_hijack_detach_and_clone(param)
|
||||
|
||||
return param
|
||||
|
||||
|
||||
def compute_global_numel(dtensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Compute the global number of elements in the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
int: The global number of elements in the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
|
||||
return numel
|
||||
|
||||
|
||||
def get_layout(dtensor: torch.Tensor) -> Layout:
|
||||
"""
|
||||
Get the layout of the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
Layout: The layout of the distributed tensor.
|
||||
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout
|
||||
|
||||
|
||||
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
|
||||
"""
|
||||
Get the global shape of the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
torch.Size: The global shape of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout.global_shape
|
||||
|
||||
|
||||
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
|
||||
"""
|
||||
Get the device mesh of the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
DeviceMesh: The device mesh of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout.device_mesh
|
||||
|
||||
|
||||
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
|
||||
"""
|
||||
Get the sharding spec of the distributed tensor.
|
||||
|
||||
Args:
|
||||
dtensor (torch.Tensor): The distributed tensor.
|
||||
|
||||
Returns:
|
||||
ShardingSpec: The sharding spec of the distributed tensor.
|
||||
"""
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
return dtensor.dist_layout.sharding_spec
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
import operator
|
||||
from dataclasses import dataclass
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
|
||||
from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError
|
||||
from .misc import DuplicatedShardingDimensionError, ShardingNotDivisibleError
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
||||
|
||||
|
|
|
@ -28,18 +28,6 @@ class LayoutConverterOptions:
|
|||
pass
|
||||
|
||||
|
||||
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
||||
layout_converter = LayoutConverter()
|
||||
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
||||
global_layout = Layout(device_mesh=layout.device_mesh,
|
||||
device_type=layout.device_type,
|
||||
sharding_spec=global_sharding_spec,
|
||||
entire_shape=layout.entire_shape)
|
||||
with torch.no_grad():
|
||||
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
|
||||
return global_tensor
|
||||
|
||||
|
||||
def set_layout_converting_options(options: LayoutConverterOptions):
|
||||
"""
|
||||
Configure the shape consistency manager via function call.
|
||||
|
@ -553,4 +541,5 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
|
||||
for comm_spec in comm_action_sequence:
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
tensor.dist_layout = target_layout
|
||||
return tensor
|
||||
|
|
|
@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals
|
|||
# the comm size for all gather is the size of the gathered tensor
|
||||
gather_dim = comm_spec.gather_dim
|
||||
all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
|
||||
all_gather_size = device_mesh.mesh_shape[all_gather_axis]
|
||||
all_gather_size = device_mesh.shape[all_gather_axis]
|
||||
comm_size_for_all_gather = comm_size * all_gather_size
|
||||
forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
|
||||
# give a tiny cost to shard
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
from colossalai.tensor.d_tensor.api import to_distributed_tensor
|
|
@ -7,7 +7,8 @@ import torch
|
|||
from packaging import version
|
||||
|
||||
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||
from colossalai.tensor.d_tensor.layout_converter import to_global
|
||||
from colossalai.tensor.d_tensor import to_global
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from tests.kit.model_zoo.registry import ModelAttribute
|
||||
|
||||
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
|
||||
|
@ -91,6 +92,8 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
|
|||
assert n1 == n2
|
||||
t1 = t1.cuda()
|
||||
t2 = t2.cuda()
|
||||
if n2 in layout_dict:
|
||||
t2 = to_global(t2, layout_dict[n2])
|
||||
if n2 in sharding_spec_dict:
|
||||
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
|
||||
t2.dist_layout = layout
|
||||
t2 = to_global(t2)
|
||||
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||
|
|
|
@ -14,6 +14,10 @@ def check_embedding_1d():
|
|||
|
||||
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
embedding.load_state_dict(embedding_1d.state_dict())
|
||||
embedding_1d.load_state_dict(embedding.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.randint(low=0, high=32, size=(4, 32)).cuda()
|
||||
out = embedding(x)
|
||||
|
|
|
@ -5,6 +5,7 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@ -12,9 +13,18 @@ def check_linear_1d_col():
|
|||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
||||
|
||||
# ensure that the parameters are distributed
|
||||
assert is_distributed_tensor(linear_col.weight)
|
||||
assert is_distributed_tensor(linear_col.bias)
|
||||
|
||||
# ensure the shape is correct
|
||||
assert linear_col.weight.shape == torch.Size([64, 32])
|
||||
assert linear_col.bias.shape == torch.Size([64])
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
linear.load_state_dict(linear_col.state_dict())
|
||||
linear_col.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
out = linear(x)
|
||||
|
@ -55,7 +65,7 @@ def check_linear_1d_row():
|
|||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
check_linear_1d_col()
|
||||
check_linear_1d_row()
|
||||
# check_linear_1d_row()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -14,7 +14,11 @@ def check_vocab_embedding_1d():
|
|||
|
||||
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
||||
assert dist_embedding_1d.num_embeddings == 64
|
||||
assert dist_embedding_1d.embed_dim == 32
|
||||
assert dist_embedding_1d.embedding_dim == 32
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
embedding.load_state_dict(dist_embedding_1d.state_dict())
|
||||
dist_embedding_1d.load_state_dict(embedding.state_dict())
|
||||
|
||||
# check embedding correctness
|
||||
x = torch.randint(0, 128, (4, 32)).to('cuda')
|
||||
|
|
|
@ -1,14 +1,11 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
|
|
@ -3,9 +3,7 @@ import torch
|
|||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@ -31,22 +29,18 @@ def check_dtensor(rank, world_size, port):
|
|||
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
||||
layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=target_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
d_tensor = DTensor(original_tensor, layout)
|
||||
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
|
||||
|
||||
assert d_tensor.entire_shape == original_tensor.shape
|
||||
assert d_tensor.data_type == original_tensor.dtype
|
||||
assert get_global_shape(d_tensor) == original_tensor.shape
|
||||
assert d_tensor.dtype == original_tensor.dtype
|
||||
|
||||
if rank in (0, 1):
|
||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 0, 2))
|
||||
elif rank in (2, 3):
|
||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 2, 2))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
assert d_tensor.to_global().equal(original_tensor)
|
||||
assert to_global(d_tensor).equal(original_tensor)
|
||||
output = test_model(d_tensor)
|
||||
|
||||
if rank in (0, 1):
|
||||
|
@ -57,34 +51,29 @@ def check_dtensor(rank, world_size, port):
|
|||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
|
||||
new_layout = Layout(device_mesh=device_mesh,
|
||||
device_type=torch.device('cuda'),
|
||||
sharding_spec=new_sharding_spec,
|
||||
entire_shape=original_tensor.shape)
|
||||
|
||||
d_tensor.layout_convert(new_layout)
|
||||
d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec)
|
||||
|
||||
if rank == 0:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
elif rank == 1:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||
elif rank == 2:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||
elif rank == 3:
|
||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||
assert d_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
||||
|
||||
if rank == 0:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1))
|
||||
elif rank == 1:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1))
|
||||
elif rank == 2:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1))
|
||||
elif rank == 3:
|
||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||
assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1))
|
||||
else:
|
||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.logging import disable_existing_loggers
|
|||
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
|
||||
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
|
||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
entire_shape = torch.Size((64, 32, 16))
|
||||
|
|
Loading…
Reference in New Issue