mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] added namespace constraints (#1490)
parent
a6c8749198
commit
d39e11dffb
|
@ -4,6 +4,8 @@ import torch
|
|||
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
|
||||
from .operator_handler import OperatorHandler
|
||||
|
||||
__all__ = ['ConvHandler']
|
||||
|
||||
|
||||
class ConvHandler(OperatorHandler):
|
||||
"""
|
||||
|
|
|
@ -4,6 +4,8 @@ from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy,
|
|||
from .operator_handler import OperatorHandler
|
||||
from functools import reduce
|
||||
|
||||
__all__ = ['DotHandler']
|
||||
|
||||
|
||||
class DotHandler(OperatorHandler):
|
||||
"""
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from webbrowser import Opera
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from abc import ABC, abstractmethod
|
||||
|
@ -9,6 +10,8 @@ from colossalai.tensor.sharding_spec import ShardingSpec
|
|||
|
||||
from .sharding_strategy import StrategiesVector
|
||||
|
||||
__all__ = ['OperatorHandler']
|
||||
|
||||
|
||||
class OperatorHandler(ABC):
|
||||
'''
|
||||
|
@ -48,6 +51,9 @@ class OperatorHandler(ABC):
|
|||
|
||||
@abstractmethod
|
||||
def register_strategy(self) -> StrategiesVector:
|
||||
"""
|
||||
Register
|
||||
"""
|
||||
pass
|
||||
|
||||
def _generate_sharding_spec(self, tensor: torch.Tensor, dim_partition_dict: Dict[int, int]) -> ShardingSpec:
|
||||
|
|
Loading…
Reference in New Issue