fix import error in sharded model v2 (#1053)

pull/1063/head
ver217 2022-06-02 13:48:22 +08:00 committed by GitHub
parent e1922ea4f6
commit e3fde4ee6b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 1 deletions

View File

@ -2,7 +2,6 @@ import functools
from collections import OrderedDict
from typing import Any, Optional, Iterator, Tuple
from copy import deepcopy
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
import itertools
import torch
import torch.distributed as dist
@ -28,6 +27,11 @@ from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFacto
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor)
try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
except ImportError:
_EXTRA_STATE_KEY_SUFFIX = '_extra_state'
class ShardedModelV2(nn.Module):
"""