mirror of https://github.com/hpcaitech/ColossalAI
[zero] bucketized tensor cpu gpu copy (#368)
parent
44e4891f57
commit
00670c870e
|
@ -4,10 +4,6 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: yapf
|
- id: yapf
|
||||||
args: ['--style=.style.yapf', '--parallel', '--in-place']
|
args: ['--style=.style.yapf', '--parallel', '--in-place']
|
||||||
- repo: https://github.com/pycqa/flake8
|
|
||||||
rev: '4.0.1'
|
|
||||||
hooks:
|
|
||||||
- id: flake8
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-clang-format
|
- repo: https://github.com/pre-commit/mirrors-clang-format
|
||||||
rev: v13.0.1
|
rev: v13.0.1
|
||||||
hooks:
|
hooks:
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .bucket_tensor_copy import BucketizedTensorCopy
|
||||||
|
|
||||||
|
__all__ = ['BucketizedTensorCopy']
|
|
@ -0,0 +1,61 @@
|
||||||
|
import torch
|
||||||
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
class BucketizedTensorCopy(object):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chunk_size: int,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
torch.nn.Parameter CPU (fp32) -> ShardedParam GPU (fp16)
|
||||||
|
TODO(jiaruifang) The class is a little bit hardcoded
|
||||||
|
I will make it more general later.
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
self._offset = 0
|
||||||
|
self._cpu_buffer = torch.empty(chunk_size, dtype=torch.float, device=torch.device("cpu:0"), pin_memory=True)
|
||||||
|
self._cuda_buffer = torch.empty(chunk_size,
|
||||||
|
dtype=torch.half,
|
||||||
|
device=torch.device(f"cuda:{get_current_device()}"))
|
||||||
|
|
||||||
|
self._buffered_param_list: List[ShardedParamV2] = []
|
||||||
|
self._numel_list = []
|
||||||
|
|
||||||
|
def copy(self, src_param: torch.nn.Parameter, target_param: ShardedParamV2):
|
||||||
|
assert isinstance(target_param, ShardedParamV2)
|
||||||
|
assert isinstance(src_param, torch.nn.Parameter)
|
||||||
|
|
||||||
|
numel = src_param.numel()
|
||||||
|
|
||||||
|
if self._offset + numel > self.chunk_size:
|
||||||
|
self.flush()
|
||||||
|
|
||||||
|
assert src_param.data.device.type == 'cpu'
|
||||||
|
self._cpu_buffer.narrow(0, self._offset, numel).copy_(src_param.data.view(-1))
|
||||||
|
|
||||||
|
self._buffered_param_list.append(target_param)
|
||||||
|
self._numel_list.append(numel)
|
||||||
|
|
||||||
|
self._offset += numel
|
||||||
|
|
||||||
|
def flush(self):
|
||||||
|
"""
|
||||||
|
flush to cuda memory
|
||||||
|
"""
|
||||||
|
self._cuda_buffer.copy_(self._cpu_buffer)
|
||||||
|
flush_offset = 0
|
||||||
|
for sparam, numel in zip(self._buffered_param_list, self._numel_list):
|
||||||
|
sparam.data.copy_payload(self._cpu_buffer.narrow(0, flush_offset, numel))
|
||||||
|
flush_offset += numel
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._buffered_param_list = []
|
||||||
|
self._numel_list = []
|
||||||
|
self._offset = 0
|
|
@ -88,14 +88,19 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return
|
return
|
||||||
|
|
||||||
# Write master param to p.data
|
# assign master param pointers to p.data.
|
||||||
|
# We will not trigger data copy here.
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
p.data = self.master_params[p]
|
p.data = self.master_params[p]
|
||||||
# Now p.data is sharded
|
# Now p.data is sharded
|
||||||
# So optimizer states are sharded naturally
|
# So optimizer states are sharded naturally
|
||||||
|
|
||||||
ret = self.optim.step(*args, **kwargs)
|
ret = self.optim.step(*args, **kwargs)
|
||||||
# Write master param to payload
|
|
||||||
|
# Copy master param data (fp32) to payload of col_attr (fp16)
|
||||||
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||||
|
# a chunk.
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
is_param_sharded = p.col_attr.data.is_sharded
|
is_param_sharded = p.col_attr.data.is_sharded
|
||||||
|
@ -108,7 +113,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self.shard_strategy.shard([p.col_attr.data])
|
self.shard_strategy.shard([p.col_attr.data])
|
||||||
# We have to use `copy_payload` instead of `reset_payload`
|
# We have to use `copy_payload` instead of `reset_payload`
|
||||||
# Since p.data is fp32 and p.col_attr.data is fp16
|
# Since p.data is fp32 and p.col_attr.data is fp16
|
||||||
|
|
||||||
|
# TODO() optimize this line
|
||||||
p.col_attr.data.copy_payload(p.data)
|
p.col_attr.data.copy_payload(p.data)
|
||||||
|
|
||||||
if not is_param_sharded:
|
if not is_param_sharded:
|
||||||
# We gather full fp16 param here
|
# We gather full fp16 param here
|
||||||
self.shard_strategy.gather([p.col_attr.data])
|
self.shard_strategy.gather([p.col_attr.data])
|
||||||
|
|
|
@ -14,7 +14,6 @@ class ShardedTensor(object):
|
||||||
self.world_size = dist.get_world_size(self.process_group)
|
self.world_size = dist.get_world_size(self.process_group)
|
||||||
self.local_rank = dist.get_rank(self.process_group)
|
self.local_rank = dist.get_rank(self.process_group)
|
||||||
self._is_sharded = False
|
self._is_sharded = False
|
||||||
self._payload = tensor
|
|
||||||
|
|
||||||
self._origin_shape = tensor.shape
|
self._origin_shape = tensor.shape
|
||||||
self._origin_numel = tensor.numel()
|
self._origin_numel = tensor.numel()
|
||||||
|
@ -41,7 +40,7 @@ class ShardedTensor(object):
|
||||||
return self._payload
|
return self._payload
|
||||||
|
|
||||||
def copy_payload(self, tensor):
|
def copy_payload(self, tensor):
|
||||||
self._payload.copy_(tensor)
|
self._payload.view(-1).copy_(tensor.view(-1))
|
||||||
|
|
||||||
def reset_payload(self, tensor):
|
def reset_payload(self, tensor):
|
||||||
del self._payload
|
del self._payload
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
from colossalai.utils.commons import BucketizedTensorCopy
|
||||||
|
from colossalai.zero.sharded_param import ShardedParamV2
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
import torch
|
||||||
|
import colossalai
|
||||||
|
|
||||||
|
|
||||||
|
def test_bucket_copy():
|
||||||
|
# init dist env
|
||||||
|
colossalai.launch(config={}, rank=0, world_size=1, host='localhost', port=free_port(), backend='nccl')
|
||||||
|
|
||||||
|
copyer = BucketizedTensorCopy(20)
|
||||||
|
|
||||||
|
shape_list = [(2, 3), (5), (8), (12)]
|
||||||
|
src_param_list = []
|
||||||
|
tgt_param_list = []
|
||||||
|
for shape in shape_list:
|
||||||
|
# on CPU
|
||||||
|
src_param = torch.nn.Parameter(torch.randn(shape, dtype=torch.float, device=torch.device('cpu')))
|
||||||
|
print(src_param)
|
||||||
|
# on GPU
|
||||||
|
tgt_param = ShardedParamV2(torch.nn.Parameter(torch.ones(shape, dtype=torch.half, device=torch.device('cuda'))))
|
||||||
|
|
||||||
|
src_param_list.append(src_param)
|
||||||
|
tgt_param_list.append(tgt_param)
|
||||||
|
|
||||||
|
copyer.copy(src_param, tgt_param)
|
||||||
|
|
||||||
|
copyer.flush()
|
||||||
|
|
||||||
|
for src_param, tgt_param in zip(src_param_list, tgt_param_list):
|
||||||
|
print(tgt_param.data.payload)
|
||||||
|
diff = src_param.cpu().float() - tgt_param.data.payload.cpu().float()
|
||||||
|
assert torch.allclose(src_param.cpu().float(), tgt_param.data.payload.cpu().float(), rtol=1e-03,
|
||||||
|
atol=1e-03), f"diff {diff}"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_bucket_copy()
|
Loading…
Reference in New Issue