[tensor] fixed non-serializable colo parameter during model checkpointing (#1153)

pull/1164/head
Frank Lee 2022-06-22 11:43:38 +08:00 committed by GitHub
parent ffa025e120
commit f8eec98ff5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 39 additions and 3 deletions

View File

@ -1,13 +1,13 @@
from .utils import InsertPostInitMethodToModuleSubClasses
import torch
from colossalai.tensor import ColoTensor, ColoParameter
from colossalai.tensor import ColoTensor, ColoParameter, distspec, TensorSpec
from colossalai.nn.parallel.layers import register_colo_module, \
ColoLinear, ColoEmbedding
from copy import copy
from torch import nn
from typing import Iterator, Tuple, Union
from functools import partialmethod
# find named_params includes replica
@ -34,6 +34,38 @@ def ColoModulize(module):
module._colo_visited = True
def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_dict_func=None):
# build param to spec mapping
mapping = dict()
# gather all params
has_dist_parameter = False
with torch.no_grad():
for param in self.parameters():
if isinstance(param, ColoParameter) and param.has_spec():
has_dist_parameter = True
mapping[id(param)] = copy(param.spec)
param.set_spec(TensorSpec(distspec.replicate()))
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
# new tensors, but when keep_vars = True, the recovery of spec will be reflected
# in the `ret`, such that the final state dict will still contain process group,
# raising exception as it is not serializable
assert not (keep_vars and has_dist_parameter), 'keep_vars cannot be True when there are distributed ColoParameters.'
ret = state_dict_func(self, destination, prefix, keep_vars)
# recover
with torch.no_grad():
for param in self.parameters():
param_id = id(param)
if param_id in mapping:
spec = mapping[id(param)]
param.set_spec(spec)
return ret
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
def __init__(self, lazy_memory_allocate: bool = False, device: torch.device = torch.device('cpu')):
@ -52,6 +84,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
register_colo_module(torch.nn.Linear, ColoLinear())
register_colo_module(torch.nn.Embedding, ColoEmbedding())
def _pre_context_exec(self):
self.state_dict_func = nn.Module.state_dict
nn.Module.state_dict = partialmethod(colo_state_dict, state_dict_func=self.state_dict_func)
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.