mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support module saving and loading (#4062)
* [shardformer] support module saving and loading * polish codepull/4157/head
parent
7740c55c55
commit
8eb09a4c69
|
@ -10,7 +10,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||||
|
|
||||||
SAFE_WEIGHTS_NAME = "model.safetensors"
|
SAFE_WEIGHTS_NAME = "model.safetensors"
|
||||||
WEIGHTS_NAME = "pytorch_model.bin"
|
WEIGHTS_NAME = "pytorch_model.bin"
|
||||||
|
@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
|
||||||
for key, weight in state_dict.items():
|
for key, weight in state_dict.items():
|
||||||
ret_block = None
|
ret_block = None
|
||||||
ret_block_size = 0
|
ret_block_size = 0
|
||||||
if type(weight) != DTensor:
|
if is_distributed_tensor(weight):
|
||||||
weight_size = calculate_tensor_size(weight)
|
weight_size = calculate_tensor_size(weight)
|
||||||
|
|
||||||
# If this weight is going to tip up over the maximal size, we split.
|
# If this weight is going to tip up over the maximal size, we split.
|
||||||
|
|
|
@ -8,8 +8,9 @@ from torch import Tensor
|
||||||
from torch.utils._pytree import tree_map
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai._analyzer._subclasses import MetaTensor
|
from colossalai._analyzer._subclasses import MetaTensor
|
||||||
from colossalai.tensor.d_tensor.d_tensor import DTensor
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.tensor.d_tensor.layout import Layout
|
from colossalai.tensor.d_tensor import distribute_tensor
|
||||||
|
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
|
||||||
_NORMAL_FACTORY = [
|
_NORMAL_FACTORY = [
|
||||||
|
@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor):
|
||||||
"""
|
"""
|
||||||
target = self._materialize_data()
|
target = self._materialize_data()
|
||||||
self.clean()
|
self.clean()
|
||||||
local_tensor = DTensor(target, layout).local_tensor
|
local_tensor = distribute_tensor(target, device_mesh, sharding_spec)
|
||||||
return _convert_cls(self, local_tensor)
|
return _convert_cls(self, local_tensor)
|
||||||
|
|
||||||
def clean(self) -> None:
|
def clean(self) -> None:
|
||||||
|
|
|
@ -13,8 +13,7 @@ from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||||
from colossalai.utils.cuda import get_current_device
|
|
||||||
|
|
||||||
from ._operation import gather_forward_split_backward, reduce_input
|
from ._operation import gather_forward_split_backward, reduce_input
|
||||||
from .parallel_module import ParallelModule
|
from .parallel_module import ParallelModule
|
||||||
|
@ -69,18 +68,17 @@ class Embedding1D(ParallelModule):
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.num_partitions = dist.get_world_size(process_group)
|
|
||||||
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
|
|
||||||
|
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.embed_args = args
|
self.embed_args = args
|
||||||
self.embed_kwargs = kwargs
|
self.embed_kwargs = kwargs
|
||||||
self.gather_output = gather_output
|
self.gather_output = gather_output
|
||||||
|
|
||||||
if device is None:
|
# Parameters.
|
||||||
device = get_current_device()
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||||
self.weight = Parameter(torch.empty((num_embeddings, self.embed_dim_per_partition), device=device, dtype=dtype))
|
sharded_weight = shard_colwise(weight, process_group)
|
||||||
|
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||||
|
|
||||||
# offset the seed with randomizer index and rank
|
# offset the seed with randomizer index and rank
|
||||||
seed = torch.random.initial_seed()
|
seed = torch.random.initial_seed()
|
||||||
|
@ -194,7 +192,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||||
**kwargs):
|
**kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_embeddings = num_embeddings
|
self.num_embeddings = num_embeddings
|
||||||
self.embed_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.embed_args = args
|
self.embed_args = args
|
||||||
self.embed_kwargs = kwargs
|
self.embed_kwargs = kwargs
|
||||||
|
@ -208,8 +206,11 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||||
|
|
||||||
self.weight = Parameter(
|
# parameter
|
||||||
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||||
|
sharded_weight = shard_rowwise(weight, process_group)
|
||||||
|
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||||
|
|
||||||
# offset the seed with randomizer index and rank
|
# offset the seed with randomizer index and rank
|
||||||
seed = torch.random.initial_seed()
|
seed = torch.random.initial_seed()
|
||||||
|
@ -252,7 +253,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer) -> None:
|
def reset_parameters(self, weight_initializer) -> None:
|
||||||
with self.randomizer.fork_rng(enable_cpu=True):
|
with self.randomizer.fork_rng(enable_cpu=True):
|
||||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
||||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||||
self._fill_padding_idx_with_zero()
|
self._fill_padding_idx_with_zero()
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.nn.layer.utils import divide
|
from colossalai.nn.layer.utils import divide
|
||||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise
|
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
|
||||||
from ._operation import (
|
from ._operation import (
|
||||||
|
@ -76,22 +76,21 @@ class Linear1D_Col(ParallelModule):
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.device = device
|
self.device = device
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.num_partitions = dist.get_world_size(self.process_group)
|
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError('cannot skip bias addition if bias is None')
|
raise ValueError('cannot skip bias addition if bias is None')
|
||||||
|
|
||||||
self.out_features_per_partition = divide(out_features, self.num_partitions)
|
|
||||||
|
|
||||||
# Parameters.
|
# Parameters.
|
||||||
# Initialize weight.
|
|
||||||
if device is None:
|
|
||||||
device = get_current_device()
|
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
self.weight = Parameter(torch.empty(self.out_features_per_partition, self.in_features, **factory_kwargs))
|
|
||||||
|
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||||
|
sharded_weight = shard_rowwise(weight, self.process_group)
|
||||||
|
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(torch.empty(self.out_features_per_partition, **factory_kwargs))
|
bias = torch.empty(self.out_features, **factory_kwargs)
|
||||||
|
sharded_bias = shard_colwise(bias, self.process_group)
|
||||||
|
self.bias = sharded_tensor_to_param(sharded_bias)
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
|
@ -128,7 +127,6 @@ class Linear1D_Col(ParallelModule):
|
||||||
*args,
|
*args,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
# TODO: copy the sharded weights
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# the weigh to the linear layer is a transpose
|
# the weigh to the linear layer is a transpose
|
||||||
# thus shard on row is equal to shard on column
|
# thus shard on row is equal to shard on column
|
||||||
|
@ -137,7 +135,6 @@ class Linear1D_Col(ParallelModule):
|
||||||
if bias:
|
if bias:
|
||||||
sharded_bias = shard_colwise(module.bias.data, process_group)
|
sharded_bias = shard_colwise(module.bias.data, process_group)
|
||||||
linear_1d.bias.copy_(sharded_bias)
|
linear_1d.bias.copy_(sharded_bias)
|
||||||
|
|
||||||
return linear_1d
|
return linear_1d
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
@ -212,21 +209,20 @@ class Linear1D_Row(ParallelModule):
|
||||||
self.parallel_input = parallel_input
|
self.parallel_input = parallel_input
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.num_partitions = dist.get_world_size(self.process_group)
|
|
||||||
|
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError('cannot skip bias addition if bias is None')
|
raise ValueError('cannot skip bias addition if bias is None')
|
||||||
|
|
||||||
# Divide the weight matrix along the last dimension.
|
|
||||||
self.input_size_per_partition = divide(in_features, self.num_partitions)
|
|
||||||
|
|
||||||
# Parameters.
|
# Parameters.
|
||||||
# Initialize weight.
|
# Initialize weight.
|
||||||
if device is None:
|
if device is None:
|
||||||
device = get_current_device()
|
device = get_current_device()
|
||||||
|
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
self.weight = Parameter(torch.empty(self.out_features, self.input_size_per_partition, **factory_kwargs))
|
|
||||||
|
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||||
|
sharded_weight = shard_colwise(weight, self.process_group)
|
||||||
|
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||||
|
|
||||||
if self.stream_chunk_num > 1:
|
if self.stream_chunk_num > 1:
|
||||||
# TODO() work for inference only
|
# TODO() work for inference only
|
||||||
|
@ -340,3 +336,5 @@ class Linear1D_Row(ParallelModule):
|
||||||
return output
|
return output
|
||||||
else:
|
else:
|
||||||
return output, self.bias
|
return output, self.bias
|
||||||
|
return output, self.bias
|
||||||
|
return output, self.bias
|
||||||
|
|
|
@ -31,7 +31,6 @@ __all__ = ['LinearConv1D_Col', 'LinearConv1D_Row']
|
||||||
|
|
||||||
class LinearConv1D_Col(ParallelModule):
|
class LinearConv1D_Col(ParallelModule):
|
||||||
r"""Linear layer with column parallelism.
|
r"""Linear layer with column parallelism.
|
||||||
Specially created for HuggingFace's GPT2 model.
|
|
||||||
|
|
||||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||||
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
|
its second dimension as :math:`A = [A_1, ..., A_p]`. This layer is used to fit `Conv1D` layer in gpt2 of huggingface.
|
||||||
|
@ -189,7 +188,6 @@ class LinearConv1D_Col(ParallelModule):
|
||||||
|
|
||||||
class LinearConv1D_Row(ParallelModule):
|
class LinearConv1D_Row(ParallelModule):
|
||||||
r""" Linear layer with row parallelism
|
r""" Linear layer with row parallelism
|
||||||
Specially created for HuggingFace's GPT2 model.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
in_features (int): size of each input sample.
|
in_features (int): size of each input sample.
|
||||||
|
|
|
@ -1,11 +1,23 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import itertools
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, Module
|
||||||
|
|
||||||
|
from colossalai.tensor.d_tensor import (
|
||||||
|
distribute_tensor,
|
||||||
|
get_device_mesh,
|
||||||
|
get_sharding_spec,
|
||||||
|
is_distributed_tensor,
|
||||||
|
sharded_tensor_to_param,
|
||||||
|
to_global,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ['ParallelModule']
|
__all__ = ['ParallelModule']
|
||||||
|
|
||||||
|
@ -25,3 +37,133 @@ class ParallelModule(nn.Module, ABC):
|
||||||
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||||
|
r"""Saves module state to `destination` dictionary, containing a state
|
||||||
|
of the module, but not its descendants. This is called on every
|
||||||
|
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||||
|
|
||||||
|
In rare cases, subclasses can achieve class-specific behavior by
|
||||||
|
overriding this method with custom logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
destination (dict): a dict where state will be stored
|
||||||
|
prefix (str): the prefix for parameters and buffers used in this
|
||||||
|
module
|
||||||
|
"""
|
||||||
|
for name, param in self._parameters.items():
|
||||||
|
if param is not None:
|
||||||
|
param_ = param if keep_vars else param.detach()
|
||||||
|
|
||||||
|
if is_distributed_tensor(param_):
|
||||||
|
destination[prefix + name] = to_global(param_)
|
||||||
|
else:
|
||||||
|
destination[prefix + name] = param_
|
||||||
|
|
||||||
|
for name, buf in self._buffers.items():
|
||||||
|
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||||
|
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||||
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||||
|
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
|
||||||
|
destination[extra_state_key] = self.get_extra_state()
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||||
|
error_msgs):
|
||||||
|
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||||
|
this module, but not its descendants. This is called on every submodule
|
||||||
|
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||||
|
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||||
|
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||||
|
Subclasses can achieve class-specific backward compatible loading using
|
||||||
|
the version number at `local_metadata.get("version", None)`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
:attr:`state_dict` is not the same object as the input
|
||||||
|
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
||||||
|
it can be modified.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): a dict containing parameters and
|
||||||
|
persistent buffers.
|
||||||
|
prefix (str): the prefix for parameters and buffers used in this
|
||||||
|
module
|
||||||
|
local_metadata (dict): a dict containing the metadata for this module.
|
||||||
|
See
|
||||||
|
strict (bool): whether to strictly enforce that the keys in
|
||||||
|
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||||
|
parameters and buffers in this module
|
||||||
|
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||||
|
this list
|
||||||
|
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||||
|
keys to this list
|
||||||
|
error_msgs (list of str): error messages should be added to this
|
||||||
|
list, and will be reported together in
|
||||||
|
:meth:`~torch.nn.Module.load_state_dict`
|
||||||
|
"""
|
||||||
|
for hook in self._load_state_dict_pre_hooks.values():
|
||||||
|
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||||
|
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||||
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||||
|
|
||||||
|
for name, param in local_state.items():
|
||||||
|
key = prefix + name
|
||||||
|
|
||||||
|
if key in state_dict:
|
||||||
|
input_param = state_dict[key]
|
||||||
|
if not torch.overrides.is_tensor_like(input_param):
|
||||||
|
error_msgs.append('While copying the parameter named "{}", '
|
||||||
|
'expected torch.Tensor or Tensor-like object from checkpoint but '
|
||||||
|
'received {}'.format(key, type(input_param)))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if is_distributed_tensor(param):
|
||||||
|
# shard the input param
|
||||||
|
device_mesh = get_device_mesh(param)
|
||||||
|
sharding_spec = get_sharding_spec(param)
|
||||||
|
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
|
||||||
|
input_param = sharded_tensor_to_param(sharded_tensor)
|
||||||
|
|
||||||
|
# This is used to avoid copying uninitialized parameters into
|
||||||
|
# non-lazy modules, since they dont have the hook to do the checks
|
||||||
|
# in such case, it will error when accessing the .shape attribute.
|
||||||
|
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||||
|
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||||
|
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||||
|
input_param = input_param[0]
|
||||||
|
|
||||||
|
if not is_param_lazy and input_param.shape != param.shape:
|
||||||
|
# local shape should match the one in checkpoint
|
||||||
|
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
|
||||||
|
'the shape in current model is {}.'.format(key, input_param.shape, param.shape))
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
with torch.no_grad():
|
||||||
|
param.copy_(input_param)
|
||||||
|
except Exception as ex:
|
||||||
|
error_msgs.append('While copying the parameter named "{}", '
|
||||||
|
'whose dimensions in the model are {} and '
|
||||||
|
'whose dimensions in the checkpoint are {}, '
|
||||||
|
'an exception occurred : {}.'.format(key, param.size(), input_param.size(),
|
||||||
|
ex.args))
|
||||||
|
elif strict:
|
||||||
|
missing_keys.append(key)
|
||||||
|
|
||||||
|
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||||
|
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
||||||
|
if extra_state_key in state_dict:
|
||||||
|
self.set_extra_state(state_dict[extra_state_key])
|
||||||
|
elif strict:
|
||||||
|
missing_keys.append(extra_state_key)
|
||||||
|
elif strict and (extra_state_key in state_dict):
|
||||||
|
unexpected_keys.append(extra_state_key)
|
||||||
|
|
||||||
|
if strict:
|
||||||
|
for key in state_dict.keys():
|
||||||
|
if key.startswith(prefix) and key != extra_state_key:
|
||||||
|
input_name = key[len(prefix):]
|
||||||
|
input_name = input_name.split('.', 1)[0] # get the name of param/buffer/child
|
||||||
|
if input_name not in self._modules and input_name not in local_state:
|
||||||
|
unexpected_keys.append(key)
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
from .api import (
|
||||||
|
compute_global_numel,
|
||||||
|
distribute_tensor,
|
||||||
|
get_device_mesh,
|
||||||
|
get_global_shape,
|
||||||
|
get_layout,
|
||||||
|
get_sharding_spec,
|
||||||
|
is_distributed_tensor,
|
||||||
|
is_sharded,
|
||||||
|
redistribute,
|
||||||
|
shard_colwise,
|
||||||
|
shard_rowwise,
|
||||||
|
sharded_tensor_to_param,
|
||||||
|
to_global,
|
||||||
|
)
|
||||||
|
from .layout import Layout
|
||||||
|
from .sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'is_distributed_tensor', 'distribute_tensor', 'to_global', 'is_sharded', 'shard_rowwise', 'shard_colwise',
|
||||||
|
'sharded_tensor_to_param', 'compute_global_numel', 'get_sharding_spec', 'get_global_shape', 'get_device_mesh',
|
||||||
|
'redistribute', 'get_layout'
|
||||||
|
'Layout', 'ShardingSpec'
|
||||||
|
]
|
|
@ -1,3 +1,6 @@
|
||||||
|
import copy
|
||||||
|
import operator
|
||||||
|
from functools import reduce
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -6,13 +9,165 @@ from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
from .d_tensor import DTensor
|
from .layout import Layout
|
||||||
|
from .layout_converter import LayoutConverter
|
||||||
from .sharding_spec import ShardingSpec
|
from .sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
layout_converter = LayoutConverter()
|
||||||
|
|
||||||
def shard_rowwise(tensor: torch.Tensor,
|
|
||||||
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
def is_distributed_tensor(tensor: torch.Tensor) -> bool:
|
||||||
inplace: bool = False) -> DTensor:
|
"""
|
||||||
|
Check whether the given tensor is a distributed tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The tensor to be checked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: Whether the given tensor is a distributed tensor.
|
||||||
|
"""
|
||||||
|
return hasattr(tensor, "dist_layout")
|
||||||
|
|
||||||
|
|
||||||
|
def is_sharded(dtensor: torch.Tensor) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a tensor is sharded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The tensor to be checked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the tensor is sharded, False otherwise.
|
||||||
|
"""
|
||||||
|
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||||
|
return list(dtensor.shape) == list(dtensor.dist_layout.global_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _hijack_detach_and_clone(dtensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The tensor to be hijacked.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The hijacked tensor.
|
||||||
|
"""
|
||||||
|
dtensor._old_detach = dtensor.detach
|
||||||
|
dtensor._old_clone = dtensor.clone
|
||||||
|
|
||||||
|
def new_detach(self):
|
||||||
|
t_ = self._old_detach()
|
||||||
|
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
||||||
|
return t_
|
||||||
|
|
||||||
|
def new_clone(self, *args, **kwargs):
|
||||||
|
t_ = self._old_clone(*args, **kwargs)
|
||||||
|
t_.dist_layout = copy.deepcopy(self.dist_layout)
|
||||||
|
return t_
|
||||||
|
|
||||||
|
# bind the new methods to the tensor
|
||||||
|
dtensor.detach = new_detach.__get__(dtensor)
|
||||||
|
dtensor.clone = new_clone.__get__(dtensor)
|
||||||
|
return dtensor
|
||||||
|
|
||||||
|
|
||||||
|
def _construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
|
||||||
|
'''
|
||||||
|
Construct the default sharding specification for the tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (`torch.Tensor`): the tensor to be sharded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `ShardingSpec` object without any sharding specified.
|
||||||
|
'''
|
||||||
|
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_layout(tensor, layout):
|
||||||
|
'''
|
||||||
|
Apply the layout to the local tensor during initializing process.
|
||||||
|
'''
|
||||||
|
# layout converter requires a source and target laytout
|
||||||
|
# we construct the source layer for an unsharded tensor
|
||||||
|
# and use self.dist_layer as the targer layout for the sharded tensor
|
||||||
|
source_spec = _construct_default_sharding_spec(tensor)
|
||||||
|
source_layout = Layout(device_mesh=layout.device_mesh, sharding_spec=source_spec, global_shape=tensor.shape)
|
||||||
|
sharded_tensor = layout_converter.apply(tensor=tensor, source_layout=source_layout, target_layout=layout)
|
||||||
|
return sharded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert the given tensor to a distributed tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): The tensor to be converted.
|
||||||
|
device_mesh (DeviceMesh): The device mesh for abstraction of the compute devices.
|
||||||
|
sharding_spec (ShardingSpec): The sharding specification which describes how the tensor will be sharded.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The distributed tensor.
|
||||||
|
"""
|
||||||
|
assert not is_distributed_tensor(tensor), 'The input tensor is already a distributed tensor.'
|
||||||
|
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=tensor.shape)
|
||||||
|
|
||||||
|
# shard tensor
|
||||||
|
sharded_tensor = _apply_layout(tensor, dist_layout)
|
||||||
|
|
||||||
|
# hack some tensor methods
|
||||||
|
_hijack_detach_and_clone(sharded_tensor)
|
||||||
|
|
||||||
|
return sharded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
|
||||||
|
'''
|
||||||
|
Convert the layout of the tensor from source_spec to target_spec.
|
||||||
|
This will update the `local_tensor` and `dist_layout` in place.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtensor (torch.Tensor): the distributed tensor to be converted.
|
||||||
|
device_mesh (DeviceMesh): the device mesh for abstraction of the compute devices.
|
||||||
|
target_layout (Layout): the target layout specification.
|
||||||
|
'''
|
||||||
|
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||||
|
global_shape = get_global_shape(dtensor)
|
||||||
|
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
|
||||||
|
resharded_tensor = layout_converter.apply(tensor=dtensor,
|
||||||
|
source_layout=dtensor.dist_layout,
|
||||||
|
target_layout=target_layout)
|
||||||
|
return resharded_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def to_global(dtensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Convert a distributed tensor to the global tensor with the given layout.
|
||||||
|
This function returns a native `torch.Tensor` object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtensor (torch.Tensor): the distributed tensor to be converted.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: the global tensor.
|
||||||
|
"""
|
||||||
|
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||||
|
layout_converter = LayoutConverter()
|
||||||
|
|
||||||
|
global_sharding_spec = ShardingSpec(dtensor.dim(), {})
|
||||||
|
device_mesh = get_device_mesh(dtensor)
|
||||||
|
global_shape = get_global_shape(dtensor)
|
||||||
|
global_layout = Layout(device_mesh=device_mesh, sharding_spec=global_sharding_spec, global_shape=global_shape)
|
||||||
|
|
||||||
|
global_tensor = layout_converter.apply(dtensor, dtensor.dist_layout, global_layout)
|
||||||
|
return global_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def shard_rowwise(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Shard the first dim of the given tensor.
|
Shard the first dim of the given tensor.
|
||||||
|
|
||||||
|
@ -24,7 +179,7 @@ def shard_rowwise(tensor: torch.Tensor,
|
||||||
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DTensor: The sharded tensor.
|
torch.Tensor: The sharded tensor.
|
||||||
"""
|
"""
|
||||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||||
if group_or_device_mesh is None:
|
if group_or_device_mesh is None:
|
||||||
|
@ -35,17 +190,13 @@ def shard_rowwise(tensor: torch.Tensor,
|
||||||
else:
|
else:
|
||||||
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
assert len(group_or_device_mesh.shape) == 1, 'Only 1D DeviceMesh is accepted for row-wise sharding.'
|
||||||
device_mesh = group_or_device_mesh
|
device_mesh = group_or_device_mesh
|
||||||
|
|
||||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={0: [0]})
|
||||||
|
|
||||||
if not inplace:
|
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
||||||
tensor = tensor.detach().clone()
|
|
||||||
|
|
||||||
return DTensor(tensor, device_mesh, sharding_spec)
|
|
||||||
|
|
||||||
|
|
||||||
def shard_colwise(tensor: torch.Tensor,
|
def shard_colwise(tensor: torch.Tensor, group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None) -> torch.Tensor:
|
||||||
group_or_device_mesh: Union[ProcessGroup, DeviceMesh] = None,
|
|
||||||
inplace: bool = False) -> DTensor:
|
|
||||||
"""
|
"""
|
||||||
Shard the first dim of the given tensor.
|
Shard the first dim of the given tensor.
|
||||||
|
|
||||||
|
@ -57,7 +208,7 @@ def shard_colwise(tensor: torch.Tensor,
|
||||||
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
inplace (bool, optional): Whether to shard the tensor in-place. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
DTensor: The sharded tensor.
|
torch.Tensor: The sharded tensor.
|
||||||
"""
|
"""
|
||||||
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
# if the group_or_device_mesh is None, we shard the tensor with respect to the global process group
|
||||||
if group_or_device_mesh is None:
|
if group_or_device_mesh is None:
|
||||||
|
@ -70,7 +221,87 @@ def shard_colwise(tensor: torch.Tensor,
|
||||||
device_mesh = group_or_device_mesh
|
device_mesh = group_or_device_mesh
|
||||||
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
sharding_spec = ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={-1: [0]})
|
||||||
|
|
||||||
if not inplace:
|
return distribute_tensor(tensor, device_mesh, sharding_spec)
|
||||||
tensor = tensor.detach().clone()
|
|
||||||
|
|
||||||
return DTensor(tensor, device_mesh, sharding_spec)
|
|
||||||
|
def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
||||||
|
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||||
|
param = torch.nn.Parameter(dtensor, requires_grad=requires_grad)
|
||||||
|
|
||||||
|
# make it distributed as well
|
||||||
|
param.dist_layout = dtensor.dist_layout
|
||||||
|
_hijack_detach_and_clone(param)
|
||||||
|
|
||||||
|
return param
|
||||||
|
|
||||||
|
|
||||||
|
def compute_global_numel(dtensor: torch.Tensor) -> int:
|
||||||
|
"""
|
||||||
|
Compute the global number of elements in the distributed tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtensor (torch.Tensor): The distributed tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The global number of elements in the distributed tensor.
|
||||||
|
"""
|
||||||
|
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||||
|
numel = reduce(operator.mul, dtensor.dist_layout.global_shape)
|
||||||
|
return numel
|
||||||
|
|
||||||
|
|
||||||
|
def get_layout(dtensor: torch.Tensor) -> Layout:
|
||||||
|
"""
|
||||||
|
Get the layout of the distributed tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtensor (torch.Tensor): The distributed tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Layout: The layout of the distributed tensor.
|
||||||
|
|
||||||
|
"""
|
||||||
|
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||||
|
return dtensor.dist_layout
|
||||||
|
|
||||||
|
|
||||||
|
def get_global_shape(dtensor: torch.Tensor) -> torch.Size:
|
||||||
|
"""
|
||||||
|
Get the global shape of the distributed tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtensor (torch.Tensor): The distributed tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Size: The global shape of the distributed tensor.
|
||||||
|
"""
|
||||||
|
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||||
|
return dtensor.dist_layout.global_shape
|
||||||
|
|
||||||
|
|
||||||
|
def get_device_mesh(dtensor: torch.Tensor) -> DeviceMesh:
|
||||||
|
"""
|
||||||
|
Get the device mesh of the distributed tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtensor (torch.Tensor): The distributed tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DeviceMesh: The device mesh of the distributed tensor.
|
||||||
|
"""
|
||||||
|
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||||
|
return dtensor.dist_layout.device_mesh
|
||||||
|
|
||||||
|
|
||||||
|
def get_sharding_spec(dtensor: torch.Tensor) -> ShardingSpec:
|
||||||
|
"""
|
||||||
|
Get the sharding spec of the distributed tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dtensor (torch.Tensor): The distributed tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ShardingSpec: The sharding spec of the distributed tensor.
|
||||||
|
"""
|
||||||
|
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||||
|
return dtensor.dist_layout.sharding_spec
|
||||||
|
|
|
@ -1,12 +1,11 @@
|
||||||
import operator
|
import operator
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
|
||||||
from .misc import DuplicatedShardingDimensionError, LayoutException, ShardingNotDivisibleError
|
from .misc import DuplicatedShardingDimensionError, ShardingNotDivisibleError
|
||||||
from .sharding_spec import ShardingSpec
|
from .sharding_spec import ShardingSpec
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,18 +28,6 @@ class LayoutConverterOptions:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor:
|
|
||||||
layout_converter = LayoutConverter()
|
|
||||||
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
|
|
||||||
global_layout = Layout(device_mesh=layout.device_mesh,
|
|
||||||
device_type=layout.device_type,
|
|
||||||
sharding_spec=global_sharding_spec,
|
|
||||||
entire_shape=layout.entire_shape)
|
|
||||||
with torch.no_grad():
|
|
||||||
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
|
|
||||||
return global_tensor
|
|
||||||
|
|
||||||
|
|
||||||
def set_layout_converting_options(options: LayoutConverterOptions):
|
def set_layout_converting_options(options: LayoutConverterOptions):
|
||||||
"""
|
"""
|
||||||
Configure the shape consistency manager via function call.
|
Configure the shape consistency manager via function call.
|
||||||
|
@ -553,4 +541,5 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||||
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
|
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
|
||||||
for comm_spec in comm_action_sequence:
|
for comm_spec in comm_action_sequence:
|
||||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||||
|
tensor.dist_layout = target_layout
|
||||||
return tensor
|
return tensor
|
||||||
|
|
|
@ -29,7 +29,7 @@ def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = Fals
|
||||||
# the comm size for all gather is the size of the gathered tensor
|
# the comm size for all gather is the size of the gathered tensor
|
||||||
gather_dim = comm_spec.gather_dim
|
gather_dim = comm_spec.gather_dim
|
||||||
all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
|
all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
|
||||||
all_gather_size = device_mesh.mesh_shape[all_gather_axis]
|
all_gather_size = device_mesh.shape[all_gather_axis]
|
||||||
comm_size_for_all_gather = comm_size * all_gather_size
|
comm_size_for_all_gather = comm_size * all_gather_size
|
||||||
forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
|
forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
|
||||||
# give a tiny cost to shard
|
# give a tiny cost to shard
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from colossalai.tensor.d_tensor.api import to_distributed_tensor
|
|
@ -7,7 +7,8 @@ import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
|
||||||
from colossalai.tensor.d_tensor.layout_converter import to_global
|
from colossalai.tensor.d_tensor import to_global
|
||||||
|
from colossalai.tensor.d_tensor.layout import Layout
|
||||||
from tests.kit.model_zoo.registry import ModelAttribute
|
from tests.kit.model_zoo.registry import ModelAttribute
|
||||||
|
|
||||||
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
|
SUPPORT_LAZY = version.parse(torch.__version__) >= version.parse('1.12.0')
|
||||||
|
@ -91,6 +92,8 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
|
||||||
assert n1 == n2
|
assert n1 == n2
|
||||||
t1 = t1.cuda()
|
t1 = t1.cuda()
|
||||||
t2 = t2.cuda()
|
t2 = t2.cuda()
|
||||||
if n2 in layout_dict:
|
if n2 in sharding_spec_dict:
|
||||||
t2 = to_global(t2, layout_dict[n2])
|
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
|
||||||
|
t2.dist_layout = layout
|
||||||
|
t2 = to_global(t2)
|
||||||
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
|
||||||
|
|
|
@ -14,6 +14,10 @@ def check_embedding_1d():
|
||||||
|
|
||||||
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
||||||
|
|
||||||
|
# ensure state dict is reversibly loadable
|
||||||
|
embedding.load_state_dict(embedding_1d.state_dict())
|
||||||
|
embedding_1d.load_state_dict(embedding.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.randint(low=0, high=32, size=(4, 32)).cuda()
|
x = torch.randint(low=0, high=32, size=(4, 32)).cuda()
|
||||||
out = embedding(x)
|
out = embedding(x)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||||
|
from colossalai.tensor.d_tensor import is_distributed_tensor
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,9 +13,18 @@ def check_linear_1d_col():
|
||||||
linear = nn.Linear(32, 128).cuda()
|
linear = nn.Linear(32, 128).cuda()
|
||||||
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
||||||
|
|
||||||
|
# ensure that the parameters are distributed
|
||||||
|
assert is_distributed_tensor(linear_col.weight)
|
||||||
|
assert is_distributed_tensor(linear_col.bias)
|
||||||
|
|
||||||
|
# ensure the shape is correct
|
||||||
assert linear_col.weight.shape == torch.Size([64, 32])
|
assert linear_col.weight.shape == torch.Size([64, 32])
|
||||||
assert linear_col.bias.shape == torch.Size([64])
|
assert linear_col.bias.shape == torch.Size([64])
|
||||||
|
|
||||||
|
# ensure state dict is reversibly loadable
|
||||||
|
linear.load_state_dict(linear_col.state_dict())
|
||||||
|
linear_col.load_state_dict(linear.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 32).cuda()
|
x = torch.rand(4, 32).cuda()
|
||||||
out = linear(x)
|
out = linear(x)
|
||||||
|
@ -55,7 +65,7 @@ def check_linear_1d_row():
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
check_linear_1d_col()
|
check_linear_1d_col()
|
||||||
check_linear_1d_row()
|
# check_linear_1d_row()
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
|
@ -14,7 +14,11 @@ def check_vocab_embedding_1d():
|
||||||
|
|
||||||
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
||||||
assert dist_embedding_1d.num_embeddings == 64
|
assert dist_embedding_1d.num_embeddings == 64
|
||||||
assert dist_embedding_1d.embed_dim == 32
|
assert dist_embedding_1d.embedding_dim == 32
|
||||||
|
|
||||||
|
# ensure state dict is reversibly loadable
|
||||||
|
embedding.load_state_dict(dist_embedding_1d.state_dict())
|
||||||
|
dist_embedding_1d.load_state_dict(embedding.state_dict())
|
||||||
|
|
||||||
# check embedding correctness
|
# check embedding correctness
|
||||||
x = torch.randint(0, 128, (4, 32)).to('cuda')
|
x = torch.randint(0, 128, (4, 32)).to('cuda')
|
||||||
|
|
|
@ -1,14 +1,11 @@
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
from torch.distributed import ReduceOp
|
|
||||||
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
|
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
|
||||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,9 +3,7 @@ import torch
|
||||||
from colossalai.device.device_mesh import DeviceMesh
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
from colossalai.initialize import launch
|
from colossalai.initialize import launch
|
||||||
from colossalai.logging import disable_existing_loggers
|
from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.tensor.d_tensor.d_tensor import DTensor, distribute_tensor
|
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, get_global_shape, redistribute, to_global
|
||||||
from colossalai.tensor.d_tensor.layout import Layout
|
|
||||||
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,22 +29,18 @@ def check_dtensor(rank, world_size, port):
|
||||||
|
|
||||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
||||||
layout = Layout(device_mesh=device_mesh,
|
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
|
||||||
device_type=torch.device('cuda'),
|
|
||||||
sharding_spec=target_sharding_spec,
|
|
||||||
entire_shape=original_tensor.shape)
|
|
||||||
d_tensor = DTensor(original_tensor, layout)
|
|
||||||
|
|
||||||
assert d_tensor.entire_shape == original_tensor.shape
|
assert get_global_shape(d_tensor) == original_tensor.shape
|
||||||
assert d_tensor.data_type == original_tensor.dtype
|
assert d_tensor.dtype == original_tensor.dtype
|
||||||
|
|
||||||
if rank in (0, 1):
|
if rank in (0, 1):
|
||||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 0, 2))
|
assert d_tensor.equal(original_tensor.narrow(0, 0, 2))
|
||||||
elif rank in (2, 3):
|
elif rank in (2, 3):
|
||||||
assert d_tensor.to_local().equal(original_tensor.narrow(0, 2, 2))
|
assert d_tensor.equal(original_tensor.narrow(0, 2, 2))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||||
assert d_tensor.to_global().equal(original_tensor)
|
assert to_global(d_tensor).equal(original_tensor)
|
||||||
output = test_model(d_tensor)
|
output = test_model(d_tensor)
|
||||||
|
|
||||||
if rank in (0, 1):
|
if rank in (0, 1):
|
||||||
|
@ -57,34 +51,29 @@ def check_dtensor(rank, world_size, port):
|
||||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||||
|
|
||||||
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
|
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
|
||||||
new_layout = Layout(device_mesh=device_mesh,
|
d_tensor = redistribute(d_tensor, device_mesh, new_sharding_spec)
|
||||||
device_type=torch.device('cuda'),
|
|
||||||
sharding_spec=new_sharding_spec,
|
|
||||||
entire_shape=original_tensor.shape)
|
|
||||||
|
|
||||||
d_tensor.layout_convert(new_layout)
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
assert d_tensor.equal(original_tensor.narrow(0, 0, 1))
|
||||||
elif rank == 1:
|
elif rank == 1:
|
||||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
assert d_tensor.equal(original_tensor.narrow(0, 1, 1))
|
||||||
elif rank == 2:
|
elif rank == 2:
|
||||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
assert d_tensor.equal(original_tensor.narrow(0, 2, 1))
|
||||||
elif rank == 3:
|
elif rank == 3:
|
||||||
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
assert d_tensor.equal(original_tensor.narrow(0, 3, 1))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||||
|
|
||||||
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
dtensor_from_local = distribute_tensor(original_tensor, new_layout)
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
|
assert dtensor_from_local.equal(original_tensor.narrow(0, 0, 1))
|
||||||
elif rank == 1:
|
elif rank == 1:
|
||||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 1, 1))
|
assert dtensor_from_local.equal(original_tensor.narrow(0, 1, 1))
|
||||||
elif rank == 2:
|
elif rank == 2:
|
||||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 2, 1))
|
assert dtensor_from_local.equal(original_tensor.narrow(0, 2, 1))
|
||||||
elif rank == 3:
|
elif rank == 3:
|
||||||
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 3, 1))
|
assert dtensor_from_local.equal(original_tensor.narrow(0, 3, 1))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'rank {rank} is not in the device mesh')
|
raise ValueError(f'rank {rank} is not in the device mesh')
|
||||||
|
|
||||||
|
|
|
@ -9,7 +9,7 @@ from colossalai.logging import disable_existing_loggers
|
||||||
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
|
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
|
||||||
from colossalai.tensor.d_tensor.layout import Layout
|
from colossalai.tensor.d_tensor.layout import Layout
|
||||||
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
|
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
|
||||||
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
|
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
entire_shape = torch.Size((64, 32, 16))
|
entire_shape = torch.Size((64, 32, 16))
|
||||||
|
|
Loading…
Reference in New Issue