[autoparallel] added namespace constraints (#1490)

pull/1493/head
Frank Lee 2022-08-24 15:44:07 +08:00 committed by GitHub
parent a6c8749198
commit d39e11dffb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 0 deletions

View File

@ -4,6 +4,8 @@ import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
__all__ = ['ConvHandler']
class ConvHandler(OperatorHandler):
"""

View File

@ -4,6 +4,8 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
from .operator_handler import OperatorHandler
from functools import reduce
__all__ = ['DotHandler']
class DotHandler(OperatorHandler):
"""

View File

@ -1,3 +1,4 @@
from webbrowser import Opera
import torch
import torch.nn as nn
from abc import ABC, abstractmethod
@ -9,6 +10,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec
from .sharding_strategy import StrategiesVector
__all__ = ['OperatorHandler']
class OperatorHandler(ABC):
'''
@ -48,6 +51,9 @@ class OperatorHandler(ABC):
@abstractmethod
def register_strategy(self) -> StrategiesVector:
"""
Register
"""
pass
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec: