[utils] refactor parallel layers checkpoint and bcast model on loading checkpoint (#1548)

* refactor parallel layer

* broadcast rank0 model after load ckpt
pull/1534/head
ver217 2022-09-06 20:18:35 +08:00 committed by GitHub
parent 2bed096848
commit ae71036cd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 131 additions and 94 deletions

View File

@ -5,9 +5,11 @@ import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from contextlib import contextmanager
class ParallelLayer(nn.Module):
global_state_dict: bool = True
def __init__(self):
super().__init__()
@ -26,10 +28,35 @@ class ParallelLayer(nn.Module):
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
ParallelMode.PIPELINE)
def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
return super()._save_to_state_dict(destination, prefix, keep_vars)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
if gpc.get_local_rank(ParallelMode.TENSOR) != 0:
missing_keys.clear()
unexpected_keys.clear()
if self.global_state_dict:
if gpc.get_local_rank(ParallelMode.TENSOR) != 0:
missing_keys.clear()
unexpected_keys.clear()
return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
if self.global_state_dict:
return self._save_to_global_state_dict(destination, prefix, keep_vars)
return super()._save_to_state_dict(destination, prefix, keep_vars)
@classmethod
@contextmanager
def use_local_state_dict(cls):
try:
cls.global_state_dict = False
yield
finally:
cls.global_state_dict = True

View File

@ -189,7 +189,7 @@ class Classifier1D(ParallelLayer):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -215,9 +215,9 @@ class Classifier1D(ParallelLayer):
weight_key: True,
bias_key: False
})
super()._load_from_state_dict(local_state, prefix, *args)
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
@ -242,12 +242,12 @@ class Classifier1D(ParallelLayer):
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
@ -326,7 +326,7 @@ class VocabParallelClassifier1D(ParallelLayer):
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -352,9 +352,9 @@ class VocabParallelClassifier1D(ParallelLayer):
weight_key: True,
bias_key: True
})
super()._load_from_state_dict(local_state, prefix, *args)
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
@ -461,7 +461,7 @@ class Linear1D_Col(ParallelLayer):
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -486,9 +486,9 @@ class Linear1D_Col(ParallelLayer):
weight_key: True,
bias_key: True
})
super()._load_from_state_dict(local_state, prefix, *args)
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
@ -598,7 +598,7 @@ class Linear1D_Row(ParallelLayer):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
def _load_from_state_dict(self, state_dict, prefix, *args):
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -623,9 +623,9 @@ class Linear1D_Row(ParallelLayer):
weight_key: True,
bias_key: False
})
super()._load_from_state_dict(local_state, prefix, *args)
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
@ -648,12 +648,12 @@ class Linear1D_Row(ParallelLayer):
# Set up backprop all-reduce.
if self.parallel_input:
assert input_.shape[-1] == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1])
input_ = input_
else:
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
@ -738,7 +738,7 @@ class Embedding1D(ParallelLayer):
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args):
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@ -751,9 +751,9 @@ class Embedding1D(ParallelLayer):
ParallelMode.PARALLEL_1D,
dims={weight_key: -1},
partition_states={weight_key: True})
super()._load_from_state_dict(local_state, prefix, *args)
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,
@ -773,7 +773,7 @@ class Embedding1D(ParallelLayer):
@LAYERS.register_module
class VocabParallelEmbedding1D(torch.nn.Module):
class VocabParallelEmbedding1D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension.
Args:
@ -847,7 +847,7 @@ class VocabParallelEmbedding1D(torch.nn.Module):
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args):
def _load_from_global_state_dict(self, state_dict, prefix, *args):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@ -860,9 +860,9 @@ class VocabParallelEmbedding1D(torch.nn.Module):
ParallelMode.PARALLEL_1D,
dims={weight_key: 0},
partition_states={weight_key: True})
super()._load_from_state_dict(local_state, prefix, *args)
super()._load_from_global_state_dict(local_state, prefix, *args)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
local_state = gather_tensor_parallel_state_dict(local_state,

View File

@ -94,7 +94,7 @@ class Linear2D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -137,9 +137,9 @@ class Linear2D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
@ -252,7 +252,7 @@ class LayerNorm2D(ParallelLayer):
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -294,9 +294,9 @@ class LayerNorm2D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
@ -443,7 +443,7 @@ class PatchEmbedding2D(ParallelLayer):
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -503,9 +503,9 @@ class PatchEmbedding2D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
@ -651,7 +651,7 @@ class Embedding2D(ParallelLayer):
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@ -676,9 +676,9 @@ class Embedding2D(ParallelLayer):
partition_states={weight_key: True},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
@ -712,7 +712,7 @@ class Embedding2D(ParallelLayer):
@LAYERS.register_module
class VocabParallelEmbedding2D(torch.nn.Module):
class VocabParallelEmbedding2D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension.
Args:
@ -789,7 +789,7 @@ class VocabParallelEmbedding2D(torch.nn.Module):
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@ -814,9 +814,9 @@ class VocabParallelEmbedding2D(torch.nn.Module):
partition_states={weight_key: True},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
@ -924,7 +924,7 @@ class Classifier2D(ParallelLayer):
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL)
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -968,9 +968,9 @@ class Classifier2D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
@ -1095,7 +1095,7 @@ class VocabParallelClassifier2D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -1139,9 +1139,9 @@ class VocabParallelClassifier2D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()

View File

@ -96,7 +96,7 @@ class Linear2p5D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -143,9 +143,9 @@ class Linear2p5D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0:
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -314,9 +314,9 @@ class LayerNorm2p5D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
@ -463,7 +463,7 @@ class PatchEmbedding2p5D(ParallelLayer):
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -523,9 +523,9 @@ class PatchEmbedding2p5D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
@ -671,7 +671,7 @@ class Embedding2p5D(ParallelLayer):
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@ -696,9 +696,9 @@ class Embedding2p5D(ParallelLayer):
partition_states={weight_key: True},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
@ -733,7 +733,7 @@ class Embedding2p5D(ParallelLayer):
@LAYERS.register_module
class VocabParallelEmbedding2p5D(torch.nn.Module):
class VocabParallelEmbedding2p5D(ParallelLayer):
"""Embedding parallelized in the vocabulary dimension.
Args:
@ -810,7 +810,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@ -835,9 +835,9 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
partition_states={weight_key: True},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
@ -950,7 +950,7 @@ class Classifier2p5D(ParallelLayer):
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL)
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -994,9 +994,9 @@ class Classifier2p5D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
@ -1123,7 +1123,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -1167,7 +1167,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]

View File

@ -70,7 +70,7 @@ class LayerNorm3D(ParallelLayer):
if self.bias is not None:
init.zeros_()(self.bias)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -105,9 +105,9 @@ class LayerNorm3D(ParallelLayer):
# broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
@ -207,7 +207,7 @@ class Linear3D(ParallelLayer):
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -265,9 +265,9 @@ class Linear3D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
@ -400,7 +400,7 @@ class Classifier3D(ParallelLayer):
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -437,9 +437,9 @@ class Classifier3D(ParallelLayer):
# broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict()
@ -551,7 +551,7 @@ class VocabParallelClassifier3D(ParallelLayer):
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -610,9 +610,9 @@ class VocabParallelClassifier3D(ParallelLayer):
},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
local_state = OrderedDict({weight_key: self.weight})
@ -763,7 +763,7 @@ class PatchEmbedding3D(ParallelLayer):
self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
@ -812,9 +812,9 @@ class PatchEmbedding3D(ParallelLayer):
# broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
bias_key = prefix + 'bias'
cls_token_key = prefix + 'cls_token'
@ -937,7 +937,7 @@ class Embedding3D(ParallelLayer):
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@ -961,9 +961,9 @@ class Embedding3D(ParallelLayer):
# broadcast in weight groups
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})
@ -991,7 +991,7 @@ class Embedding3D(ParallelLayer):
@LAYERS.register_module
class VocabParallelEmbedding3D(torch.nn.Module):
class VocabParallelEmbedding3D(ParallelLayer):
r"""Embedding parallelized in the vocabulary dimension.
Args:
@ -1070,7 +1070,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
with torch.no_grad():
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
local_state = OrderedDict()
weight_key = prefix + 'weight'
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
@ -1104,9 +1104,9 @@ class VocabParallelEmbedding3D(torch.nn.Module):
partition_states={weight_key: True},
)
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
def _save_to_state_dict(self, destination, prefix, keep_vars):
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
weight_key = prefix + 'weight'
local_state = OrderedDict({weight_key: self.weight})

View File

@ -3,9 +3,9 @@ from itertools import chain
import torch
import torch.distributed as dist
from colossalai.communication.collective import scatter_object_list
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.constants import IS_TENSOR_PARALLEL
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
@ -190,6 +190,15 @@ def save_checkpoint(file,
torch.save(checkpoint, file, **kwargs)
def broadcast_model(model: torch.nn.Module):
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
for p in model.parameters():
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
ParallelMode.TENSOR)
dist.broadcast(p, src_rank, group=group)
def load_checkpoint(
file,
model: torch.nn.Module,
@ -225,6 +234,7 @@ def load_checkpoint(
model_state = partition_pipeline_parallel_state_dict(model, model_state)
try:
model.load_state_dict(model_state, strict=strict)
broadcast_model(model)
except RuntimeError as e:
error_msgs = str(e)
if error_msgs.startswith("Error(s) in loading state_dict for "):