[polish] add update directory in gemini; rename AgChunk to ChunkV2 (#1432)

pull/1437/head
HELSON 2022-08-10 16:40:29 +08:00 committed by GitHub
parent f20cb4e893
commit 039b7ed3bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 4 additions and 3 deletions

View File

@ -0,0 +1 @@
from .chunkv2 import ChunkV2

View File

@ -8,7 +8,7 @@ from colossalai.gemini.chunk import TensorState, STATE_TRANS, TensorInfo, ChunkF
free_storage, alloc_storage free_storage, alloc_storage
class AgChunk: class ChunkV2:
def __init__(self, def __init__(self,
chunk_size: int, chunk_size: int,
process_group: ColoProcessGroup, process_group: ColoProcessGroup,

View File

@ -9,7 +9,7 @@ from colossalai.utils import free_port, get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from colossalai.gemini import TensorState from colossalai.gemini import TensorState
from colossalai.gemini.ag_chunk import AgChunk from colossalai.gemini.update import ChunkV2
def dist_sum(x): def dist_sum(x):
@ -38,7 +38,7 @@ def check_euqal(param, param_cp):
def exam_chunk_basic(init_device, keep_gathered, pin_memory): def exam_chunk_basic(init_device, keep_gathered, pin_memory):
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
pg = ColoProcessGroup() pg = ColoProcessGroup()
my_chunk = AgChunk( my_chunk = ChunkV2(
chunk_size=1024, chunk_size=1024,
process_group=pg, process_group=pg,
dtype=torch.float32, dtype=torch.float32,