mirror of https://github.com/hpcaitech/ColossalAI
[zero] polish ShardedOptimV2 unittest (#385)
* place params on cpu after zero init context * polish code * bucketzed cpu gpu tensor transter * find a bug in sharded optim unittest * add offload unittest for ShardedOptimV2. * polish code and make it more robustpull/394/head
parent
ce7b2c9ae3
commit
3af13a2c3e
|
@ -79,6 +79,10 @@ class ShardedModelV2(nn.Module):
|
||||||
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
|
self.reducer = ReduceScatterBucketer(reduce_scatter_bucket_size_mb)
|
||||||
self._require_backward_grad_sync: bool = True
|
self._require_backward_grad_sync: bool = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cpu_offload(self):
|
||||||
|
return self._cpu_offload
|
||||||
|
|
||||||
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
||||||
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs)
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
|
|
|
@ -44,6 +44,10 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
super().__init__(optimizer)
|
super().__init__(optimizer)
|
||||||
self.shard_strategy = shard_strategy
|
self.shard_strategy = shard_strategy
|
||||||
self.model: ShardedModelV2 = sharded_model
|
self.model: ShardedModelV2 = sharded_model
|
||||||
|
if cpu_offload and not sharded_model.cpu_offload:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"ShardedOptimizerV2 using cpu_offload, but the sharded_model used to initialize it dose not use cpu_offload"
|
||||||
|
)
|
||||||
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
|
self.device = torch.cuda.current_device() if not cpu_offload else torch.device('cpu')
|
||||||
self.optim_state: OptimState = OptimState.UNSCALED
|
self.optim_state: OptimState = OptimState.UNSCALED
|
||||||
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
self.dp_process_group = dp_process_group or gpc.get_group(ParallelMode.DATA)
|
||||||
|
|
|
@ -24,8 +24,12 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||||
model.train()
|
model.train()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||||
|
if criterion:
|
||||||
y = model(data)
|
y = model(data)
|
||||||
loss = criterion(y, label)
|
loss = criterion(y, label)
|
||||||
|
else:
|
||||||
|
loss = model(data, label)
|
||||||
|
|
||||||
loss = loss.float()
|
loss = loss.float()
|
||||||
if isinstance(model, ShardedModelV2):
|
if isinstance(model, ShardedModelV2):
|
||||||
optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
|
@ -34,19 +38,7 @@ def run_step(model, optimizer, data, label, criterion, enable_autocast=False):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
|
||||||
def run_step_no_criterion(model, optimizer, data, label, enable_autocast=False):
|
def run_dist(rank, world_size, port, cpu_offload):
|
||||||
model.train()
|
|
||||||
optimizer.zero_grad()
|
|
||||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
|
||||||
loss = model(data, label)
|
|
||||||
if isinstance(model, ShardedModelV2):
|
|
||||||
optimizer.backward(loss)
|
|
||||||
else:
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
|
||||||
for model_name in test_models:
|
for model_name in test_models:
|
||||||
|
@ -54,33 +46,33 @@ def run_dist(rank, world_size, port):
|
||||||
shard_strategy = TensorShardStrategy()
|
shard_strategy = TensorShardStrategy()
|
||||||
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func()
|
||||||
model = model(checkpoint=True).cuda()
|
model = model(checkpoint=True).cuda()
|
||||||
zero_model = ShardedModelV2(copy.deepcopy(model), shard_strategy)
|
zero_model = ShardedModelV2(copy.deepcopy(model),
|
||||||
|
shard_strategy,
|
||||||
|
offload_config=dict(device='cpu') if cpu_offload else None)
|
||||||
if dist.get_world_size() > 1:
|
if dist.get_world_size() > 1:
|
||||||
model = DDP(model)
|
model = DDP(model)
|
||||||
optim = Adam(model.parameters(), lr=1e-3)
|
optim = Adam(model.parameters(), lr=1e-3)
|
||||||
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3),
|
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3),
|
||||||
zero_model,
|
zero_model,
|
||||||
shard_strategy,
|
shard_strategy,
|
||||||
|
cpu_offload=cpu_offload,
|
||||||
initial_scale=2**5)
|
initial_scale=2**5)
|
||||||
for i, (data, label) in enumerate(train_dataloader):
|
for i, (data, label) in enumerate(train_dataloader):
|
||||||
if i > 2:
|
if i > 2:
|
||||||
break
|
break
|
||||||
data, label = data.cuda(), label.cuda()
|
data, label = data.cuda(), label.cuda()
|
||||||
if criterion is None:
|
|
||||||
run_step_no_criterion(model, optim, data, label, False)
|
|
||||||
run_step_no_criterion(zero_model, sharded_optim, data, label, False)
|
|
||||||
else:
|
|
||||||
run_step(model, optim, data, label, criterion, False)
|
run_step(model, optim, data, label, criterion, False)
|
||||||
run_step(zero_model, sharded_optim, data, label, criterion, False)
|
run_step(zero_model, sharded_optim, data, label, criterion, False)
|
||||||
check_sharded_params_padding(model, zero_model, loose=True)
|
check_sharded_params_padding(model, zero_model, loose=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize("world_size", [1, 2, 4])
|
@pytest.mark.parametrize("world_size", [1, 2])
|
||||||
def test_sharded_optim_v2(world_size):
|
@pytest.mark.parametrize("cpu_offload", [True, False])
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
def test_sharded_optim_v2(world_size, cpu_offload):
|
||||||
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), cpu_offload=cpu_offload)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_sharded_optim_v2(world_size=2)
|
test_sharded_optim_v2(world_size=2, cpu_offload=True)
|
||||||
|
|
Loading…
Reference in New Issue