mirror of https://github.com/hpcaitech/ColossalAI
[colotensor] use cpu memory to store state_dict (#1367)
parent
943a96323e
commit
87775a0682
|
@ -318,7 +318,8 @@ class ZeroDDP(ColoDDP):
|
||||||
self.chunk_manager.access_chunk(chunk)
|
self.chunk_manager.access_chunk(chunk)
|
||||||
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
for (name, p), fp32_p in zip(self.named_parameters(), self.fp32_params):
|
||||||
if p is not None:
|
if p is not None:
|
||||||
destination[prefix + name] = fp32_p.clone() if keep_vars else fp32_p.clone().detach()
|
rec_p = fp32_p.clone() if fp32_p.device.type == 'cpu' else fp32_p.cpu()
|
||||||
|
destination[prefix + name] = rec_p if keep_vars else rec_p.detach()
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
self.chunk_manager.release_chunk(chunk)
|
self.chunk_manager.release_chunk(chunk)
|
||||||
for name, buf in self.named_buffers():
|
for name, buf in self.named_buffers():
|
||||||
|
|
|
@ -4,6 +4,20 @@ from colossalai.tensor import ColoTensor, ColoTensorSpec
|
||||||
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
|
||||||
|
|
||||||
|
|
||||||
|
def robust_broadcast(tensor):
|
||||||
|
with torch.no_grad():
|
||||||
|
is_cpu_ten = tensor.device.type == 'cpu'
|
||||||
|
if is_cpu_ten:
|
||||||
|
b_data = tensor.cuda()
|
||||||
|
else:
|
||||||
|
b_data = tensor
|
||||||
|
|
||||||
|
dist.broadcast(b_data, 0)
|
||||||
|
|
||||||
|
if is_cpu_ten:
|
||||||
|
tensor.copy_(b_data)
|
||||||
|
|
||||||
|
|
||||||
def gather_tensor(colo_tensor: ColoTensor) -> None:
|
def gather_tensor(colo_tensor: ColoTensor) -> None:
|
||||||
"""Make colo_tensor replicated when the rank is 0
|
"""Make colo_tensor replicated when the rank is 0
|
||||||
"""
|
"""
|
||||||
|
@ -27,7 +41,7 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||||
"""Reversal operation of `gather_tensor`.
|
"""Reversal operation of `gather_tensor`.
|
||||||
"""
|
"""
|
||||||
if dist_spec.placement == DistPlacementPattern.REPLICATE:
|
if dist_spec.placement == DistPlacementPattern.REPLICATE:
|
||||||
dist.broadcast(colo_tensor.data, 0)
|
robust_broadcast(colo_tensor.data)
|
||||||
else:
|
else:
|
||||||
global_size = colo_tensor.size_global()
|
global_size = colo_tensor.size_global()
|
||||||
|
|
||||||
|
@ -35,7 +49,7 @@ def scatter_tensor(colo_tensor: ColoTensor, dist_spec: _DistSpec) -> None:
|
||||||
entire_data = colo_tensor.data
|
entire_data = colo_tensor.data
|
||||||
else:
|
else:
|
||||||
entire_data = torch.empty(global_size, device=colo_tensor.device)
|
entire_data = torch.empty(global_size, device=colo_tensor.device)
|
||||||
dist.broadcast(entire_data, 0)
|
robust_broadcast(entire_data)
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
if dist.get_rank() == 0:
|
||||||
colo_tensor.set_dist_spec(dist_spec)
|
colo_tensor.set_dist_spec(dist_spec)
|
||||||
|
|
|
@ -19,7 +19,13 @@ from colossalai.tensor import ProcessGroup, ColoParameter
|
||||||
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDict):
|
||||||
for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()):
|
for (k1, t1), (k2, t2) in zip(state_dict.items(), other_state_dict.items()):
|
||||||
assert k1 == k2
|
assert k1 == k2
|
||||||
assert torch.allclose(t1, t2, atol=1e-3, rtol=1e-3)
|
|
||||||
|
if t1.device != t2.device:
|
||||||
|
temp_t2 = t2.to(t1.device)
|
||||||
|
else:
|
||||||
|
temp_t2 = t2
|
||||||
|
|
||||||
|
assert torch.allclose(t1, temp_t2, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
|
||||||
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
def init_ddp(module: torch.nn.Module) -> ColoDDP:
|
||||||
|
|
|
@ -17,7 +17,7 @@ from tests.test_tensor.common_utils import tensor_shard_equal
|
||||||
def run_dist(rank, world_size, port, dp_degree, tp_degree):
|
def run_dist(rank, world_size, port, dp_degree, tp_degree):
|
||||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
|
pg = ProcessGroup(dp_degree=dp_degree, tp_degree=tp_degree)
|
||||||
x = torch.randn(4, 4, device=get_current_device())
|
x = torch.randn(4, 4)
|
||||||
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
|
param = ColoTensor(torch.nn.Parameter(x), spec=ColoTensorSpec(pg))
|
||||||
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
|
spec = ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)
|
||||||
param.set_tensor_spec(*spec)
|
param.set_tensor_spec(*spec)
|
||||||
|
|
Loading…
Reference in New Issue