mirror of https://github.com/hpcaitech/ColossalAI
[polish] add update directory in gemini; rename AgChunk to ChunkV2 (#1432)
parent
f20cb4e893
commit
039b7ed3bc
|
@ -0,0 +1 @@
|
|||
from .chunkv2 import ChunkV2
|
|
@ -8,7 +8,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF
|
|||
free_storage, alloc_storage
|
||||
|
||||
|
||||
class AgChunk:
|
||||
class ChunkV2:
|
||||
def __init__(self,
|
||||
chunk_size: int,
|
||||
process_group: ColoProcessGroup,
|
|
@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device
|
|||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from colossalai.tensor import ColoParameter
|
||||
from colossalai.gemini import TensorState
|
||||
from colossalai.gemini.ag_chunk import AgChunk
|
||||
from colossalai.gemini.update import ChunkV2
|
||||
|
||||
|
||||
def dist_sum(x):
|
||||
|
@ -38,7 +38,7 @@ def check_euqal(param, param_cp):
|
|||
def exam_chunk_basic(init_device, keep_gathered, pin_memory):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ColoProcessGroup()
|
||||
my_chunk = AgChunk(
|
||||
my_chunk = ChunkV2(
|
||||
chunk_size=1024,
|
||||
process_group=pg,
|
||||
dtype=torch.float32,
|
Loading…
Reference in New Issue