mirror of https://github.com/hpcaitech/ColossalAI
265 lines
10 KiB
Python
265 lines
10 KiB
Python
from typing import List, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed import ProcessGroup
|
|
|
|
from colossalai.lazy import LazyInitContext
|
|
from colossalai.shardformer.layer import ParallelModule
|
|
|
|
from .linear import W8A8B8O8Linear, W8A8BFP32O32LinearSiLU, W8A8BFP32OFP32Linear
|
|
|
|
|
|
def split_row_copy(smooth_linear, para_linear, tp_size=1, tp_rank=0, split_num=1):
|
|
qweights = smooth_linear.weight.split(smooth_linear.out_features // split_num, dim=0)
|
|
if smooth_linear.bias is not None:
|
|
bias = smooth_linear.bias.split(smooth_linear.out_features // split_num, dim=0)
|
|
|
|
smooth_split_out_features = para_linear.out_features // split_num
|
|
|
|
for i in range(split_num):
|
|
para_linear.weight[i * smooth_split_out_features : (i + 1) * smooth_split_out_features, :] = qweights[i][
|
|
tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features, :
|
|
]
|
|
|
|
if para_linear.bias is not None:
|
|
para_linear.bias[:, i * smooth_split_out_features : (i + 1) * smooth_split_out_features] = bias[i][
|
|
:, tp_rank * smooth_split_out_features : (tp_rank + 1) * smooth_split_out_features
|
|
]
|
|
|
|
|
|
def split_column_copy(smooth_linear, para_linear, tp_rank=0, split_num=1):
|
|
qweights = smooth_linear.weight.split(smooth_linear.in_features // split_num, dim=-1)
|
|
|
|
smooth_split_in_features = para_linear.in_features // split_num
|
|
|
|
for i in range(split_num):
|
|
para_linear.weight[:, i * smooth_split_in_features : (i + 1) * smooth_split_in_features] = qweights[i][
|
|
:, tp_rank * smooth_split_in_features : (tp_rank + 1) * smooth_split_in_features
|
|
]
|
|
|
|
if smooth_linear.bias is not None:
|
|
para_linear.bias.copy_(smooth_linear.bias)
|
|
|
|
|
|
class RowW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
|
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
|
super().__init__(in_features, out_features, alpha, beta)
|
|
self.process_group = None
|
|
self.tp_size = 1
|
|
self.tp_rank = 0
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
LazyInitContext.materialize(module)
|
|
# get the attributes
|
|
out_features = module.out_features
|
|
|
|
# ensure only one process group is passed
|
|
if isinstance(process_group, (list, tuple)):
|
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
|
process_group = process_group[0]
|
|
|
|
tp_size = dist.get_world_size(process_group)
|
|
tp_rank = dist.get_rank(process_group)
|
|
|
|
if out_features < tp_size:
|
|
return module
|
|
|
|
if out_features % tp_size != 0:
|
|
raise ValueError(
|
|
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
|
)
|
|
linear_1d = RowW8A8B8O8Linear(module.in_features, module.out_features // tp_size)
|
|
linear_1d.tp_size = tp_size
|
|
linear_1d.tp_rank = tp_rank
|
|
linear_1d.process_group = process_group
|
|
linear_1d.a = module.a.clone().detach()
|
|
linear_1d.b = module.b.clone().detach()
|
|
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
|
return linear_1d
|
|
|
|
|
|
class ColW8A8B8O8Linear(W8A8B8O8Linear, ParallelModule):
|
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
|
super().__init__(in_features, out_features, alpha, beta)
|
|
self.process_group = None
|
|
self.tp_size = 1
|
|
self.tp_rank = 0
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
LazyInitContext.materialize(module)
|
|
# get the attributes
|
|
in_features = module.in_features
|
|
|
|
# ensure only one process group is passed
|
|
if isinstance(process_group, (list, tuple)):
|
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
|
process_group = process_group[0]
|
|
|
|
tp_size = dist.get_world_size(process_group)
|
|
tp_rank = dist.get_rank(process_group)
|
|
|
|
if in_features < tp_size:
|
|
return module
|
|
|
|
if in_features % tp_size != 0:
|
|
raise ValueError(
|
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
|
)
|
|
linear_1d = ColW8A8B8O8Linear(module.in_features // tp_size, module.out_features)
|
|
linear_1d.tp_size = tp_size
|
|
linear_1d.tp_rank = tp_rank
|
|
linear_1d.process_group = process_group
|
|
linear_1d.a = torch.tensor(module.a)
|
|
linear_1d.b = torch.tensor(module.b)
|
|
|
|
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
|
if linear_1d.bias is not None:
|
|
linear_1d.bias = linear_1d.bias // tp_size
|
|
|
|
return linear_1d
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x):
|
|
output = super().forward(x)
|
|
if self.tp_size > 1:
|
|
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
|
return output
|
|
|
|
|
|
class RowW8A8BFP32O32LinearSiLU(W8A8BFP32O32LinearSiLU, ParallelModule):
|
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
|
super().__init__(in_features, out_features, alpha, beta)
|
|
self.process_group = None
|
|
self.tp_size = 1
|
|
self.tp_rank = 0
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
LazyInitContext.materialize(module)
|
|
# get the attributes
|
|
out_features = module.out_features
|
|
|
|
# ensure only one process group is passed
|
|
if isinstance(process_group, (list, tuple)):
|
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
|
process_group = process_group[0]
|
|
|
|
tp_size = dist.get_world_size(process_group)
|
|
tp_rank = dist.get_rank(process_group)
|
|
|
|
if out_features < tp_size:
|
|
return module
|
|
|
|
if out_features % tp_size != 0:
|
|
raise ValueError(
|
|
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
|
)
|
|
linear_1d = RowW8A8BFP32O32LinearSiLU(module.in_features, module.out_features // tp_size)
|
|
linear_1d.tp_size = tp_size
|
|
linear_1d.tp_rank = tp_rank
|
|
linear_1d.process_group = process_group
|
|
linear_1d.a = module.a.clone().detach()
|
|
|
|
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
|
return linear_1d
|
|
|
|
|
|
class RowW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
|
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
|
super().__init__(in_features, out_features, alpha, beta)
|
|
self.process_group = None
|
|
self.tp_size = 1
|
|
self.tp_rank = 0
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
LazyInitContext.materialize(module)
|
|
# get the attributes
|
|
out_features = module.out_features
|
|
|
|
# ensure only one process group is passed
|
|
if isinstance(process_group, (list, tuple)):
|
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
|
process_group = process_group[0]
|
|
|
|
tp_size = dist.get_world_size(process_group)
|
|
tp_rank = dist.get_rank(process_group)
|
|
|
|
if out_features < tp_size:
|
|
return module
|
|
|
|
if out_features % tp_size != 0:
|
|
raise ValueError(
|
|
f"The size of out_features:{out_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
|
)
|
|
linear_1d = RowW8A8BFP32OFP32Linear(module.in_features, module.out_features // tp_size)
|
|
linear_1d.tp_size = tp_size
|
|
linear_1d.tp_rank = tp_rank
|
|
linear_1d.process_group = process_group
|
|
linear_1d.a = module.a.clone().detach()
|
|
|
|
split_row_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
|
return linear_1d
|
|
|
|
|
|
class ColW8A8BFP32OFP32Linear(W8A8BFP32OFP32Linear, ParallelModule):
|
|
def __init__(self, in_features, out_features, alpha=1.0, beta=1.0):
|
|
super().__init__(in_features, out_features, alpha, beta)
|
|
self.process_group = None
|
|
self.tp_size = 1
|
|
self.tp_rank = 0
|
|
|
|
@staticmethod
|
|
def from_native_module(
|
|
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
|
) -> ParallelModule:
|
|
LazyInitContext.materialize(module)
|
|
# get the attributes
|
|
in_features = module.in_features
|
|
|
|
# ensure only one process group is passed
|
|
if isinstance(process_group, (list, tuple)):
|
|
assert len(process_group) == 1, f"Expected only one process group, got {len(process_group)}."
|
|
process_group = process_group[0]
|
|
|
|
tp_size = dist.get_world_size(process_group)
|
|
tp_rank = dist.get_rank(process_group)
|
|
|
|
if in_features < tp_size:
|
|
return module
|
|
|
|
if in_features % tp_size != 0:
|
|
raise ValueError(
|
|
f"The size of in_features:{in_features} is not integer multiples of tensor parallel size: {tp_size}!"
|
|
)
|
|
linear_1d = ColW8A8BFP32OFP32Linear(module.in_features // tp_size, module.out_features)
|
|
linear_1d.tp_size = tp_size
|
|
linear_1d.tp_rank = tp_rank
|
|
linear_1d.process_group = process_group
|
|
linear_1d.a = module.a.clone().detach()
|
|
|
|
split_column_copy(module, linear_1d, tp_rank=tp_rank, **kwargs)
|
|
if linear_1d.bias is not None:
|
|
linear_1d.bias = linear_1d.bias / tp_size
|
|
|
|
return linear_1d
|
|
|
|
@torch.no_grad()
|
|
def forward(self, x):
|
|
output = super().forward(x)
|
|
if self.tp_size > 1:
|
|
dist.all_reduce(output, op=dist.ReduceOp.SUM, group=self.process_group)
|
|
return output
|