mirror of https://github.com/hpcaitech/ColossalAI
[zero] polish sharded param name (#484)
* [zero] polish sharded param name * polish code * polish * polish code * polish * polsih * polishpull/491/head
parent
9caa8b6481
commit
b334822163
|
@ -34,13 +34,13 @@ class ZeroHook(BaseOpHook):
|
|||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
if param.col_attr.sharded_data_tensor.device != self.computing_device:
|
||||
param.col_attr.sharded_data_tensor.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
@ -49,7 +49,7 @@ class ZeroHook(BaseOpHook):
|
|||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
param.col_attr.remove_torch_payload()
|
||||
|
@ -58,13 +58,13 @@ class ZeroHook(BaseOpHook):
|
|||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.gather(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
if param.col_attr.data.device != self.computing_device:
|
||||
param.col_attr.data.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
||||
param.data = param.col_attr.data.payload
|
||||
if param.col_attr.sharded_data_tensor.device != self.computing_device:
|
||||
param.col_attr.sharded_data_tensor.to(self.computing_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
param.data = param.col_attr.sharded_data_tensor.payload
|
||||
# Store local accumulated grad shard
|
||||
if param.grad is not None:
|
||||
if param.col_attr.bwd_count == 0:
|
||||
|
@ -75,7 +75,7 @@ class ZeroHook(BaseOpHook):
|
|||
else:
|
||||
# We have stored local accumulated grad
|
||||
# The grad here must be locally computed full grad in this backward pass
|
||||
assert param.grad.shape == param.col_attr.data.origin_shape
|
||||
assert param.grad.shape == param.col_attr.sharded_data_tensor.origin_shape
|
||||
param.col_attr.bwd_count += 1
|
||||
if self._memstarts_collector:
|
||||
self._memstarts_collector.sample_memstats()
|
||||
|
@ -84,7 +84,7 @@ class ZeroHook(BaseOpHook):
|
|||
tensor_list = []
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
tensor_list.append(param.col_attr.data)
|
||||
tensor_list.append(param.col_attr.sharded_data_tensor)
|
||||
self.shard_strategy.shard(tensor_list, self.process_group)
|
||||
for param in module.parameters():
|
||||
param.col_attr.remove_torch_payload()
|
||||
|
|
|
@ -50,7 +50,7 @@ class BucketizedTensorCopy(object):
|
|||
self._cuda_buffer.copy_(self._cpu_buffer)
|
||||
flush_offset = 0
|
||||
for sparam, numel in zip(self._buffered_param_list, self._numel_list):
|
||||
sparam.data.copy_payload(self._cpu_buffer.narrow(0, flush_offset, numel))
|
||||
sparam.sharded_data_tensor.copy_payload(self._cpu_buffer.narrow(0, flush_offset, numel))
|
||||
flush_offset += numel
|
||||
|
||||
self.reset()
|
||||
|
|
|
@ -160,8 +160,8 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
self.initialized_param_list.append(param)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.col_attr._data_sharded_tensor], self.dp_process_group)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload)
|
||||
self.shard_strategy.shard([param.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.sharded_data_tensor.payload)
|
||||
# if param.col_attr.grad and self.shard_grad:
|
||||
# self.shard_strategy.shard([param.col_attr._grad_sharded_tensor], self.dp_process_group)
|
||||
# GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload)
|
||||
|
|
|
@ -165,7 +165,7 @@ class ShardedModelV2(nn.Module):
|
|||
if self.shard_param:
|
||||
for p in self.module.parameters():
|
||||
if not p.col_attr.param_is_sharded:
|
||||
self.shard_strategy.shard([p.col_attr.data], self.process_group)
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.process_group)
|
||||
for p in self.module.parameters():
|
||||
p.col_attr.bwd_count = 0
|
||||
if not p.requires_grad:
|
||||
|
@ -249,13 +249,15 @@ class ShardedModelV2(nn.Module):
|
|||
param.col_attr.fp16_grad = reduced_grad.data
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()], self.process_group)
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
prev_params = {}
|
||||
for p in self.module.parameters():
|
||||
prev_params[p] = p.data
|
||||
p.data = p.col_attr.data.payload
|
||||
p.data = p.col_attr.sharded_data_tensor.payload
|
||||
gathered_state_dict = self.module.state_dict(destination, prefix, keep_vars)
|
||||
self.shard_strategy.shard([p.col_attr.data for p in self.module.parameters()], self.process_group)
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
self.process_group)
|
||||
for p in self.module.parameters():
|
||||
p.data = prev_params[p]
|
||||
return gathered_state_dict
|
||||
|
|
|
@ -11,9 +11,9 @@ def col_model_deepcopy(sharded_model: ShardedModelV2, other_model: torch.nn.Modu
|
|||
"""
|
||||
for zero_param, param in zip(sharded_model.parameters(), other_model.parameters()):
|
||||
assert hasattr(zero_param, 'col_attr')
|
||||
shard_flag = zero_param.col_attr.data.is_sharded
|
||||
shard_flag = zero_param.col_attr.sharded_data_tensor.is_sharded
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.gather([zero_param.col_attr.data])
|
||||
param.data = copy.deepcopy(zero_param.col_attr.data.payload)
|
||||
sharded_model.shard_strategy.gather([zero_param.col_attr.sharded_data_tensor])
|
||||
param.data = copy.deepcopy(zero_param.col_attr.sharded_data_tensor.payload)
|
||||
if shard_flag:
|
||||
sharded_model.shard_strategy.shard([zero_param.col_attr.data])
|
||||
sharded_model.shard_strategy.shard([zero_param.col_attr.sharded_data_tensor])
|
||||
|
|
|
@ -109,17 +109,17 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
|
||||
is_param_sharded = p.col_attr.data.is_sharded
|
||||
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
# TODO (ver217): we may not use shard / gather here
|
||||
# Param is no sharded, which means we use ZeRO-2 here
|
||||
# As we only store param shard, we shard it here
|
||||
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
|
||||
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.data.payload).to(self.device)
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = cast_tensor_to_fp32(p.col_attr.sharded_data_tensor.payload).to(self.device)
|
||||
if not is_param_sharded:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
# unscale grads if scaled
|
||||
|
@ -149,24 +149,24 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# a chunk.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
is_param_sharded = p.col_attr.data.is_sharded
|
||||
is_param_sharded = p.col_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
# We use ZeRO-2 here
|
||||
# The `p.col_attr.data` saves full fp16 param
|
||||
# The `p.col_attr.sharded_data_tensor` saves full fp16 param
|
||||
# But we only have updated fp32 param shard here
|
||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
||||
# Then we will gather them
|
||||
self.shard_strategy.shard([p.col_attr.data], self.dp_process_group)
|
||||
self.shard_strategy.shard([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
# We have to use `copy_payload` instead of `reset_payload`
|
||||
# Since p.data is fp32 and p.col_attr.data is fp16
|
||||
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.col_attr.data.copy_payload(p.data)
|
||||
p.col_attr.sharded_data_tensor.copy_payload(p.data)
|
||||
|
||||
if not is_param_sharded:
|
||||
# We gather full fp16 param here
|
||||
self.shard_strategy.gather([p.col_attr.data], self.dp_process_group)
|
||||
p.data = p.col_attr.data.payload
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor], self.dp_process_group)
|
||||
p.data = p.col_attr.sharded_data_tensor.payload
|
||||
return ret
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParam, ShardedParamV2
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
|
||||
__all__ = ['ShardedParam', 'ShardedTensor', 'ShardedParamV2']
|
||||
__all__ = ['ShardedTensor', 'ShardedParamV2']
|
||||
|
|
|
@ -1,12 +1,7 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
@ -15,7 +10,7 @@ class ShardedParamV2(object):
|
|||
param: torch.nn.Parameter,
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
rm_torch_payload=False) -> None:
|
||||
self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
||||
self.fp16_grad: Optional[torch.Tensor] = None
|
||||
self.fp32_grad: Optional[torch.Tensor] = None
|
||||
|
||||
|
@ -37,105 +32,9 @@ class ShardedParamV2(object):
|
|||
self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device)
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data_sharded_tensor
|
||||
def sharded_data_tensor(self):
|
||||
return self._sharded_data_tensor
|
||||
|
||||
@property
|
||||
def param_is_sharded(self):
|
||||
return self._data_sharded_tensor.is_sharded
|
||||
|
||||
|
||||
class ShardedParam(object):
|
||||
r"""
|
||||
A wrapper to torch.nn.Parameter. Shard a param
|
||||
on memory space of different processes.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
other: Union[torch.nn.Parameter, Tuple[int, ...]],
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
is_sharded: bool = False,
|
||||
device: Optional[torch.device] = None) -> None:
|
||||
r"""
|
||||
other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
|
||||
process_group: the process group storing the shared data.
|
||||
is_sharded: is shared the param during __init__.
|
||||
device: the device to place param data payload on
|
||||
"""
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self.is_sharded = False
|
||||
self.device = device
|
||||
|
||||
# Hijack the data payload of param
|
||||
if isinstance(other, torch.nn.Parameter):
|
||||
self._param_payload = other.data.to(device)
|
||||
self._origin_shape = other.shape
|
||||
self._origin_numel = other.numel()
|
||||
if is_sharded:
|
||||
self.shard()
|
||||
elif isinstance(other, tuple):
|
||||
self._origin_shape = other
|
||||
self._origin_numel = numpy.prod(other)
|
||||
|
||||
# TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
|
||||
assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
|
||||
self._param_payload = torch.empty(self._origin_shape, device=device)
|
||||
if is_sharded:
|
||||
self.shard()
|
||||
else:
|
||||
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")
|
||||
|
||||
self._payload_numel = None
|
||||
|
||||
def payload(self, target_device: Optional[torch.device] = None):
|
||||
r"""
|
||||
get the payload and move it to target device
|
||||
"""
|
||||
if target_device is not None:
|
||||
return self._param_payload.to(target_device)
|
||||
return self._param_payload
|
||||
|
||||
def set_payload(self, data: torch.Tensor):
|
||||
r"""
|
||||
set payload as data
|
||||
"""
|
||||
assert self._param_payload.shape == data.shape
|
||||
self._param_payload.copy_(data)
|
||||
|
||||
def shard(self):
|
||||
r"""
|
||||
Distributed the payload of param to all processes.
|
||||
"""
|
||||
if self.is_sharded:
|
||||
return
|
||||
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
|
||||
self.is_sharded = True
|
||||
|
||||
def gather(self):
|
||||
r"""
|
||||
Collect the payload of param from different processes to process of local rank.
|
||||
The payload has to be moved to cuda memory before communication.
|
||||
"""
|
||||
if not self.is_sharded:
|
||||
return
|
||||
|
||||
buffer_list = []
|
||||
payload_numel = self._param_payload.numel()
|
||||
for i in range(self.world_size):
|
||||
if i == self.local_rank:
|
||||
buffer_list.append(self._param_payload.cuda())
|
||||
else:
|
||||
buffer_list.append(torch.zeros(payload_numel).cuda())
|
||||
|
||||
torch.distributed.all_gather(buffer_list,
|
||||
buffer_list[self.local_rank],
|
||||
group=self.process_group,
|
||||
async_op=False)
|
||||
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
||||
self.is_sharded = False
|
||||
|
||||
@property
|
||||
def origin_dtype(self):
|
||||
return self._origin_dtype
|
||||
return self._sharded_data_tensor.is_sharded
|
||||
|
|
|
@ -17,7 +17,6 @@ def test_bucket_copy():
|
|||
for shape in shape_list:
|
||||
# on CPU
|
||||
src_param = torch.nn.Parameter(torch.randn(shape, dtype=torch.float, device=torch.device('cpu')))
|
||||
print(src_param)
|
||||
# on GPU
|
||||
tgt_param = ShardedParamV2(torch.nn.Parameter(torch.ones(shape, dtype=torch.half, device=torch.device('cuda'))))
|
||||
|
||||
|
@ -29,9 +28,10 @@ def test_bucket_copy():
|
|||
copyer.flush()
|
||||
|
||||
for src_param, tgt_param in zip(src_param_list, tgt_param_list):
|
||||
print(tgt_param.data.payload)
|
||||
diff = src_param.cpu().float() - tgt_param.data.payload.cpu().float()
|
||||
assert torch.allclose(src_param.cpu().float(), tgt_param.data.payload.cpu().float(), rtol=1e-03,
|
||||
diff = src_param.cpu().float() - tgt_param.sharded_data_tensor.payload.cpu().float()
|
||||
assert torch.allclose(src_param.cpu().float(),
|
||||
tgt_param.sharded_data_tensor.payload.cpu().float(),
|
||||
rtol=1e-03,
|
||||
atol=1e-03), f"diff {diff}"
|
||||
|
||||
|
||||
|
|
|
@ -119,7 +119,7 @@ def check_params_padding(model, zero_model, loose=False):
|
|||
def check_sharded_params_padding(model, zero_model, loose=False):
|
||||
rank = dist.get_rank()
|
||||
for p, zero_p in zip(model.parameters(), zero_model.parameters()):
|
||||
zero_p = zero_p.col_attr.data.payload.to(p.device).float()
|
||||
zero_p = zero_p.col_attr.sharded_data_tensor.payload.to(p.device).float()
|
||||
chunks = torch.flatten(p).chunk(dist.get_world_size())
|
||||
if rank >= len(chunks):
|
||||
continue
|
||||
|
|
|
@ -34,10 +34,10 @@ def run_model_test(init_device, shard_strategy_class):
|
|||
|
||||
for param in model.parameters():
|
||||
assert hasattr(param, 'col_attr')
|
||||
assert param.col_attr.data.dtype == torch.half
|
||||
assert param.col_attr.data.is_sharded
|
||||
assert param.col_attr.data.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.data.payload.device.type} vs. {init_device.type}'
|
||||
assert param.col_attr.sharded_data_tensor.dtype == torch.half
|
||||
assert param.col_attr.sharded_data_tensor.is_sharded
|
||||
assert param.col_attr.sharded_data_tensor.payload.device.type == init_device.type, \
|
||||
f'{param.col_attr.sharded_data_tensor.payload.device.type} vs. {init_device.type}'
|
||||
|
||||
print(f'cuda usgae {GLOBAL_MODEL_DATA_TRACER.cuda_usage}')
|
||||
print(f'numel {model_numel_tensor}')
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
|
@ -8,13 +5,11 @@ import colossalai
|
|||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||
from colossalai.testing import parameterize
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardStrategy)
|
||||
from colossalai.zero.sharded_param import ShardedParam, ShardedTensor
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
||||
|
||||
|
||||
|
@ -52,7 +47,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
param_ref = deepcopy(param)
|
||||
sparam = ShardedParamV2(param=param, process_group=None)
|
||||
|
||||
allclose(sparam.data.payload, param_ref.data)
|
||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||
|
||||
sparam.remove_torch_payload()
|
||||
assert (param.data.numel() == 1)
|
||||
|
@ -65,69 +60,6 @@ def test_shard_param_v2(world_size):
|
|||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
def _run_test_shard_param(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
param = torch.nn.Parameter(torch.randn(2, 3))
|
||||
param_ref = deepcopy(param)
|
||||
sparam = ShardedParamV2(param=param, process_group=None)
|
||||
print(sparam.data)
|
||||
print(param_ref.data)
|
||||
|
||||
logger = get_dist_logger()
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model_builder, *_ = get_components_func()
|
||||
model = model_builder(checkpoint=True)
|
||||
# add an attribute as col_attr to hijack the access to param.data
|
||||
for _, param in model.named_parameters():
|
||||
numel_ref = (param.numel() + world_size - 1) // world_size
|
||||
param.col_attr = ShardedParam(param)
|
||||
param.col_attr.shard()
|
||||
param_data = param.col_attr.payload(torch.device('cpu'))
|
||||
assert (numel_ref == param_data.numel())
|
||||
|
||||
for _, param in model.named_parameters():
|
||||
param.col_attr.gather()
|
||||
param_data = param.col_attr.payload(torch.device('cpu'))
|
||||
|
||||
disable_existing_loggers([logger])
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 2])
|
||||
def test_shard_param(world_size):
|
||||
run_func = partial(_run_test_shard_param, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
def _run_init_shard_param(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
param = torch.nn.Parameter(data=torch.rand(world_size, 3))
|
||||
sparam = ShardedParam(param, None, True)
|
||||
payload = sparam.payload(torch.device('cuda'))
|
||||
assert (list(payload.shape) == [3])
|
||||
del sparam
|
||||
|
||||
param_shape = (world_size, 3)
|
||||
sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
|
||||
payload = sparam.payload(torch.device('cuda'))
|
||||
assert (list(payload.shape) == [3])
|
||||
|
||||
param_shape = (world_size, 3)
|
||||
sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
|
||||
payload = sparam.payload(torch.device('cuda'))
|
||||
assert (list(payload.shape) == [world_size, 3])
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize("world_size", [1, 4])
|
||||
def test_init_shard_param(world_size):
|
||||
run_func = partial(_run_init_shard_param, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_shard_tensor(2)
|
||||
test_shard_param(2)
|
||||
test_shard_param_v2(2)
|
||||
test_init_shard_param(4)
|
||||
|
|
Loading…
Reference in New Issue