mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
23 lines
658 B
23 lines
658 B
2 years ago
|
from dataclasses import dataclass
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from colossalai.device.device_mesh import DeviceMesh
|
||
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class Layout:
|
||
|
"""Layout of a tensor.
|
||
|
|
||
|
Attributes:
|
||
|
device_mesh: the device mesh to store the tensor distributedly.
|
||
|
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
|
||
|
sharding_spec: the sharding specification to describe how the tensor is sharded.
|
||
|
entire_shape: the entire shape of the global tensor.
|
||
|
"""
|
||
|
device_mesh: DeviceMesh
|
||
|
device_type: torch.device
|
||
|
sharding_spec: ShardingSpec
|
||
|
entire_shape: torch.Size = None
|