mirror of https://github.com/hpcaitech/ColossalAI
[utils] refactor parallel layers checkpoint and bcast model on loading checkpoint (#1548)
* refactor parallel layer * broadcast rank0 model after load ckptpull/1534/head
parent
2bed096848
commit
ae71036cd2
|
@ -5,9 +5,11 @@ import torch.nn as nn
|
|||
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class ParallelLayer(nn.Module):
|
||||
global_state_dict: bool = True
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
@ -26,10 +28,35 @@ class ParallelLayer(nn.Module):
|
|||
self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||||
ParallelMode.PIPELINE)
|
||||
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
return super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs):
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||
missing_keys.clear()
|
||||
unexpected_keys.clear()
|
||||
if self.global_state_dict:
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) != 0:
|
||||
missing_keys.clear()
|
||||
unexpected_keys.clear()
|
||||
return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys,
|
||||
unexpected_keys, error_msgs)
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
|
||||
error_msgs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
if self.global_state_dict:
|
||||
return self._save_to_global_state_dict(destination, prefix, keep_vars)
|
||||
return super()._save_to_state_dict(destination, prefix, keep_vars)
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def use_local_state_dict(cls):
|
||||
try:
|
||||
cls.global_state_dict = False
|
||||
yield
|
||||
finally:
|
||||
cls.global_state_dict = True
|
||||
|
|
|
@ -189,7 +189,7 @@ class Classifier1D(ParallelLayer):
|
|||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -215,9 +215,9 @@ class Classifier1D(ParallelLayer):
|
|||
weight_key: True,
|
||||
bias_key: False
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
|
@ -242,12 +242,12 @@ class Classifier1D(ParallelLayer):
|
|||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
|
@ -326,7 +326,7 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -352,9 +352,9 @@ class VocabParallelClassifier1D(ParallelLayer):
|
|||
weight_key: True,
|
||||
bias_key: True
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
|
@ -461,7 +461,7 @@ class Linear1D_Col(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -486,9 +486,9 @@ class Linear1D_Col(ParallelLayer):
|
|||
weight_key: True,
|
||||
bias_key: True
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
@ -598,7 +598,7 @@ class Linear1D_Row(ParallelLayer):
|
|||
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -623,9 +623,9 @@ class Linear1D_Row(ParallelLayer):
|
|||
weight_key: True,
|
||||
bias_key: False
|
||||
})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
@ -648,12 +648,12 @@ class Linear1D_Row(ParallelLayer):
|
|||
# Set up backprop all-reduce.
|
||||
if self.parallel_input:
|
||||
assert input_.shape[-1] == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1])
|
||||
input_ = input_
|
||||
else:
|
||||
assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format(
|
||||
input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size)
|
||||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
|
@ -738,7 +738,7 @@ class Embedding1D(ParallelLayer):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
@ -751,9 +751,9 @@ class Embedding1D(ParallelLayer):
|
|||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: -1},
|
||||
partition_states={weight_key: True})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
|
@ -773,7 +773,7 @@ class Embedding1D(ParallelLayer):
|
|||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VocabParallelEmbedding1D(torch.nn.Module):
|
||||
class VocabParallelEmbedding1D(ParallelLayer):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
|
@ -847,7 +847,7 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
@ -860,9 +860,9 @@ class VocabParallelEmbedding1D(torch.nn.Module):
|
|||
ParallelMode.PARALLEL_1D,
|
||||
dims={weight_key: 0},
|
||||
partition_states={weight_key: True})
|
||||
super()._load_from_state_dict(local_state, prefix, *args)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
local_state = gather_tensor_parallel_state_dict(local_state,
|
||||
|
|
|
@ -94,7 +94,7 @@ class Linear2D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -137,9 +137,9 @@ class Linear2D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
@ -252,7 +252,7 @@ class LayerNorm2D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -294,9 +294,9 @@ class LayerNorm2D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
@ -443,7 +443,7 @@ class PatchEmbedding2D(ParallelLayer):
|
|||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -503,9 +503,9 @@ class PatchEmbedding2D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
|
@ -651,7 +651,7 @@ class Embedding2D(ParallelLayer):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
@ -676,9 +676,9 @@ class Embedding2D(ParallelLayer):
|
|||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
|
@ -712,7 +712,7 @@ class Embedding2D(ParallelLayer):
|
|||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VocabParallelEmbedding2D(torch.nn.Module):
|
||||
class VocabParallelEmbedding2D(ParallelLayer):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
|
@ -789,7 +789,7 @@ class VocabParallelEmbedding2D(torch.nn.Module):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
@ -814,9 +814,9 @@ class VocabParallelEmbedding2D(torch.nn.Module):
|
|||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
|
@ -924,7 +924,7 @@ class Classifier2D(ParallelLayer):
|
|||
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL)
|
||||
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -968,9 +968,9 @@ class Classifier2D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
|
@ -1095,7 +1095,7 @@ class VocabParallelClassifier2D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -1139,9 +1139,9 @@ class VocabParallelClassifier2D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
|
|
|
@ -96,7 +96,7 @@ class Linear2p5D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -143,9 +143,9 @@ class Linear2p5D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0:
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -314,9 +314,9 @@ class LayerNorm2p5D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
@ -463,7 +463,7 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
position_embed_initializer(self.pos_embed)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -523,9 +523,9 @@ class PatchEmbedding2p5D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
|
@ -671,7 +671,7 @@ class Embedding2p5D(ParallelLayer):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
@ -696,9 +696,9 @@ class Embedding2p5D(ParallelLayer):
|
|||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
|
@ -733,7 +733,7 @@ class Embedding2p5D(ParallelLayer):
|
|||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VocabParallelEmbedding2p5D(torch.nn.Module):
|
||||
class VocabParallelEmbedding2p5D(ParallelLayer):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
|
@ -810,7 +810,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
@ -835,9 +835,9 @@ class VocabParallelEmbedding2p5D(torch.nn.Module):
|
|||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
|
@ -950,7 +950,7 @@ class Classifier2p5D(ParallelLayer):
|
|||
broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL)
|
||||
broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -994,9 +994,9 @@ class Classifier2p5D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
|
@ -1123,7 +1123,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -1167,7 +1167,7 @@ class VocabParallelClassifier2p5D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
# input: [m/dq, n/q, k/q]
|
||||
|
|
|
@ -70,7 +70,7 @@ class LayerNorm3D(ParallelLayer):
|
|||
if self.bias is not None:
|
||||
init.zeros_()(self.bias)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -105,9 +105,9 @@ class LayerNorm3D(ParallelLayer):
|
|||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
@ -207,7 +207,7 @@ class Linear3D(ParallelLayer):
|
|||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -265,9 +265,9 @@ class Linear3D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
@ -400,7 +400,7 @@ class Classifier3D(ParallelLayer):
|
|||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
broadcast(self.bias, input_src_rank, self.input_parallel_mode)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -437,9 +437,9 @@ class Classifier3D(ParallelLayer):
|
|||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict()
|
||||
|
@ -551,7 +551,7 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
|
||||
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -610,9 +610,9 @@ class VocabParallelClassifier3D(ParallelLayer):
|
|||
},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
@ -763,7 +763,7 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
self.cls_token.register_hook(self._sync_grad_hook)
|
||||
self.pos_embed.register_hook(self._sync_grad_hook)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
|
@ -812,9 +812,9 @@ class PatchEmbedding3D(ParallelLayer):
|
|||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
bias_key = prefix + 'bias'
|
||||
cls_token_key = prefix + 'cls_token'
|
||||
|
@ -937,7 +937,7 @@ class Embedding3D(ParallelLayer):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
@ -961,9 +961,9 @@ class Embedding3D(ParallelLayer):
|
|||
# broadcast in weight groups
|
||||
local_state = broadcast_state_dict(local_state, self.weight_parallel_mode)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
|
@ -991,7 +991,7 @@ class Embedding3D(ParallelLayer):
|
|||
|
||||
|
||||
@LAYERS.register_module
|
||||
class VocabParallelEmbedding3D(torch.nn.Module):
|
||||
class VocabParallelEmbedding3D(ParallelLayer):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
|
@ -1070,7 +1070,7 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
|||
with torch.no_grad():
|
||||
self.weight[self.padding_idx - self.vocab_start_index].fill_(0)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
local_state = OrderedDict()
|
||||
weight_key = prefix + 'weight'
|
||||
if gpc.get_local_rank(ParallelMode.TENSOR) == 0:
|
||||
|
@ -1104,9 +1104,9 @@ class VocabParallelEmbedding3D(torch.nn.Module):
|
|||
partition_states={weight_key: True},
|
||||
)
|
||||
|
||||
super()._load_from_state_dict(local_state, prefix, *args, **kwargs)
|
||||
super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs)
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
def _save_to_global_state_dict(self, destination, prefix, keep_vars):
|
||||
weight_key = prefix + 'weight'
|
||||
local_state = OrderedDict({weight_key: self.weight})
|
||||
|
||||
|
|
|
@ -3,9 +3,9 @@ from itertools import chain
|
|||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.communication.collective import scatter_object_list
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.constants import IS_TENSOR_PARALLEL
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
|
@ -190,6 +190,15 @@ def save_checkpoint(file,
|
|||
torch.save(checkpoint, file, **kwargs)
|
||||
|
||||
|
||||
def broadcast_model(model: torch.nn.Module):
|
||||
src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0]
|
||||
for p in model.parameters():
|
||||
if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0:
|
||||
group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group(
|
||||
ParallelMode.TENSOR)
|
||||
dist.broadcast(p, src_rank, group=group)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
file,
|
||||
model: torch.nn.Module,
|
||||
|
@ -225,6 +234,7 @@ def load_checkpoint(
|
|||
model_state = partition_pipeline_parallel_state_dict(model, model_state)
|
||||
try:
|
||||
model.load_state_dict(model_state, strict=strict)
|
||||
broadcast_model(model)
|
||||
except RuntimeError as e:
|
||||
error_msgs = str(e)
|
||||
if error_msgs.startswith("Error(s) in loading state_dict for "):
|
||||
|
|
Loading…
Reference in New Issue