mirror of https://github.com/hpcaitech/ColossalAI
[refactor] remove gpc dependency in colotensor's _ops (#1189)
parent
abf6a262dc
commit
060b917daf
|
@ -1,6 +1,12 @@
|
|||
import torch
|
||||
from typing import Union, Optional
|
||||
from colossalai.tensor import ColoTensor
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.global_variables import tensor_parallel_env as env
|
||||
|
||||
from colossalai.nn.layer.utils import divide
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
GeneralTensor = Union[ColoTensor, torch.Tensor]
|
||||
Number = Union[int, float]
|
||||
|
@ -10,3 +16,182 @@ def convert_to_colo_tensor(tensor: Optional[GeneralTensor]) -> Optional[ColoTens
|
|||
if tensor is not None and not isinstance(tensor, ColoTensor):
|
||||
tensor = ColoTensor.from_torch_tensor(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
def set_parallel_input(input_parallel: bool):
|
||||
env.parallel_input_1d = input_parallel
|
||||
|
||||
|
||||
def get_parallel_input():
|
||||
return env.parallel_input_1d
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size, rank, world_size):
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank)
|
||||
|
||||
|
||||
def _reduce(input_, pg: ProcessGroup):
|
||||
# skip if only one rank involved
|
||||
if pg.tp_world_size() == 1:
|
||||
return input_
|
||||
assert input_.device.type == 'cuda'
|
||||
group = pg.tp_process_group()
|
||||
dist.all_reduce(input_, group=group)
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _split(input_, pg: ProcessGroup, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = pg.tp_world_size()
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
dim_size = input_.size(dim)
|
||||
assert dim_size % world_size == 0, \
|
||||
f'The dimension to split ({dim_size}) is not a multiple of world size ({world_size}), ' \
|
||||
f'cannot split tensor evenly'
|
||||
|
||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||
rank = pg.tp_local_rank()
|
||||
output = tensor_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather(input_, pg: ProcessGroup, dim=-1):
|
||||
# skip if only one rank involved
|
||||
world_size = pg.tp_world_size()
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# all gather
|
||||
rank = pg.tp_local_rank()
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
assert input_.device.type == 'cuda'
|
||||
group = pg.tp_process_group()
|
||||
torch.distributed.all_gather(tensor_list, input_, group=group)
|
||||
|
||||
# concat
|
||||
output = torch.cat(tensor_list, dim=dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class _ReduceGrad(torch.autograd.Function):
|
||||
"""
|
||||
Pass the input to the model parallel region.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
ctx.mode = process_group
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output, ctx.mode), None
|
||||
|
||||
|
||||
class _ReduceInput(torch.autograd.Function):
|
||||
"""
|
||||
All-reduce the input from the model parallel region.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group):
|
||||
return _reduce(input_, process_group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
|
||||
class _SplitForwardGatherBackward(torch.autograd.Function):
|
||||
"""
|
||||
Split the input and keep only the corresponding chuck to the rank.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
ctx.mode = process_group
|
||||
ctx.dim = dim
|
||||
return _split(input_, process_group, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatenate.
|
||||
|
||||
Args:
|
||||
input_: input matrix.
|
||||
process_group: parallel mode.
|
||||
dim: dimension
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, process_group, dim):
|
||||
ctx.mode = process_group
|
||||
ctx.dim = dim
|
||||
return _gather(input_, process_group, dim)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split(grad_output, ctx.mode, ctx.dim), None, None
|
||||
|
||||
|
||||
def reduce_grad(input_, process_group):
|
||||
return _ReduceGrad.apply(input_, process_group)
|
||||
|
||||
|
||||
def reduce_input(input_, process_group):
|
||||
return _ReduceInput.apply(input_, process_group)
|
||||
|
||||
|
||||
def split_forward_gather_backward(input_, process_group, dim):
|
||||
return _SplitForwardGatherBackward.apply(input_, process_group, dim)
|
||||
|
||||
|
||||
def gather_forward_split_backward(input_, process_group, dim):
|
||||
return _GatherForwardSplitBackward.apply(input_, process_group, dim)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import torch
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, Number, convert_to_colo_tensor
|
||||
from ._utils import reduce_input, reduce_grad
|
||||
|
||||
|
||||
def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTensor, beta: Number,
|
||||
|
@ -12,18 +11,16 @@ def colo_addmm_1Drow(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
# mat1:S[1] x mat2:S[0] = Output:P
|
||||
# beta * input + alpha * All-Reduce(Output) = res
|
||||
|
||||
mat1 = mat1.convert_to_dist_spec(
|
||||
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]))
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.shard(mat2.get_process_group(), [-1], [mat2.get_tp_world_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = torch.mm(mat1, mat2)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
output = reduce_input(partial_output, mat1.get_process_group())
|
||||
# input
|
||||
assert not input_tensor.has_compute_spec(), 'Invalid input spec for 1Drow addmm op'
|
||||
output = beta * input_tensor + alpha * output
|
||||
output = ColoTensor.from_torch_tensor(output,
|
||||
spec=TensorSpec(distspec.replicate(mat2.tensor_spec.get_process_group())))
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(mat2.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
|
@ -31,13 +28,12 @@ def colo_addmm_1Dcol(input_tensor: ColoTensor, mat1: ColoTensor, mat2: ColoTenso
|
|||
alpha: Number) -> ColoTensor:
|
||||
# mat1:B x mat2:S[1] + input:S[1] = Output:S[1]
|
||||
compute_spec = mat2.tensor_spec.compute_spec
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.tensor_spec.get_process_group()))
|
||||
mat1 = reduce_grad(mat1, ParallelMode.PARALLEL_1D)
|
||||
mat1 = mat1.convert_to_dist_spec(distspec.replicate(mat2.get_process_group()))
|
||||
mat1 = reduce_grad(mat1, mat1.get_process_group())
|
||||
|
||||
output_parallel = torch.addmm(input_tensor, mat1, mat2, beta=beta, alpha=alpha)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(mat2.tensor_spec.get_process_group(), [-1], [mat2.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output_spec = TensorSpec(distspec.shard(mat2.get_process_group(), [-1], [mat2.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
if compute_spec.output_replicate:
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
import torch.nn.functional as F
|
||||
from typing import Optional
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor, reduce_input
|
||||
|
||||
|
||||
def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||
|
@ -17,7 +14,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Dcol split the weight(lookup table) to (num_embeddings, embedding_dim/P)
|
||||
# Gather splitted lookup table
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.get_process_group()))
|
||||
|
||||
output_parallel = F.embedding(input_tensor,
|
||||
weight,
|
||||
|
@ -26,9 +23,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output_spec = TensorSpec(distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
compute_spec = weight.tensor_spec.compute_spec
|
||||
|
@ -49,9 +45,10 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
||||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.get_process_group()))
|
||||
|
||||
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
tensor_parallel_rank = weight.tensor_spec.dist_spec.process_group.tp_local_rank()
|
||||
num_embeddings_per_partition = weight.size_local(0)
|
||||
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
||||
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
||||
|
@ -75,9 +72,8 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
# Mask the output embedding.
|
||||
partial_output[input_mask, :] = 0.
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
output = ColoTensor.from_torch_tensor(output,
|
||||
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
|
||||
output = reduce_input(partial_output, weight.get_process_group())
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
|
|
|
@ -32,9 +32,8 @@ def colo_embedding_bag_1Dcol(input_tensor: ColoTensor,
|
|||
per_sample_weights=per_sample_weights,
|
||||
include_last_offset=include_last_offset,
|
||||
padding_idx=padding_idx)
|
||||
output_spec = TensorSpec(
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output_spec = TensorSpec(distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
|
||||
if weight.tensor_spec.compute_spec.output_replicate:
|
||||
|
|
|
@ -17,7 +17,7 @@ def colo_layernorm(
|
|||
input_tensor, weight, bias = tuple(map(convert_to_colo_tensor, (input_tensor, weight, bias)))
|
||||
|
||||
# TODO (ver217): check dist spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.tensor_spec.get_process_group()))
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(input_tensor.get_process_group()))
|
||||
|
||||
output = F.layer_norm(input_tensor, normalized_shape, weight=weight, bias=bias, eps=eps)
|
||||
output = ColoTensor.from_torch_tensor(output, input_tensor.tensor_spec)
|
||||
|
|
|
@ -2,9 +2,8 @@ import torch.nn.functional as F
|
|||
from typing import Optional
|
||||
from ._utils import GeneralTensor, convert_to_colo_tensor
|
||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||
from colossalai.nn.layer.parallel_1d._utils import reduce_input, reduce_grad
|
||||
from ._utils import reduce_input, reduce_grad
|
||||
from colossalai.tensor import ComputePattern, TensorSpec, ComputePattern, ComputeSpec, ColoTensor, distspec
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.nn.graph import register_colo_graph, GraphOpNode, GraphGlobalEnv
|
||||
|
||||
|
||||
|
@ -13,19 +12,18 @@ def colo_linear_1Drow(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# All-Reduce(Output) + bias = res
|
||||
# Input:S[1]
|
||||
input_tensor = input_tensor.convert_to_dist_spec(
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1], [weight.tensor_spec.get_process_group_size()]))
|
||||
distspec.shard(weight.get_process_group(), [-1], [weight.get_tp_world_size()]))
|
||||
|
||||
# Output:P
|
||||
partial_output = F.linear(input_tensor, weight)
|
||||
# Reduce(Output)
|
||||
output = reduce_input(partial_output, ParallelMode.PARALLEL_1D)
|
||||
output = reduce_input(partial_output, weight.get_process_group())
|
||||
# Bias
|
||||
if bias is not None:
|
||||
assert not bias.has_compute_spec(), 'Invalid bias spec for 1Drow Linear op'
|
||||
output = output + bias
|
||||
|
||||
output = ColoTensor.from_torch_tensor(output,
|
||||
spec=TensorSpec(distspec.replicate(weight.tensor_spec.get_process_group())))
|
||||
output = ColoTensor.from_torch_tensor(output, spec=TensorSpec(distspec.replicate(weight.get_process_group())))
|
||||
return output
|
||||
|
||||
|
||||
|
@ -35,13 +33,13 @@ def colo_linear_1Dcol(input_tensor: ColoTensor, weight: ColoTensor, bias: Option
|
|||
# Input:B
|
||||
compute_spec = weight.tensor_spec.compute_spec
|
||||
input_tensor = input_tensor.convert_to_dist_spec(distspec.replicate(weight.tensor_spec.get_process_group()))
|
||||
input_parallel = reduce_grad(input_tensor, ParallelMode.PARALLEL_1D)
|
||||
input_parallel = reduce_grad(input_tensor, weight.tensor_spec.dist_spec.process_group)
|
||||
|
||||
output_parallel = F.linear(input_parallel, weight, bias)
|
||||
output = ColoTensor.from_torch_tensor(output_parallel,
|
||||
spec=TensorSpec(
|
||||
distspec.shard(weight.tensor_spec.get_process_group(), [-1],
|
||||
[weight.tensor_spec.get_process_group_size()]),
|
||||
distspec.shard(weight.get_process_group(), [-1],
|
||||
[weight.get_tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D)))
|
||||
if compute_spec.output_replicate:
|
||||
return output.to_replicate()
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
import torch
|
||||
import itertools
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||
from colossalai.gemini.chunk import TensorState, Chunk
|
||||
|
@ -12,6 +10,7 @@ from typing import Dict, Iterable, List, Optional
|
|||
from colossalai.logging import get_dist_logger
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor import ProcessGroup as ColoProcessGroup
|
||||
from .reducer import Reducer
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
|
||||
|
@ -45,8 +44,8 @@ class ColoDDP(torch.nn.Module):
|
|||
>>> from colossalai.core import global_context as gpc
|
||||
>>> from colossalai.context import ParallelMode
|
||||
>>> model = torch.nn.Linear(20, 1)
|
||||
>>> model = ColoDDP(model)
|
||||
>>> // model = ColoDDP(model, process_group=gpc.get_group(ParallelMode.DATA), cpu_process_group=gpc.get_cpu_group(ParallelMode.DATA))
|
||||
>>> pg = ProcessGroup(tp_degree = world_size//2)
|
||||
>>> model = ColoDDP(model, pg)
|
||||
>>> logits = model(x)
|
||||
>>> loss = criterion(logits, labels)
|
||||
>>> model.backward(loss)
|
||||
|
@ -55,13 +54,13 @@ class ColoDDP(torch.nn.Module):
|
|||
module (torch.nn.Module): Module to apply DDP.
|
||||
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses.
|
||||
If it's None, the default data parallel group will be used. Defaults to None.
|
||||
process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
||||
cpu_process_group (Optional[dist.ProcessGroup], optional): The process group which DDP uses for those parameters on CPU.
|
||||
If it's None, the default CPU data parallel group will be used. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
process_group: ColoProcessGroup,
|
||||
cpu_process_group: Optional[dist.ProcessGroup] = None,
|
||||
bucket_cap_mb: int = 25,
|
||||
rebuild_bucket: bool = True) -> None:
|
||||
|
@ -69,8 +68,9 @@ class ColoDDP(torch.nn.Module):
|
|||
super().__init__()
|
||||
self.module = module
|
||||
self.comm_stream: torch.cuda.Stream = torch.cuda.Stream()
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
|
||||
assert process_group
|
||||
|
||||
self.process_group = process_group.dp_process_group()
|
||||
self.dp_world_size = self.process_group.size()
|
||||
self.reducer = Reducer(bucket_cap_mb)
|
||||
self.rebuild_bucket = rebuild_bucket
|
||||
|
@ -120,6 +120,8 @@ class ColoDDP(torch.nn.Module):
|
|||
return empty_grad
|
||||
|
||||
else:
|
||||
#TODO(jiaruifang) fixme
|
||||
raise NotImplementedError
|
||||
dist.all_reduce(grad, group=self.cpu_process_group)
|
||||
return grad
|
||||
|
||||
|
@ -191,8 +193,11 @@ class ZeroDDP(ColoDDP):
|
|||
For more details, see the API reference of ``GeminiManager``.
|
||||
"""
|
||||
|
||||
def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
|
||||
super().__init__(module.half())
|
||||
def __init__(self,
|
||||
module: torch.nn.Module,
|
||||
gemini_manager: GeminiManager,
|
||||
process_group: Optional[ColoProcessGroup] = None) -> None:
|
||||
super().__init__(module.half(), process_group=process_group)
|
||||
self.gemini_manager = gemini_manager
|
||||
self.chunk_manager = gemini_manager.chunk_manager
|
||||
self.param_op_hook = ZeROHookV2(gemini_manager)
|
||||
|
|
|
@ -52,5 +52,5 @@ class ColoModule(object):
|
|||
def get_param_names(self):
|
||||
return self._shard_params
|
||||
|
||||
def register(self, compute_pattern):
|
||||
def register(self, compute_pattern, pg):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
|
@ -10,20 +10,18 @@ class ColoEmbedding(ColoModule):
|
|||
super(ColoEmbedding, self).__init__()
|
||||
self._register_shard_params(['weight'])
|
||||
|
||||
def register(self, compute_pattern):
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
if ComputePattern.TP1D == compute_pattern:
|
||||
self._set_TP1D()
|
||||
self._set_TP1D(pg)
|
||||
|
||||
def _set_TP1D(self):
|
||||
def _set_TP1D(self, pg: ProcessGroup):
|
||||
# TP1D Row Linear
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
|
@ -32,9 +30,7 @@ class ColoEmbedding(ColoModule):
|
|||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ComputePattern, distspec, ProcessGroup
|
||||
|
||||
|
||||
class ColoLinear(ColoModule):
|
||||
|
@ -10,22 +8,19 @@ class ColoLinear(ColoModule):
|
|||
super(ColoLinear, self).__init__()
|
||||
self._register_shard_params(['weight', 'bias'])
|
||||
|
||||
def register(self, compute_pattern):
|
||||
def register(self, compute_pattern, pg: ProcessGroup):
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
if ComputePattern.TP1D == compute_pattern:
|
||||
self._set_TP1D()
|
||||
self._set_TP1D(pg)
|
||||
|
||||
def _set_TP1D(self):
|
||||
def _set_TP1D(self, pg):
|
||||
# TP1D Row Linear
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'bias':
|
||||
None
|
||||
'weight': distspec.shard(pg, [-1], [pg.tp_world_size()]),
|
||||
'bias': None
|
||||
},
|
||||
mode='row',
|
||||
)
|
||||
|
@ -34,12 +29,8 @@ class ColoLinear(ColoModule):
|
|||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'bias':
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||
[gpc.get_world_size(ParallelMode.PARALLEL_1D)])
|
||||
'weight': distspec.shard(pg, [0], [pg.tp_world_size()]),
|
||||
'bias': distspec.shard(pg, [0], [pg.tp_world_size()])
|
||||
},
|
||||
mode='col',
|
||||
)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from typing import Dict
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec
|
||||
from colossalai.tensor import ColoParameter, ComputeSpec, TensorSpec, ProcessGroup
|
||||
from . import ColoModule
|
||||
import torch
|
||||
|
||||
|
@ -29,7 +29,7 @@ def get_colo_module(module: torch.nn.Module):
|
|||
return None
|
||||
|
||||
|
||||
def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
def check_colo_module(module: torch.nn.Module, pg: ProcessGroup, recursive=True):
|
||||
if is_colo_module(module):
|
||||
colo_module = get_colo_module(module)
|
||||
param_names = colo_module.get_param_names()
|
||||
|
@ -50,7 +50,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|||
continue
|
||||
|
||||
if compute_pattern is not None:
|
||||
colo_module.register(compute_pattern)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern(compute_pattern):
|
||||
raise Exception(
|
||||
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
|
||||
|
@ -76,16 +76,20 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
|||
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
check_colo_module(submodule, recursive=True)
|
||||
check_colo_module(submodule, pg=pg, recursive=True)
|
||||
|
||||
|
||||
def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursive=True, mode='default'):
|
||||
def init_colo_module(module: torch.nn.Module,
|
||||
compute_spec: ComputeSpec,
|
||||
pg: ProcessGroup,
|
||||
recursive=True,
|
||||
mode='default'):
|
||||
compute_pattern = compute_spec.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ComputeSpec
|
||||
colo_module = get_colo_module(module)
|
||||
colo_module.register(compute_pattern)
|
||||
colo_module.register(compute_pattern, pg)
|
||||
if not colo_module.has_compute_pattern_with_mode(compute_pattern, mode=mode):
|
||||
raise NotImplementedError
|
||||
# a set for modules which update at least one param in the init process.
|
||||
|
@ -101,7 +105,7 @@ def init_colo_module(module: torch.nn.Module, compute_spec: ComputeSpec, recursi
|
|||
for mod in param.shared_param_modules:
|
||||
modules_update_param.add(mod)
|
||||
for mod in modules_update_param:
|
||||
check_colo_module(mod, recursive=False)
|
||||
check_colo_module(mod, pg, recursive=False)
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
init_colo_module(submodule, compute_spec, recursive=True, mode=mode)
|
||||
init_colo_module(submodule, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
|
|
@ -78,6 +78,12 @@ class ColoTensor(torch.Tensor):
|
|||
def is_model_data(self) -> bool:
|
||||
return self._type == TensorType.MODEL
|
||||
|
||||
def get_process_group(self) -> 'ProcessGroup':
|
||||
return self._tensor_spec.dist_spec.process_group
|
||||
|
||||
def get_tp_world_size(self) -> int:
|
||||
return self._tensor_spec.dist_spec.process_group.tp_world_size()
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
|
|
|
@ -5,6 +5,7 @@ from contextlib import contextmanager
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from packaging import version
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||
|
@ -64,7 +65,7 @@ class DistSpecManager:
|
|||
DistSpecManager._sanity_check(old_dist_spec, dist_spec)
|
||||
|
||||
chunk = tensor
|
||||
idx = dist_spec.process_group.rank()
|
||||
idx = dist_spec.process_group.tp_local_rank()
|
||||
num_parts = prod(dist_spec.num_partitions)
|
||||
for i, dim in enumerate(dist_spec.dims):
|
||||
num_parts //= dist_spec.num_partitions[i]
|
||||
|
@ -91,8 +92,9 @@ class DistSpecManager:
|
|||
saved_dev = tensor.device
|
||||
tensor.data = tensor.data.cuda()
|
||||
|
||||
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.size())]
|
||||
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group)
|
||||
buffer = [torch.empty_like(tensor) for _ in range(old_dist_spec.process_group.tp_world_size())]
|
||||
assert tensor.device.type == 'cuda'
|
||||
dist.all_gather(buffer, tensor, group=old_dist_spec.process_group.tp_process_group())
|
||||
for i in range(len(old_dist_spec.dims) - 1, -1, -1):
|
||||
new_buffer = []
|
||||
dim = old_dist_spec.dims[i]
|
||||
|
@ -108,14 +110,14 @@ class DistSpecManager:
|
|||
|
||||
@staticmethod
|
||||
def _all_to_all(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||
world_size = old_dist_spec.process_group.size()
|
||||
world_size = old_dist_spec.process_group.tp_world_size()
|
||||
if world_size == 1:
|
||||
return tensor
|
||||
|
||||
assert tensor.device.type == "cuda" and dist.get_backend(old_dist_spec.process_group) == "nccl", \
|
||||
assert tensor.device.type == "cuda" and old_dist_spec.process_group.backend == "nccl", \
|
||||
"Currently, only CUDA Tensor with NCCL backend is supported for the requested AlltoAll " \
|
||||
f"collective function, however, we got {tensor.device.type} device and " \
|
||||
f"{dist.get_backend(old_dist_spec.process_group)} backend"
|
||||
f"{old_dist_spec.process_group.backend} backend"
|
||||
|
||||
gather_dim = old_dist_spec.dims[0]
|
||||
scatter_dim = dist_spec.dims[0]
|
||||
|
@ -126,7 +128,7 @@ class DistSpecManager:
|
|||
|
||||
scatter_list = [t.contiguous() for t in torch.tensor_split(tensor, world_size, scatter_dim)]
|
||||
gather_list = [torch.empty(*shapes, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
|
||||
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group)
|
||||
dist.all_to_all(gather_list, scatter_list, group=old_dist_spec.process_group.tp_process_group())
|
||||
|
||||
output_ = torch.cat(gather_list, dim=gather_dim).contiguous()
|
||||
assert output_.shape[scatter_dim] == scattered_dim_size and output_.shape[gather_dim] == gathered_dim_size
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from enum import Enum
|
||||
from torch.distributed import ProcessGroup
|
||||
from colossalai.tensor import ProcessGroup
|
||||
from typing import Optional, List
|
||||
from numpy import prod
|
||||
|
||||
|
@ -51,8 +51,8 @@ def replicate(process_group: Optional[ProcessGroup] = None) -> _DistSpec:
|
|||
|
||||
|
||||
def shard(process_group: ProcessGroup, dims: List[int], num_partitions: List[int]) -> _DistSpec:
|
||||
assert process_group is not None
|
||||
assert process_group is not None and isinstance(process_group, ProcessGroup)
|
||||
assert isinstance(dims, list) and isinstance(num_partitions, list)
|
||||
assert len(dims) == len(num_partitions)
|
||||
assert prod(num_partitions) == process_group.size(), f"{num_partitions} {process_group.size()}"
|
||||
assert prod(num_partitions) == process_group.tp_world_size(), f"{num_partitions} {process_group.tp_world_size()}"
|
||||
return _DistSpec(DistPlacementPattern.SHARD, process_group, dims=tuple(dims), num_partitions=tuple(num_partitions))
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import torch
|
||||
from typing import List, Optional
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
|
||||
class ProcessGroup:
|
||||
|
@ -41,12 +42,12 @@ class ProcessGroup:
|
|||
if dp_degree and not tp_degree:
|
||||
self._dp_degree = dp_degree
|
||||
assert self._world_size % self._dp_degree == 0, f"DP degree {dp_degree} should be divisible by {self._world_size} hen DP degree is None"
|
||||
self._tp_degree = self._world_size / dp_degree
|
||||
self._tp_degree = self._world_size // dp_degree
|
||||
|
||||
if not dp_degree and tp_degree:
|
||||
self._tp_degree = tp_degree
|
||||
assert self._world_size % self._tp_degree == 0, f"TP degree {tp_degree} should be divisible by {self._world_size} when DP degree is None"
|
||||
self._dp_degree = self._world_size / tp_degree
|
||||
self._dp_degree = self._world_size // tp_degree
|
||||
|
||||
self._tp_rank_list = []
|
||||
self._dp_rank_list = []
|
||||
|
@ -58,12 +59,48 @@ class ProcessGroup:
|
|||
if rank_id // self._tp_degree == self._rank // self._tp_degree:
|
||||
self._tp_rank_list.append(rank_id)
|
||||
|
||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list, backend=backend)
|
||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list, backend=backend)
|
||||
assert backend == 'nccl'
|
||||
self._tp_process_group = torch.distributed.new_group(ranks=self._tp_rank_list)
|
||||
self._dp_process_group = torch.distributed.new_group(ranks=self._dp_rank_list)
|
||||
|
||||
self.logger = get_dist_logger('ProcessGroup')
|
||||
self.logger.info(f'{self._rank} initialize TP group on {self._tp_rank_list} DP group pn {self._dp_rank_list}')
|
||||
|
||||
@property
|
||||
def backend(self):
|
||||
return self._backend
|
||||
|
||||
def __eq__(self, obj: 'ProcessGroup') -> bool:
|
||||
if not isinstance(obj, ProcessGroup):
|
||||
return False
|
||||
if self._rank != obj._rank:
|
||||
assert False
|
||||
if self._rank_list != obj._rank_list:
|
||||
assert False
|
||||
if self._tp_rank_list != obj._tp_rank_list:
|
||||
assert False
|
||||
if self._dp_rank_list != obj._dp_rank_list:
|
||||
assert False
|
||||
if self._backend != obj._backend:
|
||||
assert False
|
||||
if self._tp_degree != obj._tp_degree:
|
||||
return False
|
||||
if self._dp_degree != obj._dp_degree:
|
||||
return False
|
||||
return True
|
||||
|
||||
def rank(self):
|
||||
return self._rank
|
||||
|
||||
def world_size(self):
|
||||
return self._world_size
|
||||
|
||||
def tp_local_rank(self):
|
||||
return self._rank % self._tp_degree
|
||||
|
||||
def dp_local_rank(self):
|
||||
return self._rank // self._tp_degree
|
||||
|
||||
def dp_world_size(self):
|
||||
return len(self._dp_rank_list)
|
||||
|
||||
|
|
|
@ -17,11 +17,12 @@ class TensorSpec(object):
|
|||
self.compute_spec = compute_spec
|
||||
self.dist_spec = dist_spec
|
||||
|
||||
# TODO(jiaruifang) actually need tp process group
|
||||
def get_process_group(self):
|
||||
return self.dist_spec.process_group
|
||||
|
||||
def get_process_group_size(self):
|
||||
return dist.get_world_size(self.dist_spec.process_group)
|
||||
return dist.get_world_size(self.dist_spec.process_group.tp_process_group())
|
||||
|
||||
def get_placement(self):
|
||||
return self.dist_spec.placement
|
||||
|
@ -30,7 +31,7 @@ class TensorSpec(object):
|
|||
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||
or (len(self.dist_spec.num_partitions) == 1
|
||||
and self.dist_spec.num_partitions[0] == 1) \
|
||||
or (self.dist_spec.process_group.size() == 1)
|
||||
or (self.dist_spec.process_group.tp_world_size() == 1)
|
||||
|
||||
def is_shard_1dcol(self):
|
||||
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||
|
|
|
@ -15,6 +15,7 @@ import torch.distributed as dist
|
|||
import os
|
||||
import random
|
||||
import numpy as np
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
|
@ -27,14 +28,16 @@ def set_seed(seed):
|
|||
|
||||
|
||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
return ColoDDP(module)
|
||||
pg = ProcessGroup()
|
||||
return ColoDDP(module, process_group=pg)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
|
||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ZeroDDP(module, gemini_manager)
|
||||
pg = ProcessGroup()
|
||||
return ZeroDDP(module, gemini_manager, pg)
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
|
|
@ -13,6 +13,7 @@ from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
|||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from typing import Callable
|
||||
from collections import OrderedDict
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
||||
|
@ -22,14 +23,16 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
|
|||
|
||||
|
||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||
return ColoDDP(module)
|
||||
pg = ProcessGroup()
|
||||
return ColoDDP(module, process_group=pg)
|
||||
|
||||
|
||||
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
|
||||
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero)
|
||||
gemini_manager = GeminiManager('cuda', chunk_manager)
|
||||
return ZeroDDP(module, gemini_manager)
|
||||
pg = ProcessGroup()
|
||||
return ZeroDDP(module, gemini_manager, process_group=pg)
|
||||
|
||||
|
||||
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
|
||||
|
|
|
@ -41,7 +41,7 @@ def tensor_equal(A, B):
|
|||
return torch.allclose(A, B, rtol=1e-3, atol=1e-1)
|
||||
|
||||
|
||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
||||
def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_size):
|
||||
assert tensor.ndim == shard.ndim
|
||||
if tensor.shape == shard.shape:
|
||||
return tensor_equal(tensor, shard)
|
||||
|
@ -50,8 +50,10 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor):
|
|||
if dims_not_eq.numel() == 1:
|
||||
# 1D shard
|
||||
dim = dims_not_eq.item()
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
if world_size is None:
|
||||
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
|
||||
if rank is None:
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -3,14 +3,12 @@ import torch
|
|||
import pytest
|
||||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.tensor import ColoTensor
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
from colossalai.core import global_context as gpc
|
||||
from _utils import tensor_shard_equal, tensor_equal
|
||||
|
||||
|
||||
|
@ -38,18 +36,14 @@ class Conv1D(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
bias.set_tensor_spec(spec)
|
||||
|
@ -59,7 +53,9 @@ def run_with_spec(spec_init_func):
|
|||
model = Conv1D(4, 16).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
||||
spec_init_func(weight, bias)
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
spec_init_func(weight, bias, pg)
|
||||
x = torch.rand(2, 16).cuda()
|
||||
out = model(x)
|
||||
colo_out = torch.addmm(bias, x, weight)
|
||||
|
@ -68,13 +64,12 @@ def run_with_spec(spec_init_func):
|
|||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
tensor_shard_equal(model.bias.grad, bias.grad)
|
||||
tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_row)
|
||||
run_with_spec(init_1d_col)
|
||||
|
||||
|
|
|
@ -7,12 +7,12 @@ import torch.multiprocessing as mp
|
|||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import DistSpecManager, distspec
|
||||
from colossalai.tensor import DistSpecManager, distspec, ProcessGroup
|
||||
from functools import partial
|
||||
|
||||
|
||||
def run():
|
||||
group = _get_default_group()
|
||||
group = ProcessGroup(tp_degree=dist.get_world_size())
|
||||
rank = dist.get_rank()
|
||||
size = dist.get_world_size()
|
||||
depth = int(math.sqrt(size))
|
||||
|
@ -34,7 +34,7 @@ def run():
|
|||
|
||||
|
||||
def check_mem():
|
||||
group = _get_default_group()
|
||||
group = ProcessGroup(tp_degree=dist.get_world_size())
|
||||
size = dist.get_world_size()
|
||||
assert torch.cuda.memory_allocated() == 0
|
||||
x = torch.rand(32, 32).cuda()
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ColoTensor, distspec, ColoParameter
|
||||
from colossalai.tensor import distspec, ColoParameter
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
||||
|
@ -10,23 +9,21 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.EmbeddingBag(10, 4).cuda()
|
||||
weight = ColoParameter(model.weight.clone())
|
||||
spec_init_func(weight)
|
||||
spec_init_func(weight, pg)
|
||||
inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
|
||||
offsets = torch.tensor([0, 4]).cuda()
|
||||
out = model(inputs, offsets=offsets)
|
||||
|
@ -35,7 +32,7 @@ def run_with_spec(spec_init_func):
|
|||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ColoTensor, distspec
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
@ -11,30 +10,26 @@ import torch.multiprocessing as mp
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
def run_with_spec(spec_init_func, pg: ProcessGroup):
|
||||
model = torch.nn.Embedding(12, 32).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||
spec_init_func(weight)
|
||||
spec_init_func(weight, pg)
|
||||
x = torch.tensor((0, 3, 6, 9)).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.embedding(x, weight)
|
||||
|
@ -42,14 +37,16 @@ def run_with_spec(spec_init_func):
|
|||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
# compare grad inside a TP group
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_row)
|
||||
run_with_spec(init_1d_col)
|
||||
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
run_with_spec(init_1d_row, pg)
|
||||
run_with_spec(init_1d_col, pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -1,51 +1,54 @@
|
|||
import pytest
|
||||
|
||||
import colossalai
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
||||
|
||||
from functools import partial
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
|
||||
def init_1d_row_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
tensor_spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_tensor_spec(spec)
|
||||
p.set_tensor_spec(tensor_spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert tensor_shard_equal(torch_p, p)
|
||||
assert pg.tp_local_rank() is not None, f"{pg.rank()} {pg.tp_world_size()} {pg._tp_degree} {pg.tp_local_rank()}1"
|
||||
assert pg.tp_world_size() is not None
|
||||
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def check_grad_equal(model, torch_model):
|
||||
def check_grad_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
assert tensor_shard_equal(torch_p.grad, p.grad)
|
||||
assert tensor_shard_equal(torch_p.grad, p.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_gpt(init_spec_func, use_ddp):
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(dp_degree=(2 if (use_ddp and world_size >= 2) else 1))
|
||||
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
|
@ -54,21 +57,25 @@ def run_gpt(init_spec_func, use_ddp):
|
|||
model = model.cuda()
|
||||
torch_model = model_builder().cuda()
|
||||
if use_ddp:
|
||||
model = ColoDDP(model)
|
||||
# torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg)
|
||||
# torch.distributed.barrier()
|
||||
torch_model = DDP(torch_model,
|
||||
device_ids=[gpc.get_global_rank()],
|
||||
process_group=gpc.get_group(ParallelMode.DATA))
|
||||
|
||||
model = ColoDDP(model, process_group=pg)
|
||||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p)
|
||||
init_spec_func(model)
|
||||
check_param_equal(model, torch_model)
|
||||
init_spec_func(model, pg)
|
||||
check_param_equal(model, torch_model, pg)
|
||||
model.train()
|
||||
torch_model.train()
|
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||
set_seed(pg.tp_local_rank())
|
||||
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
logits = model(input_ids, attn_mask)
|
||||
torch_logits = torch_model(input_ids, attn_mask)
|
||||
assert tensor_equal(torch_logits, logits)
|
||||
assert tensor_equal(torch_logits, logits), f"{torch_logits - logits}"
|
||||
loss = criterion(logits, input_ids)
|
||||
torch_loss = criterion(torch_logits, input_ids)
|
||||
if use_ddp:
|
||||
|
@ -76,7 +83,7 @@ def run_gpt(init_spec_func, use_ddp):
|
|||
else:
|
||||
loss.backward()
|
||||
torch_loss.backward()
|
||||
check_grad_equal(model, torch_model)
|
||||
check_grad_equal(model, torch_model, pg)
|
||||
if i > 0:
|
||||
break
|
||||
|
||||
|
@ -87,11 +94,12 @@ def run_dist(rank, world_size, port, use_ddp):
|
|||
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_gpt(init_1d_row_spec, use_ddp)
|
||||
# run_gpt(init_1d_row_spec, use_ddp)
|
||||
run_gpt(init_1d_col_spec, use_ddp)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("under development")
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
@rerun_if_address_is_in_use()
|
||||
|
|
|
@ -1,88 +0,0 @@
|
|||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ComputePattern, ComputeSpec
|
||||
|
||||
from functools import partial
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
|
||||
from colossalai.nn.parallel.layers import init_colo_module
|
||||
from colossalai.nn.parallel.data_parallel import ColoDDP
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
|
||||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import pytest
|
||||
|
||||
|
||||
class Net(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.embed = torch.nn.Embedding(20, 4)
|
||||
self.proj = torch.nn.Linear(4, 8)
|
||||
|
||||
def forward(self, x):
|
||||
# move input to cpu and restore output
|
||||
current_dev = x.device
|
||||
x = x.to('cpu')
|
||||
x = self.embed(x)
|
||||
x = x.to(current_dev)
|
||||
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
def run_hybrid_device(use_ddp, mode):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = Net()
|
||||
|
||||
real_model = model
|
||||
if use_ddp:
|
||||
model = ColoDDP(model)
|
||||
real_model = model.module
|
||||
|
||||
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
|
||||
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||
|
||||
# use cpu gloo to handle embedding
|
||||
real_model.embed.to('cpu')
|
||||
gloo_group_tp = gpc.get_cpu_group(ParallelMode.PARALLEL_1D)
|
||||
real_model.embed.weight.spec.dist_spec.process_group = gloo_group_tp
|
||||
|
||||
print(f'embedding weight size: {real_model.embed.weight.size()} | new device: {real_model.embed.weight.device}')
|
||||
#print(f'linear weight size: {real_model.proj.weight.size()} | new device: {real_model.proj.weight.device}')
|
||||
|
||||
optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
|
||||
out = model(data)
|
||||
out.sum().backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, use_ddp, mode):
|
||||
if use_ddp and world_size == 1:
|
||||
return
|
||||
tp_world_size = world_size // 2 if use_ddp else world_size
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_hybrid_device(use_ddp, mode)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||
@pytest.mark.parametrize('mode', ['col', 'row'])
|
||||
@rerun_if_address_is_in_use()
|
||||
# Working for simulate the embedding(CPU DP+TP) -> nn(GPU DP+TP)
|
||||
def _test_hybrid_device(world_size, use_ddp, mode):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, mode=mode)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
_test_hybrid_device(4, True, 'row')
|
|
@ -12,32 +12,29 @@ import torch.nn.functional as F
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
bias.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.Linear(4, 8).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()))
|
||||
spec_init_func(weight, bias)
|
||||
spec_init_func(weight, bias, pg)
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.linear(x, weight, bias)
|
||||
|
@ -46,8 +43,8 @@ def run_with_spec(spec_init_func):
|
|||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad)
|
||||
assert tensor_shard_equal(model.bias.grad, bias.grad)
|
||||
assert tensor_shard_equal(model.weight.grad, weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model.bias.grad, bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
from functools import partial
|
||||
from _utils import tensor_shard_equal, set_seed
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
import colossalai
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
|
@ -12,34 +14,30 @@ from colossalai.utils.model.colo_init_context import ColoInitContext
|
|||
from colossalai.tensor import distspec, TensorSpec, ComputePattern, \
|
||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
from functools import partial
|
||||
from _utils import tensor_shard_equal, set_seed
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
def init_1d_row_linear(weight, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_linear(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_row_embedding(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [0], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_embedding(weight, pg):
|
||||
spec = TensorSpec(distspec.shard(pg.tp_process_group(), [-1], [pg.tp_world_size()]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(spec)
|
||||
|
||||
|
@ -142,7 +140,7 @@ def run_1d_hybrid_tp(model_name):
|
|||
with torch.no_grad():
|
||||
# check param
|
||||
for p, torch_p in zip(model.parameters(), model_torch.parameters()):
|
||||
assert tensor_shard_equal(torch_p, p)
|
||||
assert tensor_shard_equal(torch_p, p, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
if i > 5:
|
||||
break
|
||||
|
|
|
@ -13,12 +13,10 @@ import colossalai
|
|||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import distspec
|
||||
from colossalai.tensor import distspec, ProcessGroup
|
||||
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
|
@ -26,7 +24,9 @@ from tests.components_to_test.registry import non_distributed_component_funcs
|
|||
def run_model_with_spec(mode, model_name):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
rank = pg.rank()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
|
@ -40,28 +40,28 @@ def run_model_with_spec(mode, model_name):
|
|||
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
# Not all layers in Bert can be mod by 4.
|
||||
# e.g. row shard for all layers is invalid because the first dim of some layer is the classification type size 2.
|
||||
if 'bert' == model_name:
|
||||
if 'col' == mode:
|
||||
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, parallel_action, recursive=True, mode='row')
|
||||
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode='row')
|
||||
elif 'row' == mode:
|
||||
init_colo_module(model.bert.embeddings, parallel_action, recursive=True, mode='col')
|
||||
init_colo_module(model.bert.encoder, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model.bert.embeddings, compute_spec, pg=pg, recursive=True, mode='col')
|
||||
init_colo_module(model.bert.encoder, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
init_colo_module(model.classifier, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
elif 'simple_net' == model_name:
|
||||
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
||||
model = model.cuda()
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(data, 0, group=pg.tp_process_group())
|
||||
torch.distributed.broadcast(label, 0, group=pg.tp_process_group())
|
||||
|
||||
if criterion:
|
||||
output = model(data)
|
||||
|
@ -113,9 +113,10 @@ def run_linear_with_spec(mode):
|
|||
model = torch.nn.Linear(4, 8)
|
||||
|
||||
model_handy = copy(model)
|
||||
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, parallel_action, recursive=True, mode=mode)
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode=mode)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
|
@ -124,8 +125,8 @@ def run_linear_with_spec(mode):
|
|||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
|
||||
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
|
||||
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad, pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_check_shared_param():
|
||||
|
@ -136,6 +137,10 @@ def run_check_shared_param():
|
|||
num_layer = 2
|
||||
vocab_size = 24
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
rank = pg.rank()
|
||||
|
||||
config = BertConfig(vocab_size=vocab_size,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=hidden_dim * 4,
|
||||
|
@ -148,18 +153,16 @@ def run_check_shared_param():
|
|||
model = BertForMaskedLM(config)
|
||||
|
||||
model = model.cuda()
|
||||
parallel_action = ComputeSpec(ComputePattern.TP1D)
|
||||
compute_spec = ComputeSpec(ComputePattern.TP1D)
|
||||
# model.cls.predictions.decoder and model.cls.predictions share the bias, so they should have the same spec
|
||||
assert len(model.cls.predictions.decoder.bias.shared_param_modules) == 2
|
||||
# They are all Linear, so both row is allowed. This should pass check.
|
||||
init_colo_module(model, parallel_action, recursive=True, mode='row')
|
||||
init_colo_module(model, compute_spec, pg=pg, recursive=True, mode='row')
|
||||
# This should be detected by check because you can not set weight as row while set bias as col.
|
||||
col_spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
col_spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
model.cls.predictions.bias.set_tensor_spec(col_spec)
|
||||
try:
|
||||
check_colo_module(model.cls.predictions.decoder, recursive=False)
|
||||
check_colo_module(model.cls.predictions.decoder, pg=pg, recursive=False)
|
||||
except Exception as e:
|
||||
assert 'incorrectly sharded' in str(e)
|
||||
|
||||
|
|
|
@ -4,10 +4,9 @@ import colossalai
|
|||
import torch.nn.functional as F
|
||||
import torch.multiprocessing as mp
|
||||
from functools import partial
|
||||
from colossalai.tensor import ColoTensor, ColoParameter
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.utils import get_current_device
|
||||
from torch.nn import Parameter
|
||||
from torch.distributed.distributed_c10d import _get_default_group
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, TensorSpec
|
||||
|
@ -43,9 +42,10 @@ def check_spec_eq(tensor, other):
|
|||
|
||||
|
||||
def check_element_wise_ops():
|
||||
pg = _get_default_group()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
t = torch.rand(2, 2)
|
||||
x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.size()])))
|
||||
x = ColoTensor(t, spec=TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()])))
|
||||
check_spec_eq(x, x.cuda())
|
||||
assert torch.equal(x.cuda(), t.cuda())
|
||||
check_spec_eq(x, torch.abs(x))
|
||||
|
|
|
@ -11,7 +11,6 @@ import torch.multiprocessing as mp
|
|||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import distspec, TensorSpec, ColoTensor, ProcessGroup
|
||||
from colossalai.context import ParallelMode
|
||||
from functools import partial
|
||||
|
||||
|
||||
|
@ -55,11 +54,9 @@ def test_operand():
|
|||
def _run_view(world_size):
|
||||
t_ref = torch.randn(4, 5)
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
assert pg.dp_world_size() == world_size, f"{pg.dp_world_size()} vs {world_size}"
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
t = ColoTensor.from_torch_tensor(
|
||||
t_ref,
|
||||
TensorSpec(distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])))
|
||||
t_ref, TensorSpec(distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])))
|
||||
|
||||
assert t.size_global()[0] == 4 * world_size
|
||||
assert t.size_global(1) == 5
|
||||
|
@ -77,12 +74,12 @@ def _run_tensor_shard_init(world_size):
|
|||
t_ref = torch.randn(4, 5)
|
||||
|
||||
rank = gpc.get_global_rank()
|
||||
pg = ProcessGroup(rank, list(range(world_size)))
|
||||
shard_spec = distspec.shard(process_group=pg.dp_process_group(), dims=[0], num_partitions=[pg.dp_world_size()])
|
||||
pg = ProcessGroup(rank, list(range(world_size)), tp_degree=world_size)
|
||||
shard_spec = distspec.shard(process_group=pg, dims=[0], num_partitions=[pg.tp_world_size()])
|
||||
tensor_spec = TensorSpec(shard_spec)
|
||||
t = ColoTensor.from_torch_tensor(t_ref.clone(), tensor_spec)
|
||||
t.set_tensor_spec(TensorSpec(dist_spec=distspec.replicate()))
|
||||
assert t.shape == torch.Size((4 * world_size, 5))
|
||||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape} vs ({4 * world_size, 5})"
|
||||
|
||||
|
||||
def _run_tensor_replicated_init(world_size):
|
||||
|
@ -92,11 +89,19 @@ def _run_tensor_replicated_init(world_size):
|
|||
assert t.shape == torch.Size((4 * world_size, 5)), f"{t.shape}"
|
||||
|
||||
|
||||
def _run_process_group(world_size):
|
||||
pg1 = ProcessGroup()
|
||||
pg2 = ProcessGroup()
|
||||
|
||||
assert pg1 == pg2
|
||||
|
||||
|
||||
def run_dist_tests(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
_run_tensor_shard_init(world_size)
|
||||
_run_tensor_replicated_init(world_size)
|
||||
_run_view(world_size)
|
||||
_run_process_group(world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -2,13 +2,11 @@ import pytest
|
|||
import colossalai
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.gemini import ChunkManager
|
||||
from colossalai.core import global_context as gpc
|
||||
from functools import partial
|
||||
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
@ -19,20 +17,22 @@ from colossalai.zero import ZeroOptimizer
|
|||
from colossalai.testing import parameterize
|
||||
from colossalai.amp import convert_to_apex_amp
|
||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
|
||||
|
||||
|
||||
def check_param_equal(model, torch_model):
|
||||
def check_param_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
if p.storage().size() > 0:
|
||||
assert p.dtype == torch.half
|
||||
assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p), f'{torch_p} vs {p}'
|
||||
assert tensor_shard_equal(torch_p.to(dtype=p.dtype, device=p.device), p, pg.tp_local_rank(),
|
||||
pg.tp_world_size()), f'{torch_p} vs {p}'
|
||||
|
||||
|
||||
def check_grad_equal(model, torch_model):
|
||||
def check_grad_equal(model, torch_model, pg: ProcessGroup):
|
||||
for p, torch_p in zip(model.parameters(), torch_model.parameters()):
|
||||
if p.grad is not None:
|
||||
assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad)
|
||||
assert tensor_shard_equal(torch_p.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
|
||||
pg.tp_local_rank(), pg.tp_world_size())
|
||||
|
||||
|
||||
def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
||||
|
@ -44,20 +44,16 @@ def run_fwd_bwd(model, criterion, optimizer, input_ids, attn_mask):
|
|||
return logits
|
||||
|
||||
|
||||
def init_1d_row_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_row_spec(model, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'weight' in n and 'ln' not in n:
|
||||
p.set_tensor_spec(spec)
|
||||
|
||||
|
||||
def init_1d_col_spec(model):
|
||||
spec = TensorSpec(
|
||||
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
ComputeSpec(ComputePattern.TP1D))
|
||||
def init_1d_col_spec(model, pg: ProcessGroup):
|
||||
spec = TensorSpec(distspec.shard(pg, [-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
for n, p in model.named_parameters():
|
||||
if 'ln' not in n and ('weight' in n or 'bias' in n):
|
||||
|
@ -79,44 +75,51 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
|
|||
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
|
||||
torch_p.data.copy_(p)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
# world size, dp = 2, tp =2, construct a hybrid parallelism.
|
||||
if world_size == 4:
|
||||
pg = ProcessGroup(tp_degree=2)
|
||||
else:
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
||||
if tp_init_spec_func:
|
||||
tp_init_spec_func(model)
|
||||
tp_init_spec_func(model, pg)
|
||||
|
||||
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
|
||||
chunk_manager = ChunkManager(chunk_size,
|
||||
enable_distributed_storage=use_zero,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
model = ZeroDDP(model, gemini_manager, pg)
|
||||
optim = HybridAdam(model.parameters(), lr=1e-3)
|
||||
optim = ZeroOptimizer(optim, model, initial_scale=32)
|
||||
|
||||
amp_config = dict(opt_level='O2', keep_batchnorm_fp32=False, loss_scale=32)
|
||||
torch_optim = torch.optim.Adam(torch_model.parameters(), lr=1e-3)
|
||||
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
|
||||
torch_model = DDP(torch_model, device_ids=[gpc.get_global_rank()], process_group=gpc.get_group(ParallelMode.DATA))
|
||||
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
|
||||
|
||||
print(chunk_manager)
|
||||
check_param_equal(model, torch_model)
|
||||
# print(chunk_manager)
|
||||
check_param_equal(model, torch_model, pg)
|
||||
model.train()
|
||||
torch_model.train()
|
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||
set_seed(pg.dp_local_rank())
|
||||
for i, (input_ids, attn_mask) in enumerate(train_dataloader):
|
||||
if i > 2:
|
||||
break
|
||||
|
||||
logits = run_fwd_bwd(model, criterion, optim, input_ids, attn_mask)
|
||||
torch_logits = run_fwd_bwd(torch_model, criterion, torch_optim, input_ids, attn_mask)
|
||||
assert tensor_equal(logits, torch_logits)
|
||||
check_grad_equal(model, torch_model)
|
||||
check_grad_equal(model, torch_model, pg)
|
||||
optim.step()
|
||||
torch_optim.step()
|
||||
check_param_equal(model, torch_model)
|
||||
check_param_equal(model, torch_model, pg)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
if world_size == 4:
|
||||
config['parallel'] = {'tensor': {'mode': '1d', 'size': 2}}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
if world_size == 4:
|
||||
run_gpt(tp_init_spec_func=init_1d_col_spec)
|
||||
|
@ -126,6 +129,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.skip("under development")
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gpt(world_size):
|
||||
|
|
|
@ -1,12 +1,10 @@
|
|||
import pytest
|
||||
import colossalai
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from functools import partial
|
||||
from tests.test_tensor._utils import set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
@ -16,6 +14,7 @@ from colossalai.zero.init_ctx import ZeroInitContext
|
|||
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
def init_zero(model_builder, placement_policy):
|
||||
|
@ -64,7 +63,8 @@ def run_nested_model(placement_policy):
|
|||
|
||||
model.train()
|
||||
model_copy.train()
|
||||
set_seed(gpc.get_local_rank(ParallelMode.DATA))
|
||||
pg = ProcessGroup()
|
||||
set_seed(pg.dp_local_rank())
|
||||
data_iter = iter(train_dataloader)
|
||||
|
||||
data, label = map(lambda x: x.cuda(), next(data_iter))
|
||||
|
|
|
@ -16,6 +16,7 @@ from colossalai.gemini import ChunkManager, GeminiManager
|
|||
from colossalai.testing import parameterize
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.zero import ZeroOptimizer
|
||||
from colossalai.tensor import ProcessGroup
|
||||
|
||||
|
||||
def init_zero(model, use_chunk, use_zero, placement_policy):
|
||||
|
@ -24,7 +25,8 @@ def init_zero(model, use_chunk, use_zero, placement_policy):
|
|||
enable_distributed_storage=use_zero,
|
||||
init_device=GeminiManager.get_default_device(placement_policy))
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
return ZeroDDP(model, gemini_manager)
|
||||
pg = ProcessGroup()
|
||||
return ZeroDDP(model, gemini_manager, pg)
|
||||
|
||||
|
||||
def run_step(model, optim, criterion, data, label):
|
||||
|
|
Loading…
Reference in New Issue