mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] support inplace sharding (#4251)
* [shardformer] embedding support inplace sharding * [shardformer] linear support inplace sharding * [shardformer] layernorm support inplace sharding * [shardformer] qkv support inplace sharding * [test] update shardformer layer test * [shardformer] fix shared param sharding * [shardformer] fix bert policy * [shardformer] fix bloom policy * [shardformer] fix llama policy * [shardformer] fix opt policy * [shardformer] fix t5 policy * [shardformer] fix fused qkv linear * [shardformer] fix bugs * force sync * [test] fix bugs * [test] fix transformer versionpull/4445/head
parent
2a2eacfaf1
commit
d921ce8391
|
@ -3,10 +3,11 @@ from .embedding import Embedding1D, VocabParallelEmbedding1D
|
|||
from .linear import Linear1D_Col, Linear1D_Row
|
||||
from .loss import cross_entropy_1d
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
from .qkv_fused_linear import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||
|
||||
__all__ = [
|
||||
"Embedding1D", "VocabParallelEmbedding1D", "Linear1D_Col", "Linear1D_Row", 'GPT2FusedLinearConv1D_Col',
|
||||
'GPT2FusedLinearConv1D_Row', 'DropoutForParallelInput', 'DropoutForReplicatedInput', "cross_entropy_1d",
|
||||
'FusedLayerNorm', 'FusedRMSNorm'
|
||||
'FusedLayerNorm', 'FusedRMSNorm', 'ParallelModule'
|
||||
]
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -13,7 +13,12 @@ from torch.distributed import ProcessGroup
|
|||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
from colossalai.tensor.d_tensor.api import (
|
||||
is_distributed_tensor,
|
||||
shard_colwise,
|
||||
shard_rowwise,
|
||||
sharded_tensor_to_existing_param,
|
||||
)
|
||||
|
||||
from ._operation import gather_forward_split_backward, reduce_forward
|
||||
from .parallel_module import ParallelModule
|
||||
|
@ -60,6 +65,7 @@ class Embedding1D(ParallelModule):
|
|||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = True,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
|
@ -74,18 +80,24 @@ class Embedding1D(ParallelModule):
|
|||
self.embed_kwargs = kwargs
|
||||
self.gather_output = gather_output
|
||||
|
||||
# Parameters.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer)
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_colwise(self.weight.data, process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if weight is None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding,
|
||||
|
@ -121,14 +133,10 @@ class Embedding1D(ParallelModule):
|
|||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse,
|
||||
weight=module.weight,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# copy the weight
|
||||
with torch.no_grad():
|
||||
sharded_weight = shard_colwise(module.weight.data, process_group)
|
||||
embedding.weight.copy_(sharded_weight)
|
||||
|
||||
return embedding
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
|
@ -143,7 +151,6 @@ class Embedding1D(ParallelModule):
|
|||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
return output
|
||||
|
@ -188,6 +195,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
|
@ -207,16 +215,23 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
# parameter
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
||||
# parameter
|
||||
if weight is None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if weight is None:
|
||||
self.reset_parameters(weight_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
|
@ -243,15 +258,10 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
padding_idx=padding_idx,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
# shard and slice the weight along the vocabulary(num_embeddings) dimension
|
||||
# the shape of the weight is (num_embeddings, embedding_dim)
|
||||
shard_weight = shard_rowwise(module.weight.data, process_group)
|
||||
vocab_embedding_1d.weight.data.copy_(shard_weight)
|
||||
|
||||
return vocab_embedding_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
@ -15,7 +15,12 @@ from torch.nn.parameter import Parameter
|
|||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor import shard_colwise, shard_rowwise, sharded_tensor_to_param
|
||||
from colossalai.tensor.d_tensor.api import (
|
||||
is_distributed_tensor,
|
||||
shard_colwise,
|
||||
shard_rowwise,
|
||||
sharded_tensor_to_existing_param,
|
||||
)
|
||||
|
||||
from ._operation import (
|
||||
gather_forward_split_backward,
|
||||
|
@ -65,6 +70,8 @@ class Linear1D_Col(ParallelModule):
|
|||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
@ -80,26 +87,42 @@ class Linear1D_Col(ParallelModule):
|
|||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Parameters.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
if bias:
|
||||
bias = torch.empty(self.out_features, **factory_kwargs)
|
||||
sharded_bias = shard_colwise(bias, self.process_group)
|
||||
self.bias = sharded_tensor_to_param(sharded_bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
if not is_distributed_tensor(self.bias):
|
||||
sharded_bias = shard_colwise(self.bias.data, self.process_group)
|
||||
sharded_tensor_to_existing_param(sharded_bias, self.bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
if weight is None:
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
|
@ -125,17 +148,11 @@ class Linear1D_Col(ParallelModule):
|
|||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
# the weight to the linear layer is a transpose
|
||||
# thus shard on row is equal to shard on column
|
||||
sharded_weight = shard_rowwise(module.weight.data, process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight)
|
||||
if bias:
|
||||
sharded_bias = shard_colwise(module.bias.data, process_group)
|
||||
linear_1d.bias.copy_(sharded_bias)
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
|
@ -198,6 +215,8 @@ class Linear1D_Row(ParallelModule):
|
|||
process_group: ProcessGroup = None,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1):
|
||||
|
@ -216,27 +235,44 @@ class Linear1D_Row(ParallelModule):
|
|||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
||||
sharded_weight = shard_colwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
self.chunk_weight()
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
|
||||
# Parameters.
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_colwise(self.weight.data, self.process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
self.chunk_weight()
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
if weight is None:
|
||||
with self.randomizer.fork_rng(enable_cpu=True):
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
|
@ -262,19 +298,11 @@ class Linear1D_Row(ParallelModule):
|
|||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
# the weigh to the linear layer is a transpose
|
||||
# thus shard on col is equal to shard on row
|
||||
sharded_weight = shard_colwise(module.weight.data, process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight)
|
||||
|
||||
if bias:
|
||||
linear_1d.bias.copy_(module.bias.data)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def chunk_weight(self):
|
||||
|
|
|
@ -60,10 +60,8 @@ class FusedLayerNorm():
|
|||
layernorm = ApexFusedLayerNorm(normalized_shape, eps=eps,
|
||||
elementwise_affine=elementwise_affine).to(dtype).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
# copy weight and bias
|
||||
layernorm.weight.copy_(module.weight)
|
||||
layernorm.bias.copy_(module.bias)
|
||||
layernorm.weight = module.weight
|
||||
layernorm.bias = module.bias
|
||||
return layernorm
|
||||
|
||||
|
||||
|
@ -101,8 +99,6 @@ class FusedRMSNorm():
|
|||
|
||||
rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
with torch.no_grad():
|
||||
# copy weight and bias
|
||||
rmsnorm.weight.copy_(module.weight)
|
||||
rmsnorm.weight = module.weight
|
||||
|
||||
return rmsnorm
|
||||
|
|
|
@ -2,12 +2,11 @@
|
|||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import math
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
@ -16,10 +15,12 @@ from colossalai.lazy import LazyInitContext
|
|||
from colossalai.nn import init as init
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor.d_tensor.api import (
|
||||
customized_distributed_tensor_to_param,
|
||||
customized_distributed_tensor_to_existing_param,
|
||||
distribute_tensor_with_customization,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
shard_rowwise,
|
||||
sharded_tensor_to_param,
|
||||
sharded_tensor_to_existing_param,
|
||||
)
|
||||
|
||||
from ._operation import (
|
||||
|
@ -173,6 +174,8 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
n_fused: int = 3,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||
super().__init__()
|
||||
|
@ -190,40 +193,56 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
|
||||
def shard_fn(tensor):
|
||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
||||
|
||||
def gather_fn(tensor):
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, True)
|
||||
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, True)
|
||||
|
||||
with torch.no_grad():
|
||||
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
|
||||
self.weight = customized_distributed_tensor_to_param(sharded_weight)
|
||||
if not is_customized_distributed_tensor(self.weight):
|
||||
with torch.no_grad():
|
||||
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
|
||||
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if bias:
|
||||
bias = torch.empty(self.out_features, **factory_kwargs)
|
||||
|
||||
with torch.no_grad():
|
||||
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
|
||||
self.bias = customized_distributed_tensor_to_param(sharded_bias)
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
if not is_customized_distributed_tensor(self.bias):
|
||||
with torch.no_grad():
|
||||
sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn)
|
||||
customized_distributed_tensor_to_existing_param(sharded_bias, self.bias)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
if weight is None:
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
|
||||
*args, **kwargs) -> ParallelModule:
|
||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
**kwargs) -> ParallelModule:
|
||||
r"""
|
||||
Convert a huggingface layer `Conv1D` in gpt2 to a parallelized linear layer.
|
||||
|
||||
|
@ -250,24 +269,11 @@ class GPT2FusedLinearConv1D_Col(ParallelModule):
|
|||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
||||
n_fused=n_fused,
|
||||
process_group=process_group,
|
||||
is_transposed=True)
|
||||
linear_1d.weight.data.copy_(sharded_weight.data)
|
||||
|
||||
if bias:
|
||||
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
||||
n_fused=n_fused,
|
||||
process_group=process_group,
|
||||
is_transposed=True)
|
||||
linear_1d.bias.data.copy_(sharded_bias.data)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
|
@ -333,6 +339,8 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
process_group: ProcessGroup = None,
|
||||
parallel_input: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
stream_chunk_num: int = 1):
|
||||
|
@ -351,30 +359,46 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
if skip_bias_add and not bias:
|
||||
raise ValueError('cannot skip bias addition if bias is None')
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.input_size_per_partition = divide(in_features, self.num_partitions)
|
||||
|
||||
# sanity check
|
||||
if weight is not None:
|
||||
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||
else:
|
||||
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||
|
||||
# Parameters.
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
weight = torch.empty(self.in_features, self.out_features, **factory_kwargs)
|
||||
sharded_weight = shard_rowwise(weight, self.process_group)
|
||||
self.weight = sharded_tensor_to_param(sharded_weight)
|
||||
if weight is None:
|
||||
# Initialize weight.
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
self.weight = Parameter(torch.empty(self.in_features, self.out_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
# TODO() work for inference only
|
||||
self.chunk_weight()
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
self.bias = bias_
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
# offset the seed with randomizer index and rank
|
||||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
if weight is None:
|
||||
# init weights
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args,
|
||||
|
@ -400,19 +424,11 @@ class GPT2FusedLinearConv1D_Row(ParallelModule):
|
|||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# TODO: copy the sharded weights
|
||||
with torch.no_grad():
|
||||
# the weigh to the linear layer is a transpose
|
||||
# thus shard on col is equal to shard on row
|
||||
sharded_weight = shard_rowwise(module.weight.data, process_group)
|
||||
linear_1d.weight.data.copy_(sharded_weight.data)
|
||||
|
||||
if bias:
|
||||
linear_1d.bias.copy_(module.bias.data)
|
||||
|
||||
return linear_1d
|
||||
|
||||
def chunk_weight(self):
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import CrossEntropyLoss, Module
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
MultipleChoiceModelOutput,
|
||||
|
@ -28,12 +26,11 @@ from transformers.models.bert.modeling_bert import (
|
|||
BertLMHeadModel,
|
||||
BertModel,
|
||||
)
|
||||
from transformers.utils import ModelOutput, logging
|
||||
from transformers.utils import logging
|
||||
|
||||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
@ -177,6 +174,17 @@ class BertPolicy(Policy):
|
|||
target_key=BertLMPredictionHead)
|
||||
return base_policy
|
||||
|
||||
def add_lm_prediction_policy(self, base_policy):
|
||||
from transformers.models.bert.modeling_bert import BertLMPredictionHead
|
||||
method_replacement = {
|
||||
'_save_to_state_dict': col_nn.ParallelModule._save_to_state_dict,
|
||||
'_load_from_state_dict': col_nn.ParallelModule._load_from_state_dict,
|
||||
}
|
||||
self.append_or_create_method_replacement(description=method_replacement,
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead)
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
return self.model
|
||||
|
||||
|
@ -240,6 +248,7 @@ class BertForPreTrainingPolicy(BertPolicy):
|
|||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
policy = self.add_lm_prediction_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertForPreTraining
|
||||
self.set_pipeline_forward(model_cls=BertForPreTraining, new_forward=bert_for_pretraining_forward, policy=policy)
|
||||
return policy
|
||||
|
@ -266,21 +275,13 @@ class BertForPreTrainingPolicy(BertPolicy):
|
|||
model = self.model
|
||||
if self.pipeline_stage_manager:
|
||||
if id(model.bert.embeddings.word_embeddings.weight) == id(model.cls.predictions.decoder.weight):
|
||||
#tie weights
|
||||
# tie weights
|
||||
return [{
|
||||
0: model.bert.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# BertLMHeadModel
|
||||
class BertLMHeadModelPolicy(BertPolicy):
|
||||
|
@ -291,6 +292,7 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
policy = self.add_lm_prediction_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertLMHeadModel
|
||||
self.set_pipeline_forward(model_cls=BertLMHeadModel, new_forward=bert_lm_head_model_forward, policy=policy)
|
||||
return policy
|
||||
|
@ -316,21 +318,13 @@ class BertLMHeadModelPolicy(BertPolicy):
|
|||
bert_model = self.model.bert
|
||||
if self.pipeline_stage_manager:
|
||||
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
||||
#tie weights
|
||||
# tie weights
|
||||
return [{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# BertForMaskedLM
|
||||
class BertForMaskedLMPolicy(BertPolicy):
|
||||
|
@ -341,6 +335,7 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
policy = self.add_lm_head_policy(policy)
|
||||
mpolicy = self.add_lm_prediction_policy(policy)
|
||||
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
||||
self.set_pipeline_forward(model_cls=BertForMaskedLM, new_forward=bert_for_masked_lm_forward, policy=policy)
|
||||
return policy
|
||||
|
@ -366,21 +361,13 @@ class BertForMaskedLMPolicy(BertPolicy):
|
|||
bert_model = self.model.bert
|
||||
if self.pipeline_stage_manager:
|
||||
if id(bert_model.embeddings.word_embeddings.weight) == id(self.model.cls.predictions.decoder.weight):
|
||||
#tie weights
|
||||
# tie weights
|
||||
return [{
|
||||
0: bert_model.embeddings.word_embeddings.weight,
|
||||
self.pipeline_stage_manager.num_stages - 1: self.model.cls.predictions.decoder.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# BertForSequenceClassification
|
||||
class BertForSequenceClassificationPolicy(BertPolicy):
|
||||
|
@ -1032,6 +1019,7 @@ def bert_for_masked_lm_forward(
|
|||
stage_manager: Optional[PipelineStageManager] = None,
|
||||
stage_index: Optional[List[int]] = None,
|
||||
):
|
||||
# -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||||
|
@ -1109,7 +1097,7 @@ def bert_for_next_sentence_prediction_forward(
|
|||
stage_index: Optional[List[int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
#-> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
||||
# -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
||||
r"""
|
||||
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
||||
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
import warnings
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
@ -27,7 +25,6 @@ from transformers.utils import logging
|
|||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from ..modeling.bloom import build_bloom_alibi_tensor_fn
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
@ -229,20 +226,10 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||
# tie weights
|
||||
return [{
|
||||
0: bloom_model.transformer.word_embeddings.weight,
|
||||
self.stage_manager.num_stages - 1: bloom_model.lm_head.weight
|
||||
self.pipeline_stage_manager.num_stages - 1: bloom_model.lm_head.weight
|
||||
}]
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism and self.pipeline_stage_manager is None:
|
||||
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
|
||||
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
# tie weights
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
class BloomForSequenceClassificationPolicy(BloomPolicy):
|
||||
|
||||
|
@ -692,7 +679,7 @@ def bloom_for_sequence_classification_forward(
|
|||
all_cross_attentions = None
|
||||
if stage_manager.is_last_stage():
|
||||
batch_size = hidden_states.shape[0]
|
||||
#update batch size
|
||||
# update batch size
|
||||
hidden_states = transformer_outputs[0]
|
||||
logits = self.score(hidden_states)
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
|||
import colossalai.shardformer.layer as col_nn
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
|
@ -56,42 +55,42 @@ class GPT2Policy(Policy):
|
|||
"attn.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_attn",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 1,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.resid_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_attn",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 3,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_fc",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Col,
|
||||
kwargs={
|
||||
"n_fused": 1,
|
||||
},
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.c_proj",
|
||||
target_module=col_nn.GPT2FusedLinearConv1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.attn_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="attn.resid_dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.dropout",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
),
|
||||
])
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
|
@ -99,8 +98,8 @@ class GPT2Policy(Policy):
|
|||
suffix="ln_f",
|
||||
target_module=col_nn.FusedLayerNorm,
|
||||
),
|
||||
policy=policy,
|
||||
target_key=GPT2Model)
|
||||
policy=policy,
|
||||
target_key=GPT2Model)
|
||||
|
||||
self.append_or_create_submodule_replacement(description=[
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -115,8 +114,8 @@ class GPT2Policy(Policy):
|
|||
target_module=col_nn.FusedLayerNorm,
|
||||
ignore_if_not_exist=True)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=GPT2Block)
|
||||
policy=policy,
|
||||
target_key=GPT2Block)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
|
@ -227,15 +226,6 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
else:
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism \
|
||||
and self.pipeline_stage_manager is None:
|
||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# GPT2DoubleHeadsModel
|
||||
class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
||||
|
@ -286,15 +276,6 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
else:
|
||||
return []
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism \
|
||||
and self.pipeline_stage_manager is None:
|
||||
binding_map = {"transformer.wte.weight": "lm_head.weight"}
|
||||
for k, v in binding_map.items():
|
||||
param = getattr_(self.model, k)
|
||||
setattr_(self.model, v, param)
|
||||
return self.model
|
||||
|
||||
|
||||
# GPT2ForQuestionAnswering
|
||||
class GPT2ForQuestionAnsweringPolicy(GPT2Policy):
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
import math
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -9,14 +7,11 @@ from torch import Tensor
|
|||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss
|
||||
from transformers.modeling_outputs import (
|
||||
BaseModelOutputWithPast,
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BaseModelOutputWithPoolingAndCrossAttentions,
|
||||
CausalLMOutputWithCrossAttentions,
|
||||
CausalLMOutputWithPast,
|
||||
SequenceClassifierOutputWithPast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
|
||||
from transformers.utils import ModelOutput, logging
|
||||
from transformers.utils import logging
|
||||
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = [
|
||||
|
@ -116,19 +115,6 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||
target_key=OPTForCausalLM)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
binding_map = {
|
||||
'model.decoder.embed_tokens': 'lm_head',
|
||||
}
|
||||
|
||||
for k, v in binding_map.items():
|
||||
src_mod = getattr_(self.model, k)
|
||||
dst_mod = getattr_(self.model, v)
|
||||
dst_mod.weight = src_mod.weight
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
class OPTForSequenceClassificationPolicy(OPTPolicy):
|
||||
|
||||
|
|
|
@ -8,7 +8,6 @@ from colossalai.shardformer.layer import (
|
|||
)
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
||||
|
||||
from .._utils import getattr_, setattr_
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]
|
||||
|
@ -53,7 +52,7 @@ class T5BasePolicy(Policy):
|
|||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=Embedding1D,
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
])
|
||||
policy[T5LayerSelfAttention] = ModulePolicyDescription(sub_module_replacement=[
|
||||
|
@ -165,12 +164,6 @@ class T5BasePolicy(Policy):
|
|||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]
|
||||
|
||||
for k, v in binding_map:
|
||||
mod = getattr_(self.model, k)
|
||||
setattr_(self.model, v, mod)
|
||||
return self.model
|
||||
|
||||
|
||||
|
@ -211,18 +204,6 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
target_key=T5ForConditionalGeneration)
|
||||
return policy
|
||||
|
||||
def postprocess(self):
|
||||
super().postprocess()
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
binding_map = {"shared": "lm_head"}
|
||||
|
||||
for k, v in binding_map.items():
|
||||
src_mod = getattr_(self.model, k)
|
||||
dst_mod = getattr_(self.model, v)
|
||||
dst_mod.weight = src_mod.weight
|
||||
|
||||
return self.model
|
||||
|
||||
|
||||
class T5EncoderPolicy(T5BasePolicy):
|
||||
|
||||
|
@ -239,14 +220,3 @@ class T5EncoderPolicy(T5BasePolicy):
|
|||
policy=base_policy,
|
||||
target_key=T5EncoderModel)
|
||||
return base_policy
|
||||
|
||||
def postprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
binding_map = [
|
||||
["shared", "encoder.embed_tokens"],
|
||||
]
|
||||
|
||||
for k, v in binding_map:
|
||||
mod = getattr_(self.model, k)
|
||||
setattr_(self.model, v, mod)
|
||||
return self.model
|
||||
|
|
|
@ -37,11 +37,13 @@ class ModelSharder(object):
|
|||
self.policy.set_model(self.model)
|
||||
self.policy.set_shard_config(self.shard_config)
|
||||
self._preprocess()
|
||||
# get shared params before release unheld layers, this avoid misjudgement of shared params (None is None)
|
||||
shared_params = self.policy.get_shared_params()
|
||||
self._release_unheld_layers()
|
||||
self._replace_module()
|
||||
self._materialize()
|
||||
self._postprocess()
|
||||
return self.policy.get_shared_params()
|
||||
return shared_params
|
||||
|
||||
def _preprocess(self) -> None:
|
||||
self.model = self.policy.preprocess()
|
||||
|
|
|
@ -235,6 +235,14 @@ def sharded_tensor_to_param(dtensor: torch.Tensor, requires_grad: bool = True):
|
|||
return param
|
||||
|
||||
|
||||
def sharded_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter) -> None:
|
||||
assert is_distributed_tensor(dtensor), 'The input tensor is not a distributed tensor.'
|
||||
param.data = dtensor
|
||||
# make it distributed as well
|
||||
param.dist_layout = dtensor.dist_layout
|
||||
_hijack_detach_and_clone(param)
|
||||
|
||||
|
||||
def compute_global_numel(dtensor: torch.Tensor) -> int:
|
||||
"""
|
||||
Compute the global number of elements in the distributed tensor.
|
||||
|
@ -432,3 +440,15 @@ def customized_distributed_tensor_to_param(dtensor: torch.Tensor, requires_grad:
|
|||
param.gather_fn = dtensor.gather_fn
|
||||
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|
||||
return param
|
||||
|
||||
|
||||
def customized_distributed_tensor_to_existing_param(dtensor: torch.Tensor, param: torch.nn.Parameter):
|
||||
"""
|
||||
Convert the given customized distributed tensor to an existing parameter.
|
||||
"""
|
||||
assert is_customized_distributed_tensor(dtensor), 'The input tensor is not a customized distributed tensor.'
|
||||
|
||||
param.data = dtensor.data
|
||||
param.shard_fn = dtensor.shard_fn
|
||||
param.gather_fn = dtensor.gather_fn
|
||||
_hijack_detach_and_clone_for_customized_distributed_tensor(param)
|
||||
|
|
|
@ -17,3 +17,4 @@ requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggi
|
|||
SentencePiece
|
||||
ninja
|
||||
flash_attn>=2.0
|
||||
datasets
|
||||
|
|
|
@ -15,11 +15,13 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||
def check_embedding_1d(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
with ctx:
|
||||
embedding = nn.Embedding(32, 128).cuda()
|
||||
embedding_1d = Embedding1D.from_native_module(embedding, process_group=None)
|
||||
embedding_copy = nn.Embedding(32, 128).cuda()
|
||||
embedding_1d = Embedding1D.from_native_module(embedding_copy, process_group=None)
|
||||
|
||||
assert embedding_1d.weight.shape == torch.Size([32, 64])
|
||||
assert embedding_1d.weight is embedding_copy.weight
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
embedding.load_state_dict(embedding_1d.state_dict())
|
||||
|
|
|
@ -14,11 +14,14 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||
def check_layernorm(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||
with ctx:
|
||||
norm = nn.LayerNorm(128, 0.00001).cuda()
|
||||
norm1d = FusedLayerNorm.from_native_module(norm, process_group=None)
|
||||
norm_copy = nn.LayerNorm(128, 0.00001).cuda()
|
||||
norm1d = FusedLayerNorm.from_native_module(norm_copy, process_group=None)
|
||||
|
||||
assert norm1d.weight.shape == torch.Size([128])
|
||||
assert norm_copy.weight is norm1d.weight
|
||||
assert norm_copy.bias is norm1d.bias
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
norm.load_state_dict(norm1d.state_dict())
|
||||
|
|
|
@ -15,14 +15,16 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_1d_col(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear, process_group=None, gather_output=True)
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
|
||||
|
||||
# ensure that the parameters are distributed
|
||||
assert is_distributed_tensor(linear_col.weight)
|
||||
assert is_distributed_tensor(linear_col.bias)
|
||||
assert linear_copy.weight is linear_col.weight
|
||||
assert linear_copy.bias is linear_col.bias
|
||||
|
||||
# ensure the shape is correct
|
||||
assert linear_col.weight.shape == torch.Size([64, 32])
|
||||
|
@ -61,12 +63,18 @@ def check_linear_1d_col(lazy_init: bool):
|
|||
def check_linear_1d_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
with ctx:
|
||||
linear = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
linear_copy = nn.Linear(32, 128).cuda()
|
||||
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear_row.weight.shape == torch.Size([128, 16])
|
||||
assert linear_row.bias.shape == torch.Size([128])
|
||||
assert linear_copy.weight is linear_row.weight
|
||||
assert linear_copy.bias is linear_row.bias
|
||||
|
||||
linear.load_state_dict(linear_row.state_dict())
|
||||
linear_row.load_state_dict(linear.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
|
@ -98,11 +106,19 @@ def check_linear_1d_row(lazy_init: bool):
|
|||
def check_linear_col_plus_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
|
||||
with ctx:
|
||||
linear_1 = nn.Linear(32, 128).cuda()
|
||||
linear_2 = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2, process_group=None, parallel_input=True)
|
||||
linear_1_copy = nn.Linear(32, 128).cuda()
|
||||
linear_2_copy = nn.Linear(128, 32).cuda()
|
||||
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
|
||||
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
|
||||
|
||||
linear_1.load_state_dict(linear_col.state_dict())
|
||||
linear_col.load_state_dict(linear_1.state_dict())
|
||||
linear_2.load_state_dict(linear_row.state_dict())
|
||||
linear_row.load_state_dict(linear_2.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 32).cuda()
|
||||
|
|
|
@ -56,10 +56,10 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
|||
@parameterize('lazy_init', [False, True])
|
||||
def check_linear_conv_1d_col(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
|
||||
process_group=None,
|
||||
gather_output=True,
|
||||
n_fused=3)
|
||||
|
@ -68,6 +68,8 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
|||
assert linear.bias.shape == torch.Size([192])
|
||||
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
||||
assert linear_conv_col.bias.shape == torch.Size([96])
|
||||
assert linear_copy.weight is linear_conv_col.weight
|
||||
assert linear_copy.bias is linear_conv_col.bias
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_conv_col.load_state_dict(linear.state_dict())
|
||||
|
@ -91,13 +93,20 @@ def check_linear_conv_1d_col(lazy_init: bool):
|
|||
def check_linear_conv_1d_row(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
with ctx:
|
||||
linear = Conv1D(192, 48).cuda()
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
||||
linear_copy = Conv1D(192, 48).cuda()
|
||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
||||
|
||||
assert linear.weight.shape == torch.Size([48, 192])
|
||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||
assert linear_row.bias.shape == torch.Size([192])
|
||||
assert linear_copy.weight is linear_row.weight
|
||||
assert linear_copy.bias is linear_row.bias
|
||||
|
||||
# ensure weights are reversibly loadable
|
||||
linear_row.load_state_dict(linear.state_dict())
|
||||
linear.load_state_dict(linear_row.state_dict())
|
||||
|
||||
# check computation correctness
|
||||
x = torch.rand(4, 48).cuda()
|
||||
|
|
|
@ -7,8 +7,7 @@ from torch.testing import assert_close
|
|||
|
||||
import colossalai
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||
from colossalai.shardformer.layer import VocabParallelEmbedding1D
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
|
@ -16,13 +15,15 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
|||
def check_vocab_embedding_1d(lazy_init: bool):
|
||||
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||
|
||||
embedding = nn.Embedding(128, 32).to('cuda')
|
||||
with ctx:
|
||||
embedding = nn.Embedding(128, 32).to('cuda')
|
||||
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding, process_group=None)
|
||||
embedding_copy = nn.Embedding(128, 32).to('cuda')
|
||||
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None)
|
||||
|
||||
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
||||
assert dist_embedding_1d.num_embeddings == 64
|
||||
assert dist_embedding_1d.embedding_dim == 32
|
||||
assert embedding_copy.weight is dist_embedding_1d.weight
|
||||
|
||||
# ensure state dict is reversibly loadable
|
||||
embedding.load_state_dict(dist_embedding_1d.state_dict())
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import copy
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
|
@ -61,3 +63,14 @@ def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn,
|
|||
shard_output = output_transform_fn(shard_output)
|
||||
shard_loss = loss_fn(shard_output)
|
||||
return org_output, org_loss, shard_output, shard_loss
|
||||
|
||||
|
||||
def check_state_dict(org_model: Module, sharded_model: Module, name: str = ''):
|
||||
org_sd = org_model.state_dict()
|
||||
shard_sd = sharded_model.state_dict()
|
||||
for k, v in org_sd.items():
|
||||
assert k in shard_sd, f'{name} {k} not in sharded model'
|
||||
shard_v = shard_sd[k]
|
||||
assert v.shape == shard_v.shape, f'{name} {k} shape mismatch, {v.shape} vs {shard_v.shape}'
|
||||
assert v.dtype == shard_v.dtype, f'{name} {k} dtype mismatch, {v.dtype} vs {shard_v.dtype}'
|
||||
assert torch.equal(v, shard_v), f'{name} {k} value mismatch'
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.testing import (
|
|||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
|
@ -75,6 +75,7 @@ def run_bert_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
|||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.testing import (
|
|||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
|
@ -75,6 +75,7 @@ def run_bloom_test(enable_fused_normalization, enable_tensor_parallelism, use_la
|
|||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ from colossalai.testing import (
|
|||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
|
@ -77,6 +77,7 @@ def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_laz
|
|||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.testing import (
|
|||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
@ -78,6 +78,7 @@ def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_la
|
|||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ from colossalai.testing import (
|
|||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
|
||||
|
@ -77,6 +77,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
|||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from colossalai.testing import (
|
|||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
from tests.test_shardformer.test_model._utils import build_model, check_state_dict, run_forward
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
|
@ -88,6 +88,7 @@ def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_
|
|||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
||||
use_lazy_init)
|
||||
check_state_dict(org_model, sharded_model, name=name)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
Loading…
Reference in New Issue