mirror of https://github.com/hpcaitech/ColossalAI
[test] update shardformer tests
parent
b0b8ad2823
commit
2d6cc07feb
|
@ -12,8 +12,8 @@ def build_model(model_fn, enable_fused_normalization=True, enable_tensor_paralle
|
||||||
enable_tensor_parallelism=enable_tensor_parallelism)
|
enable_tensor_parallelism=enable_tensor_parallelism)
|
||||||
model_copy = copy.deepcopy(org_model)
|
model_copy = copy.deepcopy(org_model)
|
||||||
shard_former = ShardFormer(shard_config=shard_config)
|
shard_former = ShardFormer(shard_config=shard_config)
|
||||||
sharded_model = shard_former.optimize(model_copy).cuda()
|
sharded_model, shared_params = shard_former.optimize(model_copy)
|
||||||
return org_model, sharded_model
|
return org_model, sharded_model.cuda()
|
||||||
|
|
||||||
|
|
||||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||||
|
|
|
@ -44,7 +44,7 @@ def check_shardformer_with_ddp(rank, world_size, port):
|
||||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
# create and shard model
|
# create and shard model
|
||||||
model = model_fn().cuda()
|
model = model_fn().cuda()
|
||||||
sharded_model = shardformer.optimize(model)
|
sharded_model, _ = shardformer.optimize(model)
|
||||||
|
|
||||||
# add ddp
|
# add ddp
|
||||||
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)
|
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)
|
||||||
|
|
Loading…
Reference in New Issue