mirror of https://github.com/hpcaitech/ColossalAI
update some module with new api version
parent
879301d0da
commit
726541afe2
|
@ -537,10 +537,11 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
gather_output: bool = False,
|
gather_output: bool = False,
|
||||||
skip_bias_add: bool = False,
|
skip_bias_add: bool = False,
|
||||||
n_fused: int = 3,
|
n_fused: int = 3,
|
||||||
|
weight: Optional[Parameter] = None,
|
||||||
|
bias_: Optional[Parameter] = None,
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Keep input parameters
|
# Keep input parameters
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
|
@ -554,36 +555,52 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
if skip_bias_add and not bias:
|
if skip_bias_add and not bias:
|
||||||
raise ValueError('cannot skip bias addition if bias is None')
|
raise ValueError('cannot skip bias addition if bias is None')
|
||||||
|
|
||||||
|
# offset the seed with randomizer index and rank
|
||||||
|
seed = torch.random.initial_seed()
|
||||||
|
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
if weight is not None:
|
||||||
|
assert not bias or bias_ is not None, 'bias_ must be provided if bias is True when weight is not None'
|
||||||
|
else:
|
||||||
|
assert bias_ is None, 'bias_ must be None if weight is None'
|
||||||
|
|
||||||
# Parameters.
|
# Parameters.
|
||||||
# Initialize weight.
|
if weight is None:
|
||||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
# Initialize weight.
|
||||||
weight = torch.empty(self.out_features, self.in_features, **factory_kwargs)
|
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||||
|
self.weight = Parameter(torch.empty(self.out_features, self.in_features, **factory_kwargs))
|
||||||
|
else:
|
||||||
|
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
def shard_fn(tensor):
|
def shard_fn(tensor):
|
||||||
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
return split_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
||||||
|
|
||||||
def gather_fn(tensor):
|
def gather_fn(tensor):
|
||||||
return gather_fused_qkv_in_gpt2_style(tensor, 3, self.process_group, False)
|
return gather_fused_qkv_in_gpt2_style(tensor, self.n_fused, self.process_group, False)
|
||||||
|
|
||||||
with torch.no_grad():
|
if not is_customized_distributed_tensor(self.weight):
|
||||||
sharded_weight = distribute_tensor_with_customization(weight, shard_fn, gather_fn)
|
with torch.no_grad():
|
||||||
self.weight = customized_distributed_tensor_to_param(sharded_weight)
|
sharded_weight = distribute_tensor_with_customization(self.weight.data, shard_fn, gather_fn)
|
||||||
|
customized_distributed_tensor_to_existing_param(sharded_weight, self.weight)
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
bias = torch.empty(self.out_features, **factory_kwargs)
|
if bias_ is None:
|
||||||
|
self.bias = Parameter(torch.empty(self.out_features, **factory_kwargs))
|
||||||
with torch.no_grad():
|
else:
|
||||||
sharded_bias = distribute_tensor_with_customization(bias, shard_fn, gather_fn)
|
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||||
self.bias = customized_distributed_tensor_to_param(sharded_bias)
|
self.bias = bias_
|
||||||
|
if not is_customized_distributed_tensor(self.bias):
|
||||||
|
with torch.no_grad():
|
||||||
|
sharded_bias = distribute_tensor_with_customization(self.bias.data, shard_fn, gather_fn)
|
||||||
|
customized_distributed_tensor_to_existing_param(sharded_bias, self.bias)
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
# offset the seed with randomizer index and rank
|
if weight is None:
|
||||||
seed = torch.random.initial_seed()
|
# init weights
|
||||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
self.reset_parameters(weight_initializer, bias_initializer)
|
||||||
|
|
||||||
# init weights
|
|
||||||
self.reset_parameters(weight_initializer, bias_initializer)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
|
def from_native_module(module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], n_fused: int,
|
||||||
|
@ -613,24 +630,26 @@ class FusedLinear1D_Col(ParallelModule):
|
||||||
bias=bias,
|
bias=bias,
|
||||||
device=device,
|
device=device,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
|
weight=module.weight,
|
||||||
|
bias_=module.bias,
|
||||||
*args,
|
*args,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
# TODO: copy the sharded weights
|
# # TODO: copy the sharded weights
|
||||||
with torch.no_grad():
|
# with torch.no_grad():
|
||||||
sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
# sharded_weight = split_fused_qkv_in_gpt2_style(module.weight.data,
|
||||||
n_fused=n_fused,
|
# n_fused=n_fused,
|
||||||
process_group=process_group,
|
# process_group=process_group,
|
||||||
is_transposed=False)
|
# is_transposed=False)
|
||||||
linear_1d.weight.data.copy_(sharded_weight.data)
|
# linear_1d.weight.data.copy_(sharded_weight.data)
|
||||||
|
|
||||||
if bias:
|
|
||||||
sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
|
||||||
n_fused=n_fused,
|
|
||||||
process_group=process_group,
|
|
||||||
is_transposed=False)
|
|
||||||
linear_1d.bias.data.copy_(sharded_bias.data)
|
|
||||||
|
|
||||||
|
# if bias:
|
||||||
|
# sharded_bias = split_fused_qkv_in_gpt2_style(module.bias.data,
|
||||||
|
# n_fused=n_fused,
|
||||||
|
# process_group=process_group,
|
||||||
|
# is_transposed=False)
|
||||||
|
# linear_1d.bias.data.copy_(sharded_bias.data)
|
||||||
|
print(linear_1d.weight.shape)
|
||||||
return linear_1d
|
return linear_1d
|
||||||
|
|
||||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||||
|
|
|
@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..modeling.blip2 import forward_fn
|
from ..modeling.blip2 import forward_fn
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ['BlipPolicy', 'BlipModelPolicy']
|
__all__ = ['BlipPolicy', 'BlipModelPolicy']
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||||
|
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
|
__all__ = ['ChatGLMModelPolicy', 'ChatGLMForConditionalGenerationPolicy']
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from ..modeling.sam import forward_fn
|
from ..modeling.sam import forward_fn
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = ['SamPolicy', 'SamModelPolicy']
|
__all__ = ['SamPolicy', 'SamModelPolicy']
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||||
import colossalai.shardformer.layer as col_nn
|
import colossalai.shardformer.layer as col_nn
|
||||||
|
|
||||||
from .._utils import getattr_, setattr_
|
from .._utils import getattr_, setattr_
|
||||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification'
|
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification'
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.testing import assert_close
|
from torch.testing import assert_close
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.lazy import LazyInitContext
|
||||||
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
from colossalai.shardformer.layer import GPT2FusedLinearConv1D_Col, GPT2FusedLinearConv1D_Row
|
||||||
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
from colossalai.shardformer.layer.qkv_fused_linear import split_fused_qkv_in_gpt2_style
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
|
||||||
# This code is copied from https://github.com/huggingface/transformers
|
# This code is copied from https://github.com/huggingface/transformers
|
||||||
|
@ -50,9 +53,13 @@ def rearrange(tensor: torch.Tensor, dim: int):
|
||||||
return rearanged_tensor
|
return rearanged_tensor
|
||||||
|
|
||||||
|
|
||||||
def check_gpt2_linear_conv_1d_col():
|
@parameterize('lazy_init', [False, True])
|
||||||
|
def check_linear_conv_1d_col(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear,
|
with ctx:
|
||||||
|
linear_copy = Conv1D(192, 48).cuda()
|
||||||
|
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
|
||||||
process_group=None,
|
process_group=None,
|
||||||
gather_output=True,
|
gather_output=True,
|
||||||
n_fused=3)
|
n_fused=3)
|
||||||
|
@ -61,6 +68,8 @@ def check_gpt2_linear_conv_1d_col():
|
||||||
assert linear.bias.shape == torch.Size([192])
|
assert linear.bias.shape == torch.Size([192])
|
||||||
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
assert linear_conv_col.weight.shape == torch.Size([48, 96])
|
||||||
assert linear_conv_col.bias.shape == torch.Size([96])
|
assert linear_conv_col.bias.shape == torch.Size([96])
|
||||||
|
assert linear_copy.weight is linear_conv_col.weight
|
||||||
|
assert linear_copy.bias is linear_conv_col.bias
|
||||||
|
|
||||||
# ensure weights are reversibly loadable
|
# ensure weights are reversibly loadable
|
||||||
linear_conv_col.load_state_dict(linear.state_dict())
|
linear_conv_col.load_state_dict(linear.state_dict())
|
||||||
|
@ -80,13 +89,24 @@ def check_gpt2_linear_conv_1d_col():
|
||||||
assert_close(target_grad, linear_conv_col.weight.grad)
|
assert_close(target_grad, linear_conv_col.weight.grad)
|
||||||
|
|
||||||
|
|
||||||
def check_gpt2_linear_conv_1d_row():
|
@parameterize('lazy_init', [False, True])
|
||||||
|
def check_linear_conv_1d_row(lazy_init: bool):
|
||||||
|
ctx = LazyInitContext() if lazy_init else nullcontext()
|
||||||
|
|
||||||
linear = Conv1D(192, 48).cuda()
|
linear = Conv1D(192, 48).cuda()
|
||||||
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear, process_group=None, parallel_input=False)
|
with ctx:
|
||||||
|
linear_copy = Conv1D(192, 48).cuda()
|
||||||
|
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
|
||||||
|
|
||||||
assert linear.weight.shape == torch.Size([48, 192])
|
assert linear.weight.shape == torch.Size([48, 192])
|
||||||
assert linear_row.weight.shape == torch.Size([24, 192])
|
assert linear_row.weight.shape == torch.Size([24, 192])
|
||||||
assert linear_row.bias.shape == torch.Size([192])
|
assert linear_row.bias.shape == torch.Size([192])
|
||||||
|
assert linear_copy.weight is linear_row.weight
|
||||||
|
assert linear_copy.bias is linear_row.bias
|
||||||
|
|
||||||
|
# ensure weights are reversibly loadable
|
||||||
|
linear_row.load_state_dict(linear.state_dict())
|
||||||
|
linear.load_state_dict(linear_row.state_dict())
|
||||||
|
|
||||||
# check computation correctness
|
# check computation correctness
|
||||||
x = torch.rand(4, 48).cuda()
|
x = torch.rand(4, 48).cuda()
|
||||||
|
@ -107,14 +127,14 @@ def run_dist(rank, world_size, port):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
# test for linear conv
|
# test for linear conv
|
||||||
check_gpt2_linear_conv_1d_col()
|
check_linear_conv_1d_col()
|
||||||
check_gpt2_linear_conv_1d_row()
|
check_linear_conv_1d_row()
|
||||||
|
|
||||||
|
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
def test_gpt2_linearconv():
|
def test_linearconv():
|
||||||
spawn(run_dist, nprocs=2)
|
spawn(run_dist, nprocs=2)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_gpt2_linearconv()
|
test_linearconv()
|
||||||
|
|
|
@ -84,9 +84,10 @@ def run_chatglm_test(enable_fused_normalization, enable_tensor_parallelism):
|
||||||
model_copy = copy.deepcopy(org_model)
|
model_copy = copy.deepcopy(org_model)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
if name == "transformers_chatglm":
|
if name == "transformers_chatglm":
|
||||||
sharded_model = shard_former.optimize(model_copy, ChatGLMModelPolicy()).cuda()
|
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMModelPolicy())
|
||||||
else:
|
else:
|
||||||
sharded_model = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy()).cuda()
|
sharded_model, _ = shard_former.optimize(model_copy, ChatGLMForConditionalGenerationPolicy())
|
||||||
|
sharded_model = sharded_model.cuda()
|
||||||
|
|
||||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
Loading…
Reference in New Issue