|
|
|
@ -1,14 +1,16 @@
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
|
|
from typing import Dict, List |
|
|
|
|
from webbrowser import Opera |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.nn as nn |
|
|
|
|
from abc import ABC, abstractmethod |
|
|
|
|
from torch.fx.node import Node |
|
|
|
|
from typing import Dict, List |
|
|
|
|
|
|
|
|
|
from colossalai.auto_parallel.tensor_shard.deprecated.constants import * |
|
|
|
|
from colossalai.device.device_mesh import DeviceMesh |
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec |
|
|
|
|
from .._utils import generate_resharding_costs, generate_sharding_spec |
|
|
|
|
from colossalai.auto_parallel.tensor_shard.deprecated.constants import * |
|
|
|
|
|
|
|
|
|
from .._utils import generate_resharding_costs, generate_sharding_spec |
|
|
|
|
from ..sharding_strategy import StrategiesVector |
|
|
|
|
|
|
|
|
|
__all__ = ['OperatorHandler'] |
|
|
|
@ -60,7 +62,7 @@ class OperatorHandler(ABC):
|
|
|
|
|
@abstractmethod |
|
|
|
|
def register_strategy(self) -> StrategiesVector: |
|
|
|
|
""" |
|
|
|
|
Register |
|
|
|
|
Register |
|
|
|
|
""" |
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|