From e3fde4ee6bed24406899c5737e96006fd9e77925 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 2 Jun 2022 13:48:22 +0800 Subject: [PATCH] fix import error in sharded model v2 (#1053) --- colossalai/zero/sharded_model/sharded_model_v2.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 5f087ecab..d61ea5373 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -2,7 +2,6 @@ import functools from collections import OrderedDict from typing import Any, Optional, Iterator, Tuple from copy import deepcopy -from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX import itertools import torch import torch.distributed as dist @@ -28,6 +27,11 @@ from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFacto from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) +try: + from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX +except ImportError: + _EXTRA_STATE_KEY_SUFFIX = '_extra_state' + class ShardedModelV2(nn.Module): """