|
|
|
@ -23,7 +23,8 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|
|
|
|
|
|
|
|
|
if msa_len == 32 and pair_len == 64: |
|
|
|
|
if max_memory is None: |
|
|
|
|
target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191), (161, 166), (198, 203), (6, 69)] |
|
|
|
|
target_regions = [(142, 154), (366, 373), (233, 283), (301, 351), (127, 134), (204, 228), (167, 191), |
|
|
|
|
(161, 166), (198, 203), (6, 69)] |
|
|
|
|
elif max_memory == 20: |
|
|
|
|
target_regions = [(142, 154), (369, 373), (233, 269), (301, 351)] |
|
|
|
|
elif max_memory == 25: |
|
|
|
@ -36,24 +37,19 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|
|
|
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
assert len(found_regions) == len( |
|
|
|
|
target_regions |
|
|
|
|
), "len of found regions %s doesn't equal len of target regions %s" % ( |
|
|
|
|
str(found_regions), |
|
|
|
|
str(target_regions), |
|
|
|
|
) |
|
|
|
|
target_regions), "len of found regions %s doesn't equal len of target regions %s" % ( |
|
|
|
|
str(found_regions), |
|
|
|
|
str(target_regions), |
|
|
|
|
) |
|
|
|
|
for region in target_regions: |
|
|
|
|
assert ( |
|
|
|
|
region in found_regions |
|
|
|
|
), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % ( |
|
|
|
|
assert (region in found_regions), "region:%s not in found regions for msa:%d, pair:%d, maxmem:%d" % ( |
|
|
|
|
str(region), |
|
|
|
|
msa_len, |
|
|
|
|
pair_len, |
|
|
|
|
max_memory, |
|
|
|
|
) |
|
|
|
|
for region in found_regions: |
|
|
|
|
assert ( |
|
|
|
|
region in target_regions |
|
|
|
|
), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % ( |
|
|
|
|
assert (region in target_regions), "region:%s should not be found for msa:%d, pair:%d, maxmem:%d" % ( |
|
|
|
|
str(region), |
|
|
|
|
msa_len, |
|
|
|
|
pair_len, |
|
|
|
@ -62,7 +58,7 @@ def assert_chunk_infos(chunk_infos, max_memory, msa_len, pair_len):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _test_autochunk_search(rank, msa_len, pair_len, max_memory): |
|
|
|
|
# launch colossalai to make sure we could execute colossalai.utils.checkpoint currectly |
|
|
|
|
# launch colossalai |
|
|
|
|
colossalai.launch( |
|
|
|
|
config={}, |
|
|
|
|
rank=rank, |
|
|
|
@ -77,11 +73,9 @@ def _test_autochunk_search(rank, msa_len, pair_len, max_memory):
|
|
|
|
|
node = torch.randn(1, msa_len, pair_len, 256).cuda() |
|
|
|
|
pair = torch.randn(1, pair_len, pair_len, 128).cuda() |
|
|
|
|
|
|
|
|
|
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace |
|
|
|
|
gm_prop = torch.fx.symbolic_trace(model) # must use symbolic_trace |
|
|
|
|
interp = MetaInfoProp(gm_prop) |
|
|
|
|
interp.propagate( |
|
|
|
|
MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0") |
|
|
|
|
) |
|
|
|
|
interp.propagate(MetaTensor(node, fake_device="cuda:0"), MetaTensor(pair, fake_device="cuda:0")) |
|
|
|
|
|
|
|
|
|
codegen = AutoChunkCodeGen(gm_prop, max_memory=max_memory) |
|
|
|
|
chunk_infos = codegen.chunk_infos |
|
|
|
|