@ -18,15 +18,15 @@ from .gptq_op import CaiGPTQLinearOp
HAS_GPTQ_CUDA = False
HAS_GPTQ_CUDA = False
try :
try :
from colossalai . kernel . op_builder . gptq import GPTQBuilder
from colossalai . kernel . op_builder . gptq import GPTQBuilder
gptq_cuda = GPTQBuilder ( ) . load ( )
gptq_cuda = GPTQBuilder ( ) . load ( )
HAS_GPTQ_CUDA = True
HAS_GPTQ_CUDA = True
except ImportError :
except ImportError :
warnings . warn ( ' CUDA gptq is not installed ' )
warnings . warn ( " CUDA gptq is not installed " )
HAS_GPTQ_CUDA = False
HAS_GPTQ_CUDA = False
class CaiQuantLinear ( nn . Module ) :
class CaiQuantLinear ( nn . Module ) :
def __init__ ( self , bits , groupsize , infeatures , outfeatures , bias , tp_size = 1 , tp_rank = 0 , row_split = False ) :
def __init__ ( self , bits , groupsize , infeatures , outfeatures , bias , tp_size = 1 , tp_rank = 0 , row_split = False ) :
super ( ) . __init__ ( )
super ( ) . __init__ ( )
if bits not in [ 2 , 4 , 8 ] :
if bits not in [ 2 , 4 , 8 ] :
@ -37,23 +37,28 @@ class CaiQuantLinear(nn.Module):
self . maxq = 2 * * self . bits - 1
self . maxq = 2 * * self . bits - 1
self . groupsize = groupsize if groupsize != - 1 else infeatures
self . groupsize = groupsize if groupsize != - 1 else infeatures
self . register_buffer ( ' qweight ' , torch . zeros ( ( infeatures / / 32 * self . bits , outfeatures ) , dtype = torch . int32 ) )
self . register_buffer ( " qweight " , torch . zeros ( ( infeatures / / 32 * self . bits , outfeatures ) , dtype = torch . int32 ) )
self . register_buffer (
" qzeros " ,
torch . zeros ( ( math . ceil ( infeatures / self . groupsize ) , outfeatures / / 32 * self . bits ) , dtype = torch . int32 ) ,
)
self . register_buffer (
self . register_buffer (
' qzeros ' ,
" scales " , torch . zeros ( ( math . ceil ( infeatures / self . groupsize ) , outfeatures ) , dtype = torch . float16 )
torch . zeros ( ( math . ceil ( infeatures / self . groupsize ) , outfeatures / / 32 * self . bits ) , dtype = torch . int32 ) )
)
self . register_buffer ( ' scales ' ,
torch . zeros ( ( math . ceil ( infeatures / self . groupsize ) , outfeatures ) , dtype = torch . float16 ) )
if row_split :
if row_split :
self . register_buffer (
self . register_buffer (
' g_idx ' ,
" g_idx " ,
torch . tensor ( [ ( i + ( tp_rank * self . infeatures ) ) / / self . groupsize for i in range ( infeatures ) ] ,
torch . tensor (
dtype = torch . int32 ) )
[ ( i + ( tp_rank * self . infeatures ) ) / / self . groupsize for i in range ( infeatures ) ] , dtype = torch . int32
) ,
)
else :
else :
self . register_buffer ( ' g_idx ' ,
self . register_buffer (
torch . tensor ( [ i / / self . groupsize for i in range ( infeatures ) ] , dtype = torch . int32 ) )
" g_idx " , torch . tensor ( [ i / / self . groupsize for i in range ( infeatures ) ] , dtype = torch . int32 )
)
if bias :
if bias :
self . register_buffer ( ' bias ' , torch . zeros ( ( outfeatures ) , dtype = torch . float16 ) )
self . register_buffer ( " bias " , torch . zeros ( ( outfeatures ) , dtype = torch . float16 ) )
else :
else :
self . bias = None
self . bias = None
@ -66,9 +71,11 @@ class CaiQuantLinear(nn.Module):
self . row_split = row_split
self . row_split = row_split
def pack ( self , linear , scales , zeros , g_idx = None ) :
def pack ( self , linear , scales , zeros , g_idx = None ) :
g_idx = (
g_idx = g_idx . clone ( ) if g_idx is not None else torch . tensor (
g_idx . clone ( )
[ i / / self . groupsize for i in range ( self . infeatures ) ] , dtype = torch . int32 )
if g_idx is not None
else torch . tensor ( [ i / / self . groupsize for i in range ( self . infeatures ) ] , dtype = torch . int32 )
)
scales = scales . t ( ) . contiguous ( )
scales = scales . t ( ) . contiguous ( )
zeros = zeros . t ( ) . contiguous ( )
zeros = zeros . t ( ) . contiguous ( )
@ -79,7 +86,6 @@ class CaiQuantLinear(nn.Module):
if linear . bias is not None :
if linear . bias is not None :
self . bias = linear . bias . clone ( ) . half ( )
self . bias = linear . bias . clone ( ) . half ( )
wn = 8
pbits = 32
pbits = 32
ptype = torch . int32
ptype = torch . int32
unsign_type = np . uint32
unsign_type = np . uint32
@ -88,9 +94,10 @@ class CaiQuantLinear(nn.Module):
intweight = [ ]
intweight = [ ]
for idx in range ( self . infeatures ) :
for idx in range ( self . infeatures ) :
intweight . append (
intweight . append (
torch . round (
torch . round ( ( linear . weight . data [ : , idx ] + scale_zeros [ g_idx [ idx ] ] ) / half_scales [ g_idx [ idx ] ] ) . to ( ptype ) [
( linear . weight . data [ : , idx ] + scale_zeros [ g_idx [ idx ] ] ) / half_scales [ g_idx [ idx ] ] ) . to ( ptype ) [ : ,
: , None
None ] )
]
)
intweight = torch . cat ( intweight , dim = 1 )
intweight = torch . cat ( intweight , dim = 1 )
intweight = intweight . t ( ) . contiguous ( )
intweight = intweight . t ( ) . contiguous ( )
intweight = intweight . numpy ( ) . astype ( unsign_type )
intweight = intweight . numpy ( ) . astype ( unsign_type )
@ -144,13 +151,16 @@ class CaiQuantLinear(nn.Module):
torch . tensor (
torch . tensor (
[ ( i + ( self . tp_rank * self . infeatures ) ) / / self . groupsize for i in range ( self . infeatures ) ] ,
[ ( i + ( self . tp_rank * self . infeatures ) ) / / self . groupsize for i in range ( self . infeatures ) ] ,
dtype = torch . int32 ,
dtype = torch . int32 ,
device = self . g_idx . device ) ) :
device = self . g_idx . device ,
) ,
) :
self . g_idx = None
self . g_idx = None
elif torch . equal (
elif torch . equal (
self . g_idx ,
self . g_idx ,
torch . tensor ( [ i / / self . groupsize for i in range ( self . infeatures ) ] ,
torch . tensor (
dtype = torch . int32 ,
[ i / / self . groupsize for i in range ( self . infeatures ) ] , dtype = torch . int32 , device = self . g_idx . device
device = self . g_idx . device ) ) :
) ,
) :
self . g_idx = None
self . g_idx = None
if self . g_idx is not None :
if self . g_idx is not None :
@ -165,7 +175,6 @@ class CaiQuantLinear(nn.Module):
outshape = x . shape [ : - 1 ] + ( self . outfeatures , )
outshape = x . shape [ : - 1 ] + ( self . outfeatures , )
if HAS_GPTQ_CUDA and self . bits == 4 :
if HAS_GPTQ_CUDA and self . bits == 4 :
if self . q4 is None :
if self . q4 is None :
self . init_q4 ( )
self . init_q4 ( )
@ -191,7 +200,6 @@ class CaiQuantLinear(nn.Module):
def split_column_copy ( gptq_linear , cai_linear , tp_size = 1 , tp_rank = 0 , split_num = 1 ) :
def split_column_copy ( gptq_linear , cai_linear , tp_size = 1 , tp_rank = 0 , split_num = 1 ) :
qweights = gptq_linear . qweight . split ( gptq_linear . out_features / / split_num , dim = - 1 )
qweights = gptq_linear . qweight . split ( gptq_linear . out_features / / split_num , dim = - 1 )
qzeros = gptq_linear . qzeros . split ( gptq_linear . out_features / / ( 32 / / cai_linear . bits ) / / split_num , dim = - 1 )
qzeros = gptq_linear . qzeros . split ( gptq_linear . out_features / / ( 32 / / cai_linear . bits ) / / split_num , dim = - 1 )
scales = gptq_linear . scales . split ( gptq_linear . out_features / / split_num , dim = - 1 )
scales = gptq_linear . scales . split ( gptq_linear . out_features / / split_num , dim = - 1 )
@ -203,24 +211,24 @@ def split_column_copy(gptq_linear, cai_linear, tp_size=1, tp_rank=0, split_num=1
zero_split_block = cai_linear . outfeatures / / ( 32 / / cai_linear . bits ) / / split_num
zero_split_block = cai_linear . outfeatures / / ( 32 / / cai_linear . bits ) / / split_num
for i in range ( split_num ) :
for i in range ( split_num ) :
cai_linear . qweight [ : , i * cai_split_out_features : ( i + 1 ) *
cai_linear . qweight [ : , i * cai_split_out_features : ( i + 1 ) * cai_split_out_features ] = qweights [ i ] [
cai_split_out_features ] = qweights [ i ] [ : , tp_rank * cai_split_out_features : ( tp_rank + 1 ) *
: , tp_rank * cai_split_out_features : ( tp_rank + 1 ) * cai_split_out_features
cai_split_out_features ]
]
cai_linear . qzeros [ : , i * zero_split_block : ( i + 1 ) *
cai_linear . qzeros [ : , i * zero_split_block : ( i + 1 ) * zero_split_block ] = qzeros [ i ] [
zero_split_block ] = qzeros [ i ] [ : , tp_rank * zero_split_block : ( tp_rank + 1 ) * zero_split_block ]
: , tp_rank * zero_split_block : ( tp_rank + 1 ) * zero_split_block
cai_linear . scales [ : , i * cai_split_out_features : ( i + 1 ) *
]
cai_split_out_features ] = scales [ i ] [ : , tp_rank * cai_split_out_features : ( tp_rank + 1 ) *
cai_linear . scales [ : , i * cai_split_out_features : ( i + 1 ) * cai_split_out_features ] = scales [ i ] [
cai_split_out_features ]
: , tp_rank * cai_split_out_features : ( tp_rank + 1 ) * cai_split_out_features
]
if cai_linear . bias is not None :
if cai_linear . bias is not None :
cai_linear . bias [ i * cai_split_out_features : ( i + 1 ) *
cai_linear . bias [ i * cai_split_out_features : ( i + 1 ) * cai_split_out_features ] = bias [ i ] [
cai_split_out_features ] = bias [ i ] [ tp_rank * cai_split_out_features : ( tp_rank + 1 ) *
tp_rank * cai_split_out_features : ( tp_rank + 1 ) * cai_split_out_features
cai_split_out_features ]
]
cai_linear . g_idx . copy_ ( g_idx )
cai_linear . g_idx . copy_ ( g_idx )
def split_row_copy ( gptq_linear , cai_linear , tp_rank = 0 , split_num = 1 ) :
def split_row_copy ( gptq_linear , cai_linear , tp_rank = 0 , split_num = 1 ) :
qweights = gptq_linear . qweight . split ( gptq_linear . in_features / / split_num , dim = 0 )
qweights = gptq_linear . qweight . split ( gptq_linear . in_features / / split_num , dim = 0 )
qzeros = gptq_linear . qzeros . split ( gptq_linear . in_features / / split_num , dim = 0 )
qzeros = gptq_linear . qzeros . split ( gptq_linear . in_features / / split_num , dim = 0 )
scales = gptq_linear . scales . split ( gptq_linear . in_features / / split_num , dim = 0 )
scales = gptq_linear . scales . split ( gptq_linear . in_features / / split_num , dim = 0 )
@ -231,47 +239,40 @@ def split_row_copy(gptq_linear, cai_linear, tp_rank=0, split_num=1):
idx_split_features = cai_linear . infeatures / / split_num
idx_split_features = cai_linear . infeatures / / split_num
for i in range ( split_num ) :
for i in range ( split_num ) :
cai_linear . qweight [ i * cai_split_in_features : ( i + 1 ) *
cai_linear . qweight [ i * cai_split_in_features : ( i + 1 ) * cai_split_in_features , : ] = qweights [ i ] [
cai_split_in_features , : ] = qweights [ i ] [ tp_rank * cai_split_in_features : ( tp_rank + 1 ) *
tp_rank * cai_split_in_features : ( tp_rank + 1 ) * cai_split_in_features , :
cai_split_in_features , : ]
]
cai_linear . qzeros [ i * zero_split_block : ( i + 1 ) *
cai_linear . qzeros [ i * zero_split_block : ( i + 1 ) * zero_split_block , : ] = qzeros [ i ] [
zero_split_block , : ] = qzeros [ i ] [ tp_rank * zero_split_block : ( tp_rank + 1 ) *
tp_rank * zero_split_block : ( tp_rank + 1 ) * zero_split_block , :
zero_split_block , : ]
]
cai_linear . scales [ i * zero_split_block : ( i + 1 ) *
cai_linear . scales [ i * zero_split_block : ( i + 1 ) * zero_split_block , : ] = scales [ i ] [
zero_split_block , : ] = scales [ i ] [ tp_rank * zero_split_block : ( tp_rank + 1 ) *
tp_rank * zero_split_block : ( tp_rank + 1 ) * zero_split_block , :
zero_split_block , : ]
]
cai_linear . g_idx [ i * idx_split_features : ( i + 1 ) *
cai_linear . g_idx [ i * idx_split_features : ( i + 1 ) * idx_split_features ] = g_idxs [ i ] [
idx_split_features] = g_idxs [ i ] [ tp_rank * idx_split_features : ( tp_rank + 1 ) *
tp_rank * idx_split_features : ( tp_rank + 1 ) * idx_split_features
idx_split_features ]
]
if cai_linear . bias is not None :
if cai_linear . bias is not None :
cai_linear . bias . copy_ ( gptq_linear . bias )
cai_linear . bias . copy_ ( gptq_linear . bias )
class RowCaiQuantLinear ( CaiQuantLinear , ParallelModule ) :
class RowCaiQuantLinear ( CaiQuantLinear , ParallelModule ) :
def __init__ ( self , bits , groupsize , infeatures , outfeatures , bias , tp_size = 1 , tp_rank = 0 , row_split = False ) :
def __init__ ( self , bits , groupsize , infeatures , outfeatures , bias , tp_size = 1 , tp_rank = 0 , row_split = False ) :
super ( ) . __init__ (
super ( ) . __init__ ( bits ,
bits , groupsize , infeatures , outfeatures , bias , tp_size = tp_size , tp_rank = tp_rank , row_split = row_split
groupsize ,
)
infeatures ,
outfeatures ,
bias ,
tp_size = tp_size ,
tp_rank = tp_rank ,
row_split = row_split )
self . process_group = None
self . process_group = None
@staticmethod
@staticmethod
def from_native_module ( module : nn . Module , process_group : Union [ ProcessGroup , List [ ProcessGroup ] ] , * args ,
def from_native_module (
* * kwargs ) - > ParallelModule :
module : nn . Module , process_group : Union [ ProcessGroup , List [ ProcessGroup ] ] , * args , * * kwargs
) - > ParallelModule :
LazyInitContext . materialize ( module )
LazyInitContext . materialize ( module )
# get the attributes
# get the attributes
in_features = module . in_features
in_features = module . in_features
# ensure only one process group is passed
# ensure only one process group is passed
if isinstance ( process_group , ( list , tuple ) ) :
if isinstance ( process_group , ( list , tuple ) ) :
assert len ( process_group ) == 1 , \
assert len ( process_group ) == 1 , f " Expected only one process group, got { len ( process_group ) } . "
f ' Expected only one process group, got { len ( process_group ) } . '
process_group = process_group [ 0 ]
process_group = process_group [ 0 ]
tp_size = dist . get_world_size ( process_group )
tp_size = dist . get_world_size ( process_group )
@ -282,15 +283,18 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0 :
if in_features % tp_size != 0 :
raise ValueError (
raise ValueError (
f " The size of in_features: { in_features } is not integer multiples of tensor parallel size: { tp_size } ! " )
f " The size of in_features: { in_features } is not integer multiples of tensor parallel size: { tp_size } ! "
linear_1d = RowCaiQuantLinear ( module . bits ,
)
linear_1d = RowCaiQuantLinear (
module . bits ,
module . group_size ,
module . group_size ,
module . in_features / / tp_size ,
module . in_features / / tp_size ,
module . out_features ,
module . out_features ,
module . bias is not None ,
module . bias is not None ,
tp_size = tp_size ,
tp_size = tp_size ,
tp_rank = tp_rank ,
tp_rank = tp_rank ,
row_split = True )
row_split = True ,
)
linear_1d . process_group = process_group
linear_1d . process_group = process_group
split_row_copy ( module , linear_1d , tp_rank = tp_rank , * * kwargs )
split_row_copy ( module , linear_1d , tp_rank = tp_rank , * * kwargs )
@ -306,30 +310,23 @@ class RowCaiQuantLinear(CaiQuantLinear, ParallelModule):
class ColCaiQuantLinear ( CaiQuantLinear , ParallelModule ) :
class ColCaiQuantLinear ( CaiQuantLinear , ParallelModule ) :
def __init__ ( self , bits , groupsize , infeatures , outfeatures , bias , tp_size = 1 , tp_rank = 0 , row_split = False ) :
def __init__ ( self , bits , groupsize , infeatures , outfeatures , bias , tp_size = 1 , tp_rank = 0 , row_split = False ) :
super ( ) . __init__ (
super ( ) . __init__ ( bits ,
bits , groupsize , infeatures , outfeatures , bias , tp_size = tp_size , tp_rank = tp_rank , row_split = row_split
groupsize ,
)
infeatures ,
outfeatures ,
bias ,
tp_size = tp_size ,
tp_rank = tp_rank ,
row_split = row_split )
self . process_group = None
self . process_group = None
@staticmethod
@staticmethod
def from_native_module ( module : nn . Module , process_group : Union [ ProcessGroup , List [ ProcessGroup ] ] , * args ,
def from_native_module (
* * kwargs ) - > ParallelModule :
module : nn . Module , process_group : Union [ ProcessGroup , List [ ProcessGroup ] ] , * args , * * kwargs
) - > ParallelModule :
LazyInitContext . materialize ( module )
LazyInitContext . materialize ( module )
# get the attributes
# get the attributes
in_features = module . in_features
in_features = module . in_features
# ensure only one process group is passed
# ensure only one process group is passed
if isinstance ( process_group , ( list , tuple ) ) :
if isinstance ( process_group , ( list , tuple ) ) :
assert len ( process_group ) == 1 , \
assert len ( process_group ) == 1 , f " Expected only one process group, got { len ( process_group ) } . "
f ' Expected only one process group, got { len ( process_group ) } . '
process_group = process_group [ 0 ]
process_group = process_group [ 0 ]
tp_size = dist . get_world_size ( process_group )
tp_size = dist . get_world_size ( process_group )
@ -340,14 +337,17 @@ class ColCaiQuantLinear(CaiQuantLinear, ParallelModule):
if in_features % tp_size != 0 :
if in_features % tp_size != 0 :
raise ValueError (
raise ValueError (
f " The size of in_features: { in_features } is not integer multiples of tensor parallel size: { tp_size } ! " )
f " The size of in_features: { in_features } is not integer multiples of tensor parallel size: { tp_size } ! "
linear_1d = ColCaiQuantLinear ( module . bits ,
)
linear_1d = ColCaiQuantLinear (
module . bits ,
module . group_size ,
module . group_size ,
module . in_features ,
module . in_features ,
module . out_features / / tp_size ,
module . out_features / / tp_size ,
module . bias is not None ,
module . bias is not None ,
tp_size = tp_size ,
tp_size = tp_size ,
tp_rank = tp_rank )
tp_rank = tp_rank ,
)
linear_1d . process_group = process_group
linear_1d . process_group = process_group
split_column_copy ( module , linear_1d , tp_rank = tp_rank , * * kwargs )
split_column_copy ( module , linear_1d , tp_rank = tp_rank , * * kwargs )