mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added namespace constraints (#1490)
parent
a6c8749198
commit
d39e11dffb
|
@ -4,6 +4,8 @@ import torch
|
||||||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
|
|
||||||
|
__all__ = ['ConvHandler']
|
||||||
|
|
||||||
|
|
||||||
class ConvHandler(OperatorHandler):
|
class ConvHandler(OperatorHandler):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -4,6 +4,8 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
|
||||||
from .operator_handler import OperatorHandler
|
from .operator_handler import OperatorHandler
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
|
||||||
|
__all__ = ['DotHandler']
|
||||||
|
|
||||||
|
|
||||||
class DotHandler(OperatorHandler):
|
class DotHandler(OperatorHandler):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from webbrowser import Opera
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
@ -9,6 +10,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
||||||
|
|
||||||
from .sharding_strategy import StrategiesVector
|
from .sharding_strategy import StrategiesVector
|
||||||
|
|
||||||
|
__all__ = ['OperatorHandler']
|
||||||
|
|
||||||
|
|
||||||
class OperatorHandler(ABC):
|
class OperatorHandler(ABC):
|
||||||
'''
|
'''
|
||||||
|
@ -48,6 +51,9 @@ class OperatorHandler(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def register_strategy(self) -> StrategiesVector:
|
def register_strategy(self) -> StrategiesVector:
|
||||||
|
"""
|
||||||
|
Register
|
||||||
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
|
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
|
||||||
|
|
Loading…
Reference in New Issue