diff --git a/colossalai/legacy/tensor/tensor_spec.py b/colossalai/legacy/tensor/tensor_spec.py index 5bdd384e5..44d8d04b9 100644 --- a/colossalai/legacy/tensor/tensor_spec.py +++ b/colossalai/legacy/tensor/tensor_spec.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from colossalai.legacy.tensor.distspec import DistPlacementPattern, _DistSpec @@ -17,5 +17,5 @@ class ColoTensorSpec: """ pg: ProcessGroup - dist_attr: Optional[_DistSpec] = _DistSpec(DistPlacementPattern.REPLICATE) + dist_attr: Optional[_DistSpec] = field(default_factory=lambda: _DistSpec(DistPlacementPattern.REPLICATE)) compute_attr: Optional[ComputeSpec] = None