mirror of https://github.com/hpcaitech/ColossalAI
update doc
parent
36ab2cb783
commit
61fdd3464a
|
@ -40,20 +40,16 @@ def _test_fwd(model: torch.nn.Module, gm: ColoGraphModule, node, pair):
|
|||
non_fx_out = model(node, pair)
|
||||
fx_out = gm(node, pair)
|
||||
|
||||
assert torch.allclose(
|
||||
non_fx_out[0], fx_out[0], atol=1e-4
|
||||
), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[0] - fx_out[0])
|
||||
)
|
||||
assert torch.allclose(
|
||||
non_fx_out[1], fx_out[1], atol=1e-4
|
||||
), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[1] - fx_out[1])
|
||||
)
|
||||
assert torch.allclose(non_fx_out[0], fx_out[0],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[0] - fx_out[0]))
|
||||
assert torch.allclose(non_fx_out[1], fx_out[1],
|
||||
atol=1e-4), "fx_out doesn't comply with original output, diff is %.2e" % torch.mean(
|
||||
torch.abs(non_fx_out[1] - fx_out[1]))
|
||||
|
||||
|
||||
def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
|
@ -76,18 +72,14 @@ def _test_autochunk_codegen(rank, msa_len, pair_len, max_memory):
|
|||
"pair": pair.to(torch.device("meta")),
|
||||
},
|
||||
)
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||
)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
# now run it twice to get meta info in graph module, not necessary
|
||||
gm = torch.fx.GraphModule(model, graph)
|
||||
interp = MetaInfoProp(gm)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||
)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
||||
graph.set_codegen(codegen)
|
||||
|
|
|
@ -23,7 +23,8 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|||
|
||||
if msa_len == 32 and pair_len == 64:
|
||||
if max_memory is None:
|
||||
target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191), (161, 166), (198, 203), (6, 69)]
|
||||
target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191),
|
||||
(161, 166), (198, 203), (6, 69)]
|
||||
elif max_memory == 20:
|
||||
target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)]
|
||||
elif max_memory == 25:
|
||||
|
@ -36,24 +37,19 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|||
raise NotImplementedError()
|
||||
|
||||
assert len(found_regions) == len(
|
||||
target_regions
|
||||
), "len of found regions %s doesn't equal len of target regions %s" % (
|
||||
str(found_regions),
|
||||
str(target_regions),
|
||||
)
|
||||
target_regions), "len of found regions %s doesn't equal len of target regions %s" % (
|
||||
str(found_regions),
|
||||
str(target_regions),
|
||||
)
|
||||
for region in target_regions:
|
||||
assert (
|
||||
region in found_regions
|
||||
), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
|
||||
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % (
|
||||
str(region),
|
||||
msa_len,
|
||||
pair_len,
|
||||
max_memory,
|
||||
)
|
||||
for region in found_regions:
|
||||
assert (
|
||||
region in target_regions
|
||||
), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
|
||||
assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % (
|
||||
str(region),
|
||||
msa_len,
|
||||
pair_len,
|
||||
|
@ -62,7 +58,7 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|||
|
||||
|
||||
def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
||||
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly
|
||||
# launch colossalai
|
||||
colossalai.launch(
|
||||
config={},
|
||||
rank=rank,
|
||||
|
@ -77,11 +73,9 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
|||
node = torch.randn(1, msa_len, pair_len, 256).cuda()
|
||||
pair = torch.randn(1, pair_len, pair_len, 128).cuda()
|
||||
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace
|
||||
interp = MetaInfoProp(gm_prop)
|
||||
interp.propagate(
|
||||
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")
|
||||
)
|
||||
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0"))
|
||||
|
||||
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory)
|
||||
chunk_infos = codegen.chunk_infos
|
||||
|
|
Loading…
Reference in New Issue