From ae71036cd2210b6e60805357f4bd059674e316bc Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 6 Sep 2022 20:18:35 +0800 Subject: [PATCH] [utils] refactor parallel layers checkpoint and bcast model on loading checkpoint (#1548) * refactor parallel layer * broadcast rank0 model after load ckpt --- colossalai/nn/layer/base_layer.py | 37 ++++++++++++++--- colossalai/nn/layer/parallel_1d/layers.py | 46 ++++++++++----------- colossalai/nn/layer/parallel_2d/layers.py | 44 ++++++++++---------- colossalai/nn/layer/parallel_2p5d/layers.py | 42 +++++++++---------- colossalai/nn/layer/parallel_3d/layers.py | 44 ++++++++++---------- colossalai/utils/checkpointing.py | 12 +++++- 6 files changed, 131 insertions(+), 94 deletions(-) diff --git a/colossalai/nn/layer/base_layer.py b/colossalai/nn/layer/base_layer.py index 041f0fdf0..c85f53cc4 100644 --- a/colossalai/nn/layer/base_layer.py +++ b/colossalai/nn/layer/base_layer.py @@ -5,9 +5,11 @@ import torch.nn as nn from colossalai.context import ParallelMode from colossalai.core import global_context as gpc +from contextlib import contextmanager class ParallelLayer(nn.Module): + global_state_dict: bool = True def __init__(self): super().__init__() @@ -26,10 +28,35 @@ class ParallelLayer(nn.Module): self.pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size( ParallelMode.PIPELINE) + def _load_from_global_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + + def _save_to_global_state_dict(self, destination, prefix, keep_vars): + return super()._save_to_state_dict(destination, prefix, keep_vars) + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, - error_msgs) - if gpc.get_local_rank(ParallelMode.TENSOR) != 0: - missing_keys.clear() - unexpected_keys.clear() + if self.global_state_dict: + if gpc.get_local_rank(ParallelMode.TENSOR) != 0: + missing_keys.clear() + unexpected_keys.clear() + return self._load_from_global_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs) + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs) + + def _save_to_state_dict(self, destination, prefix, keep_vars): + if self.global_state_dict: + return self._save_to_global_state_dict(destination, prefix, keep_vars) + return super()._save_to_state_dict(destination, prefix, keep_vars) + + @classmethod + @contextmanager + def use_local_state_dict(cls): + try: + cls.global_state_dict = False + yield + finally: + cls.global_state_dict = True diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index d2a466e4c..7b89c5e1f 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -189,7 +189,7 @@ class Classifier1D(ParallelLayer): num_partition = gpc.get_world_size(ParallelMode.TENSOR) set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - def _load_from_state_dict(self, state_dict, prefix, *args): + def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -215,9 +215,9 @@ class Classifier1D(ParallelLayer): weight_key: True, bias_key: False }) - super()._load_from_state_dict(local_state, prefix, *args) + super()._load_from_global_state_dict(local_state, prefix, *args) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict() @@ -242,12 +242,12 @@ class Classifier1D(ParallelLayer): # Set up backprop all-reduce. if self.parallel_input: assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) input_ = input_ else: assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( + 'Invalid shapes in Classifier1D forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) @@ -326,7 +326,7 @@ class VocabParallelClassifier1D(ParallelLayer): if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - def _load_from_state_dict(self, state_dict, prefix, *args): + def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -352,9 +352,9 @@ class VocabParallelClassifier1D(ParallelLayer): weight_key: True, bias_key: True }) - super()._load_from_state_dict(local_state, prefix, *args) + super()._load_from_global_state_dict(local_state, prefix, *args) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict() @@ -461,7 +461,7 @@ class Linear1D_Col(ParallelLayer): if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, num_partition) - def _load_from_state_dict(self, state_dict, prefix, *args): + def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -486,9 +486,9 @@ class Linear1D_Col(ParallelLayer): weight_key: True, bias_key: True }) - super()._load_from_state_dict(local_state, prefix, *args) + super()._load_from_global_state_dict(local_state, prefix, *args) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) @@ -598,7 +598,7 @@ class Linear1D_Row(ParallelLayer): num_partition = gpc.get_world_size(ParallelMode.TENSOR) set_tensor_parallel_attribute_by_partition(self.weight, num_partition) - def _load_from_state_dict(self, state_dict, prefix, *args): + def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -623,9 +623,9 @@ class Linear1D_Row(ParallelLayer): weight_key: True, bias_key: False }) - super()._load_from_state_dict(local_state, prefix, *args) + super()._load_from_global_state_dict(local_state, prefix, *args) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) @@ -648,12 +648,12 @@ class Linear1D_Row(ParallelLayer): # Set up backprop all-reduce. if self.parallel_input: assert input_.shape[-1] == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1]) input_ = input_ else: assert divide(input_.shape[-1], gpc.tensor_parallel_size) == self.weight.shape[-1], \ - 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( + 'Invalid shapes in Linear1D_Row forward: input={}, weight={}. Expected last dim of input {}.'.format( input_.shape, self.weight.shape, self.weight.shape[-1] * gpc.tensor_parallel_size) input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) @@ -738,7 +738,7 @@ class Embedding1D(ParallelLayer): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) - def _load_from_state_dict(self, state_dict, prefix, *args): + def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -751,9 +751,9 @@ class Embedding1D(ParallelLayer): ParallelMode.PARALLEL_1D, dims={weight_key: -1}, partition_states={weight_key: True}) - super()._load_from_state_dict(local_state, prefix, *args) + super()._load_from_global_state_dict(local_state, prefix, *args) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) local_state = gather_tensor_parallel_state_dict(local_state, @@ -773,7 +773,7 @@ class Embedding1D(ParallelLayer): @LAYERS.register_module -class VocabParallelEmbedding1D(torch.nn.Module): +class VocabParallelEmbedding1D(ParallelLayer): r"""Embedding parallelized in the vocabulary dimension. Args: @@ -847,7 +847,7 @@ class VocabParallelEmbedding1D(torch.nn.Module): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - def _load_from_state_dict(self, state_dict, prefix, *args): + def _load_from_global_state_dict(self, state_dict, prefix, *args): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -860,9 +860,9 @@ class VocabParallelEmbedding1D(torch.nn.Module): ParallelMode.PARALLEL_1D, dims={weight_key: 0}, partition_states={weight_key: True}) - super()._load_from_state_dict(local_state, prefix, *args) + super()._load_from_global_state_dict(local_state, prefix, *args) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) local_state = gather_tensor_parallel_state_dict(local_state, diff --git a/colossalai/nn/layer/parallel_2d/layers.py b/colossalai/nn/layer/parallel_2d/layers.py index 5fc5c63e5..f3a4d2bbb 100644 --- a/colossalai/nn/layer/parallel_2d/layers.py +++ b/colossalai/nn/layer/parallel_2d/layers.py @@ -94,7 +94,7 @@ class Linear2D(ParallelLayer): if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -137,9 +137,9 @@ class Linear2D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) @@ -252,7 +252,7 @@ class LayerNorm2D(ParallelLayer): if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -294,9 +294,9 @@ class LayerNorm2D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) @@ -443,7 +443,7 @@ class PatchEmbedding2D(ParallelLayer): bias_initializer(self.bias, fan_in=fan_in) position_embed_initializer(self.pos_embed) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -503,9 +503,9 @@ class PatchEmbedding2D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' cls_token_key = prefix + 'cls_token' @@ -651,7 +651,7 @@ class Embedding2D(ParallelLayer): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -676,9 +676,9 @@ class Embedding2D(ParallelLayer): partition_states={weight_key: True}, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) @@ -712,7 +712,7 @@ class Embedding2D(ParallelLayer): @LAYERS.register_module -class VocabParallelEmbedding2D(torch.nn.Module): +class VocabParallelEmbedding2D(ParallelLayer): r"""Embedding parallelized in the vocabulary dimension. Args: @@ -789,7 +789,7 @@ class VocabParallelEmbedding2D(torch.nn.Module): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -814,9 +814,9 @@ class VocabParallelEmbedding2D(torch.nn.Module): partition_states={weight_key: True}, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) @@ -924,7 +924,7 @@ class Classifier2D(ParallelLayer): broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2D_COL) broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2D_ROW) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -968,9 +968,9 @@ class Classifier2D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict() @@ -1095,7 +1095,7 @@ class VocabParallelClassifier2D(ParallelLayer): if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -1139,9 +1139,9 @@ class VocabParallelClassifier2D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict() diff --git a/colossalai/nn/layer/parallel_2p5d/layers.py b/colossalai/nn/layer/parallel_2p5d/layers.py index f26efcc61..f849cbbe7 100644 --- a/colossalai/nn/layer/parallel_2p5d/layers.py +++ b/colossalai/nn/layer/parallel_2p5d/layers.py @@ -96,7 +96,7 @@ class Linear2p5D(ParallelLayer): if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -143,9 +143,9 @@ class Linear2p5D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): if gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP) == 0: weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -272,7 +272,7 @@ class LayerNorm2p5D(ParallelLayer): if self.bias is not None: set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -314,9 +314,9 @@ class LayerNorm2p5D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) @@ -463,7 +463,7 @@ class PatchEmbedding2p5D(ParallelLayer): bias_initializer(self.bias, fan_in=fan_in) position_embed_initializer(self.pos_embed) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -523,9 +523,9 @@ class PatchEmbedding2p5D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' cls_token_key = prefix + 'cls_token' @@ -671,7 +671,7 @@ class Embedding2p5D(ParallelLayer): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -696,9 +696,9 @@ class Embedding2p5D(ParallelLayer): partition_states={weight_key: True}, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) @@ -733,7 +733,7 @@ class Embedding2p5D(ParallelLayer): @LAYERS.register_module -class VocabParallelEmbedding2p5D(torch.nn.Module): +class VocabParallelEmbedding2p5D(ParallelLayer): """Embedding parallelized in the vocabulary dimension. Args: @@ -810,7 +810,7 @@ class VocabParallelEmbedding2p5D(torch.nn.Module): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -835,9 +835,9 @@ class VocabParallelEmbedding2p5D(torch.nn.Module): partition_states={weight_key: True}, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) @@ -950,7 +950,7 @@ class Classifier2p5D(ParallelLayer): broadcast(self.bias, col_src_rank, ParallelMode.PARALLEL_2P5D_COL) broadcast(self.bias, row_src_rank, ParallelMode.PARALLEL_2P5D_ROW) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -994,9 +994,9 @@ class Classifier2p5D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict() @@ -1123,7 +1123,7 @@ class VocabParallelClassifier2p5D(ParallelLayer): if self.bias is not None: bias_initializer(self.bias, fan_in=fan_in) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -1167,7 +1167,7 @@ class VocabParallelClassifier2p5D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) def forward(self, x: Tensor) -> Tensor: # input: [m/dq, n/q, k/q] diff --git a/colossalai/nn/layer/parallel_3d/layers.py b/colossalai/nn/layer/parallel_3d/layers.py index 33f358241..037a09763 100644 --- a/colossalai/nn/layer/parallel_3d/layers.py +++ b/colossalai/nn/layer/parallel_3d/layers.py @@ -70,7 +70,7 @@ class LayerNorm3D(ParallelLayer): if self.bias is not None: init.zeros_()(self.bias) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -105,9 +105,9 @@ class LayerNorm3D(ParallelLayer): # broadcast in weight groups local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) @@ -207,7 +207,7 @@ class Linear3D(ParallelLayer): broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -265,9 +265,9 @@ class Linear3D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) @@ -400,7 +400,7 @@ class Classifier3D(ParallelLayer): broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -437,9 +437,9 @@ class Classifier3D(ParallelLayer): # broadcast in weight groups local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict() @@ -551,7 +551,7 @@ class VocabParallelClassifier3D(ParallelLayer): broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -610,9 +610,9 @@ class VocabParallelClassifier3D(ParallelLayer): }, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' local_state = OrderedDict({weight_key: self.weight}) @@ -763,7 +763,7 @@ class PatchEmbedding3D(ParallelLayer): self.cls_token.register_hook(self._sync_grad_hook) self.pos_embed.register_hook(self._sync_grad_hook) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' bias_key = prefix + 'bias' @@ -812,9 +812,9 @@ class PatchEmbedding3D(ParallelLayer): # broadcast in weight groups local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' bias_key = prefix + 'bias' cls_token_key = prefix + 'cls_token' @@ -937,7 +937,7 @@ class Embedding3D(ParallelLayer): with torch.no_grad(): self.weight[self.padding_idx].fill_(0) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -961,9 +961,9 @@ class Embedding3D(ParallelLayer): # broadcast in weight groups local_state = broadcast_state_dict(local_state, self.weight_parallel_mode) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) @@ -991,7 +991,7 @@ class Embedding3D(ParallelLayer): @LAYERS.register_module -class VocabParallelEmbedding3D(torch.nn.Module): +class VocabParallelEmbedding3D(ParallelLayer): r"""Embedding parallelized in the vocabulary dimension. Args: @@ -1070,7 +1070,7 @@ class VocabParallelEmbedding3D(torch.nn.Module): with torch.no_grad(): self.weight[self.padding_idx - self.vocab_start_index].fill_(0) - def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): + def _load_from_global_state_dict(self, state_dict, prefix, *args, **kwargs): local_state = OrderedDict() weight_key = prefix + 'weight' if gpc.get_local_rank(ParallelMode.TENSOR) == 0: @@ -1104,9 +1104,9 @@ class VocabParallelEmbedding3D(torch.nn.Module): partition_states={weight_key: True}, ) - super()._load_from_state_dict(local_state, prefix, *args, **kwargs) + super()._load_from_global_state_dict(local_state, prefix, *args, **kwargs) - def _save_to_state_dict(self, destination, prefix, keep_vars): + def _save_to_global_state_dict(self, destination, prefix, keep_vars): weight_key = prefix + 'weight' local_state = OrderedDict({weight_key: self.weight}) diff --git a/colossalai/utils/checkpointing.py b/colossalai/utils/checkpointing.py index 2ce959568..d1c6b6370 100644 --- a/colossalai/utils/checkpointing.py +++ b/colossalai/utils/checkpointing.py @@ -3,9 +3,9 @@ from itertools import chain import torch import torch.distributed as dist -from colossalai.communication.collective import scatter_object_list from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc +from colossalai.constants import IS_TENSOR_PARALLEL try: from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX except ImportError: @@ -190,6 +190,15 @@ def save_checkpoint(file, torch.save(checkpoint, file, **kwargs) +def broadcast_model(model: torch.nn.Module): + src_rank = gpc.get_ranks_in_group(ParallelMode.TENSOR)[0] + for p in model.parameters(): + if not getattr(p, IS_TENSOR_PARALLEL, False) and p.storage().size() > 0: + group = gpc.get_group(ParallelMode.TENSOR) if p.device.type == 'cuda' else gpc.get_cpu_group( + ParallelMode.TENSOR) + dist.broadcast(p, src_rank, group=group) + + def load_checkpoint( file, model: torch.nn.Module, @@ -225,6 +234,7 @@ def load_checkpoint( model_state = partition_pipeline_parallel_state_dict(model, model_state) try: model.load_state_dict(model_state, strict=strict) + broadcast_model(model) except RuntimeError as e: error_msgs = str(e) if error_msgs.startswith("Error(s) in loading state_dict for "):