[shardformer] support inplace sharding (#4251)

* [shardformer] embedding support inplace sharding

* [shardformer] linear support inplace sharding

* [shardformer] layernorm support inplace sharding

* [shardformer] qkv support inplace sharding

* [test] update shardformer layer test

* [shardformer] fix shared param sharding

* [shardformer] fix bert policy

* [shardformer] fix bloom policy

* [shardformer] fix llama policy

* [shardformer] fix opt policy

* [shardformer] fix t5 policy

* [shardformer] fix fused qkv linear

* [shardformer] fix bugs

* force sync

* [test] fix bugs

* [test] fix transformer version
pull/4445/head
Hongxin Liu 2023-07-20 10:39:06 +08:00
parent 2a2eacfaf1
commit d921ce8391
26 changed files with 371 additions and 340 deletions

View File

@ -3,10 +3,11 @@ from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm
from .parallel_module import ParallelModule
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
__all__ = [
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
'FusedLayerNorm', 'FusedRMSNorm'
'FusedLayerNorm', 'FusedRMSNorm', 'ParallelModule'
]

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Callable, List, Union
from typing import Callable, List, Optional, Union
import torch
import torch.distributed as dist
@ -13,7 +13,12 @@ from torch.distributed import ProcessGroup
from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.tensor.d_tensor.api import (
is_distributed_tensor,
shard_colwise,
shard_rowwise,
sharded_tensor_to_existing_param,
)
from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
@ -60,6 +65,7 @@ class Embedding1D(ParallelModule):
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
@ -74,18 +80,24 @@ class Embedding1D(ParallelModule):
self.embed_kwargs = kwargs
self.gather_output = gather_output
# Parameters.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_colwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
# Parameters.
if weight is None:
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_colwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if weight is None:
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer)
@staticmethod
def from_native_module(module: nn.Embedding,
@ -121,14 +133,10 @@ class Embedding1D(ParallelModule):
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
weight=module.weight,
*args,
**kwargs)
# copy the weight
with torch.no_grad():
sharded_weight = shard_colwise(module.weight.data, process_group)
embedding.weight.copy_(sharded_weight)
return embedding
def reset_parameters(self, weight_initializer) -> None:
@ -143,7 +151,6 @@ class Embedding1D(ParallelModule):
def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
@ -188,6 +195,7 @@ class VocabParallelEmbedding1D(ParallelModule):
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
@ -207,16 +215,23 @@ class VocabParallelEmbedding1D(ParallelModule):
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
# parameter
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
sharded_weight = shard_rowwise(weight, process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
self.reset_parameters(weight_initializer)
# parameter
if weight is None:
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if weight is None:
self.reset_parameters(weight_initializer)
@staticmethod
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@ -243,15 +258,10 @@ class VocabParallelEmbedding1D(ParallelModule):
padding_idx=padding_idx,
device=device,
process_group=process_group,
weight=module.weight,
*args,
**kwargs)
with torch.no_grad():
# shard and slice the weight along the vocabulary(num_embeddings) dimension
# the shape of the weight is (num_embeddings, embedding_dim)
shard_weight = shard_rowwise(module.weight.data, process_group)
vocab_embedding_1d.weight.data.copy_(shard_weight)
return vocab_embedding_1d
def reset_parameters(self, weight_initializer) -> None:

View File

@ -2,7 +2,7 @@
# -*- encoding: utf-8 -*-
import math
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
@ -15,7 +15,12 @@ from torch.nn.parameter import Parameter
from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
from colossalai.tensor.d_tensor.api import (
is_distributed_tensor,
shard_colwise,
shard_rowwise,
sharded_tensor_to_existing_param,
)
from ._operation import (
gather_forward_split_backward,
@ -65,6 +70,8 @@ class Linear1D_Col(ParallelModule):
process_group: ProcessGroup = None,
gather_output: bool = False,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
@ -80,26 +87,42 @@ class Linear1D_Col(ParallelModule):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Parameters.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_rowwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if bias:
bias = torch.empty(self.out_features, **factory_kwargs)
sharded_bias = shard_colwise(bias, self.process_group)
self.bias = sharded_tensor_to_param(sharded_bias)
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
else:
assert bias_ is None, 'bias_ must be None if weight is None'
# Parameters.
if weight is None:
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
if not is_distributed_tensor(self.bias):
sharded_bias = shard_colwise(self.bias.data, self.process_group)
sharded_tensor_to_existing_param(sharded_bias, self.bias)
else:
self.bias = None
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@ -125,17 +148,11 @@ class Linear1D_Col(ParallelModule):
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs)
with torch.no_grad():
# the weight to the linear layer is a transpose
# thus shard on row is equal to shard on column
sharded_weight = shard_rowwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight)
if bias:
sharded_bias = shard_colwise(module.bias.data, process_group)
linear_1d.bias.copy_(sharded_bias)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@ -198,6 +215,8 @@ class Linear1D_Row(ParallelModule):
process_group: ProcessGroup = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
@ -216,27 +235,44 @@ class Linear1D_Row(ParallelModule):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# Parameters.
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
sharded_weight = shard_colwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
else:
assert bias_ is None, 'bias_ must be None if weight is None'
# Parameters.
if weight is None:
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_colwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
if weight is None:
with self.randomizer.fork_rng(enable_cpu=True):
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@ -262,19 +298,11 @@ class Linear1D_Row(ParallelModule):
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
sharded_weight = shard_colwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight)
if bias:
linear_1d.bias.copy_(module.bias.data)
return linear_1d
def chunk_weight(self):

View File

@ -60,10 +60,8 @@ class FusedLayerNorm():
layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps,
elementwise_affine=elementwise_affine).to(dtype).to(device)
with torch.no_grad():
# copy weight and bias
layernorm.weight.copy_(module.weight)
layernorm.bias.copy_(module.bias)
layernorm.weight = module.weight
layernorm.bias = module.bias
return layernorm
@ -101,8 +99,6 @@ class FusedRMSNorm():
rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
with torch.no_grad():
# copy weight and bias
rmsnorm.weight.copy_(module.weight)
rmsnorm.weight = module.weight
return rmsnorm

View File

@ -2,12 +2,11 @@
# -*- encoding: utf-8 -*-
import math
from typing import Callable, List, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@ -16,10 +15,12 @@ from colossalai.lazy import LazyInitContext
from colossalai.nn import init as init
from colossalai.nn.layer.utils import divide
from colossalai.tensor.d_tensor.api import (
customized_distributed_tensor_to_param,
customized_distributed_tensor_to_existing_param,
distribute_tensor_with_customization,
is_customized_distributed_tensor,
is_distributed_tensor,
shard_rowwise,
sharded_tensor_to_param,
sharded_tensor_to_existing_param,
)
from ._operation import (
@ -173,6 +174,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
gather_output: bool = False,
skip_bias_add: bool = False,
n_fused: int = 3,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
@ -190,40 +193,56 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
else:
assert bias_ is None, 'bias_ must be None if weight is None'
# Parameters.
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
if weight is None:
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
def shard_fn(tensor):
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
def gather_fn(tensor):
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True)
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
self.weight = customized_distributed_tensor_to_param(sharded_weight)
if not is_customized_distributed_tensor(self.weight):
with torch.no_grad():
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
if bias:
bias = torch.empty(self.out_features, **factory_kwargs)
with torch.no_grad():
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
self.bias = customized_distributed_tensor_to_param(sharded_bias)
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
if not is_customized_distributed_tensor(self.bias):
with torch.no_grad():
sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn)
customized_distributed_tensor_to_existing_param(sharded_bias, self.bias)
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
*args, **kwargs) -> ParallelModule:
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
**kwargs) -> ParallelModule:
r"""
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
@ -250,24 +269,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=True)
linear_1d.weight.data.copy_(sharded_weight.data)
if bias:
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
n_fused=n_fused,
process_group=process_group,
is_transposed=True)
linear_1d.bias.data.copy_(sharded_bias.data)
return linear_1d
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
@ -333,6 +339,8 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
process_group: ProcessGroup = None,
parallel_input: bool = True,
skip_bias_add: bool = False,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
stream_chunk_num: int = 1):
@ -351,30 +359,46 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
if skip_bias_add and not bias:
raise ValueError('cannot skip bias addition if bias is None')
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# Divide the weight matrix along the last dimension.
self.input_size_per_partition = divide(in_features, self.num_partitions)
# sanity check
if weight is not None:
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
else:
assert bias_ is None, 'bias_ must be None if weight is None'
# Parameters.
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
sharded_weight = shard_rowwise(weight, self.process_group)
self.weight = sharded_tensor_to_param(sharded_weight)
if weight is None:
# Initialize weight.
factory_kwargs = {'device': device, 'dtype': dtype}
self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
if self.stream_chunk_num > 1:
# TODO() work for inference only
self.chunk_weight()
if bias:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
if bias_ is None:
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
self.bias = bias_
else:
self.bias = None
# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
if weight is None:
# init weights
self.reset_parameters(weight_initializer, bias_initializer)
@staticmethod
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
@ -400,19 +424,11 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs)
# TODO: copy the sharded weights
with torch.no_grad():
# the weigh to the linear layer is a transpose
# thus shard on col is equal to shard on row
sharded_weight = shard_rowwise(module.weight.data, process_group)
linear_1d.weight.data.copy_(sharded_weight.data)
if bias:
linear_1d.bias.copy_(module.bias.data)
return linear_1d
def chunk_weight(self):

View File

@ -1,13 +1,11 @@
from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module
from transformers.modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
MultipleChoiceModelOutput,
@ -28,12 +26,11 @@ from transformers.models.bert.modeling_bert import (
BertLMHeadModel,
BertModel,
)
from transformers.utils import ModelOutput, logging
from transformers.utils import logging
import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
logger = logging.get_logger(__name__)
@ -177,6 +174,17 @@ class BertPolicy(Policy):
target_key=BertLMPredictionHead)
return base_policy
def add_lm_prediction_policy(self, base_policy):
from transformers.models.bert.modeling_bert import BertLMPredictionHead
method_replacement = {
'_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict,
'_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict,
}
self.append_or_create_method_replacement(description=method_replacement,
policy=base_policy,
target_key=BertLMPredictionHead)
return base_policy
def postprocess(self):
return self.model
@ -240,6 +248,7 @@ class BertForPreTrainingPolicy(BertPolicy):
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
policy = self.add_lm_prediction_policy(policy)
from transformers.models.bert.modeling_bert import BertForPreTraining
self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy)
return policy
@ -266,21 +275,13 @@ class BertForPreTrainingPolicy(BertPolicy):
model = self.model
if self.pipeline_stage_manager:
if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
#tie weights
# tie weights
return [{
0: model.bert.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
@ -291,6 +292,7 @@ class BertLMHeadModelPolicy(BertPolicy):
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
policy = self.add_lm_prediction_policy(policy)
from transformers.models.bert.modeling_bert import BertLMHeadModel
self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy)
return policy
@ -316,21 +318,13 @@ class BertLMHeadModelPolicy(BertPolicy):
bert_model = self.model.bert
if self.pipeline_stage_manager:
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
#tie weights
# tie weights
return [{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy):
@ -341,6 +335,7 @@ class BertForMaskedLMPolicy(BertPolicy):
def module_policy(self):
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
mpolicy = self.add_lm_prediction_policy(policy)
from transformers.models.bert.modeling_bert import BertForMaskedLM
self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy)
return policy
@ -366,21 +361,13 @@ class BertForMaskedLMPolicy(BertPolicy):
bert_model = self.model.bert
if self.pipeline_stage_manager:
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
#tie weights
# tie weights
return [{
0: bert_model.embeddings.word_embeddings.weight,
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
@ -1032,6 +1019,7 @@ def bert_for_masked_lm_forward(
stage_manager: Optional[PipelineStageManager] = None,
stage_index: Optional[List[int]] = None,
):
# -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
@ -1109,7 +1097,7 @@ def bert_for_next_sentence_prediction_forward(
stage_index: Optional[List[int]] = None,
**kwargs,
):
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair

View File

@ -1,9 +1,7 @@
import warnings
from functools import partial
from types import MethodType
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
@ -27,7 +25,6 @@ from transformers.utils import logging
import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from .._utils import getattr_, setattr_
from ..modeling.bloom import build_bloom_alibi_tensor_fn
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -229,20 +226,10 @@ class BloomForCausalLMPolicy(BloomPolicy):
# tie weights
return [{
0: bloom_model.transformer.word_embeddings.weight,
self.stage_manager.num_stages - 1: bloom_model.lm_head.weight
self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight
}]
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
# tie weights
setattr_(self.model, v, param)
return self.model
class BloomForSequenceClassificationPolicy(BloomPolicy):
@ -692,7 +679,7 @@ def bloom_for_sequence_classification_forward(
all_cross_attentions = None
if stage_manager.is_last_stage():
batch_size = hidden_states.shape[0]
#update batch size
# update batch size
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)

View File

@ -8,7 +8,6 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import colossalai.shardformer.layer as col_nn
from colossalai.pipeline.stage_manager import PipelineStageManager
from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@ -56,42 +55,42 @@ class GPT2Policy(Policy):
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attn.c_attn",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 3,
},
),
SubModuleReplacementDescription(
suffix="attn.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.c_fc",
target_module=col_nn.GPT2FusedLinearConv1D_Col,
kwargs={
"n_fused": 1,
},
),
SubModuleReplacementDescription(
suffix="mlp.c_proj",
target_module=col_nn.GPT2FusedLinearConv1D_Row,
),
SubModuleReplacementDescription(
suffix="attn.attn_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="attn.resid_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dropout",
target_module=col_nn.DropoutForParallelInput,
),
])
# optimization configuration
if self.shard_config.enable_fused_normalization:
@ -99,8 +98,8 @@ class GPT2Policy(Policy):
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
),
policy=policy,
target_key=GPT2Model)
policy=policy,
target_key=GPT2Model)
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
@ -115,8 +114,8 @@ class GPT2Policy(Policy):
target_module=col_nn.FusedLayerNorm,
ignore_if_not_exist=True)
],
policy=policy,
target_key=GPT2Block)
policy=policy,
target_key=GPT2Block)
return policy
def postprocess(self):
@ -227,15 +226,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
else:
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism \
and self.pipeline_stage_manager is None:
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# GPT2DoubleHeadsModel
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
@ -286,15 +276,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
else:
return []
def postprocess(self):
if self.shard_config.enable_tensor_parallelism \
and self.pipeline_stage_manager is None:
binding_map = {"transformer.wte.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
# GPT2ForQuestionAnswering
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):

View File

@ -1,7 +1,5 @@
import math
from functools import partial
from types import MethodType
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
@ -9,14 +7,11 @@ from torch import Tensor
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
from transformers.utils import ModelOutput, logging
from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

View File

@ -1,6 +1,5 @@
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
@ -116,19 +115,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
target_key=OPTForCausalLM)
return policy
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
binding_map = {
'model.decoder.embed_tokens': 'lm_head',
}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model
class OPTForSequenceClassificationPolicy(OPTPolicy):

View File

@ -8,7 +8,6 @@ from colossalai.shardformer.layer import (
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
from .._utils import getattr_, setattr_
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
@ -53,7 +52,7 @@ class T5BasePolicy(Policy):
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
target_module=VocabParallelEmbedding1D,
)
])
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
@ -165,12 +164,6 @@ class T5BasePolicy(Policy):
return policy
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
return self.model
@ -211,18 +204,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
target_key=T5ForConditionalGeneration)
return policy
def postprocess(self):
super().postprocess()
if self.shard_config.enable_tensor_parallelism:
binding_map = {"shared": "lm_head"}
for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight
return self.model
class T5EncoderPolicy(T5BasePolicy):
@ -239,14 +220,3 @@ class T5EncoderPolicy(T5BasePolicy):
policy=base_policy,
target_key=T5EncoderModel)
return base_policy
def postprocess(self):
if self.shard_config.enable_tensor_parallelism:
binding_map = [
["shared", "encoder.embed_tokens"],
]
for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
return self.model

View File

@ -37,11 +37,13 @@ class ModelSharder(object):
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self._preprocess()
# get shared params before release unheld layers, this avoid misjudgement of shared params (None is None)
shared_params = self.policy.get_shared_params()
self._release_unheld_layers()
self._replace_module()
self._materialize()
self._postprocess()
return self.policy.get_shared_params()
return shared_params
def _preprocess(self) -> None:
self.model = self.policy.preprocess()

View File

@ -235,6 +235,14 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
return param
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
param.data = dtensor
# make it distributed as well
param.dist_layout = dtensor.dist_layout
_hijack_detach_and_clone(param)
def compute_global_numel(dtensor: torch.Tensor) -> int:
"""
Compute the global number of elements in the distributed tensor.
@ -432,3 +440,15 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad:
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
return param
def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter):
"""
Convert the given customized distributed tensor to an existing parameter.
"""
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
param.data = dtensor.data
param.shard_fn = dtensor.shard_fn
param.gather_fn = dtensor.gather_fn
_hijack_detach_and_clone_for_customized_distributed_tensor(param)

View File

@ -17,3 +17,4 @@ requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggi
SentencePiece
ninja
flash_attn>=2.0
datasets

View File

@ -15,11 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_embedding_1d(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
embedding = nn.Embedding(32, 128).cuda()
with ctx:
embedding = nn.Embedding(32, 128).cuda()
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
embedding_copy = nn.Embedding(32, 128).cuda()
embedding_1d = Embedding1D.from_native_module(embedding_copy, process_group=None)
assert embedding_1d.weight.shape == torch.Size([32, 64])
assert embedding_1d.weight is embedding_copy.weight
# ensure state dict is reversibly loadable
embedding.load_state_dict(embedding_1d.state_dict())

View File

@ -14,11 +14,14 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_layernorm(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
norm = nn.LayerNorm(128, 0.00001).cuda()
with ctx:
norm = nn.LayerNorm(128, 0.00001).cuda()
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
norm_copy = nn.LayerNorm(128, 0.00001).cuda()
norm1d = FusedLayerNorm.from_native_module(norm_copy, process_group=None)
assert norm1d.weight.shape == torch.Size([128])
assert norm_copy.weight is norm1d.weight
assert norm_copy.bias is norm1d.bias
# ensure state dict is reversibly loadable
norm.load_state_dict(norm1d.state_dict())

View File

@ -15,14 +15,16 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize('lazy_init', [False, True])
def check_linear_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
assert is_distributed_tensor(linear_col.bias)
assert linear_copy.weight is linear_col.weight
assert linear_copy.bias is linear_col.bias
# ensure the shape is correct
assert linear_col.weight.shape == torch.Size([64, 32])
@ -61,12 +63,18 @@ def check_linear_1d_col(lazy_init: bool):
def check_linear_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
linear_copy = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias
linear.load_state_dict(linear_row.state_dict())
linear_row.load_state_dict(linear.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
@ -98,11 +106,19 @@ def check_linear_1d_row(lazy_init: bool):
def check_linear_col_plus_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
with ctx:
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
linear_1.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear_1.state_dict())
linear_2.load_state_dict(linear_row.state_dict())
linear_row.load_state_dict(linear_2.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()

View File

@ -56,10 +56,10 @@ def rearrange(tensor: torch.Tensor, dim: int):
@parameterize('lazy_init', [False, True])
def check_linear_conv_1d_col(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear = Conv1D(192, 48).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
linear_copy = Conv1D(192, 48).cuda()
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
n_fused=3)
@ -68,6 +68,8 @@ def check_linear_conv_1d_col(lazy_init: bool):
assert linear.bias.shape == torch.Size([192])
assert linear_conv_col.weight.shape == torch.Size([48, 96])
assert linear_conv_col.bias.shape == torch.Size([96])
assert linear_copy.weight is linear_conv_col.weight
assert linear_copy.bias is linear_conv_col.bias
# ensure weights are reversibly loadable
linear_conv_col.load_state_dict(linear.state_dict())
@ -91,13 +93,20 @@ def check_linear_conv_1d_col(lazy_init: bool):
def check_linear_conv_1d_row(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
assert linear_row.bias.shape == torch.Size([192])
assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias
# ensure weights are reversibly loadable
linear_row.load_state_dict(linear.state_dict())
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()

View File

@ -7,8 +7,7 @@ from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
from colossalai.shardformer.layer import VocabParallelEmbedding1D
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@ -16,13 +15,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
def check_vocab_embedding_1d(lazy_init: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
embedding = nn.Embedding(128, 32).to('cuda')
with ctx:
embedding = nn.Embedding(128, 32).to('cuda')
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
embedding_copy = nn.Embedding(128, 32).to('cuda')
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None)
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
assert dist_embedding_1d.num_embeddings == 64
assert dist_embedding_1d.embedding_dim == 32
assert embedding_copy.weight is dist_embedding_1d.weight
# ensure state dict is reversibly loadable
embedding.load_state_dict(dist_embedding_1d.state_dict())

View File

@ -1,8 +1,10 @@
import copy
from contextlib import nullcontext
import torch
from torch.nn import Module
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
@ -61,3 +63,14 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)
return org_output, org_loss, shard_output, shard_loss
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
org_sd = org_model.state_dict()
shard_sd = sharded_model.state_dict()
for k, v in org_sd.items():
assert k in shard_sd, f'{name} {k} not in sharded model'
shard_v = shard_sd[k]
assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'

View File

@ -12,7 +12,7 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
@ -75,6 +75,7 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -12,7 +12,7 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
@ -75,6 +75,7 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -12,7 +12,7 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
@ -77,6 +77,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -14,7 +14,7 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -78,6 +78,7 @@ def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_la
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -15,7 +15,7 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
@ -77,6 +77,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -14,7 +14,7 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
@ -88,6 +88,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_state_dict(org_model, sharded_model, name=name)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()