You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_elixir/test_tracer/test_tf_order.py

38 lines
1.3 KiB

from colossalai.elixir.tracer.param_tracer import generate_tf_order
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS
@run_on_environment_flag('ELX')
def test_tf_forward_backward():
model_fn, data_fn = TEST_MODELS.get('gpt2_micro')
model = model_fn()
data = data_fn()
def forward_backward_fn(local_model, local_input):
local_model(**local_input).backward()
# model.gradient_checkpointing_enable()
tf_order = generate_tf_order(model, data, forward_backward_fn)
params_per_step = tf_order['params_per_step']
assert len(params_per_step) == 32
model.gradient_checkpointing_enable()
tf_order = generate_tf_order(model, data, forward_backward_fn)
params_per_step = tf_order['params_per_step']
checkpoint_info = tf_order['checkpoint_info']
for i, step in enumerate(params_per_step):
print(f'step {i}: {step}')
for c in checkpoint_info:
print(f'checkpoint info: {c}')
assert len(params_per_step) == 44
assert data['input_ids'].device.type == 'cpu'
assert data['attention_mask'].device.type == 'cpu'
for param in model.parameters():
assert param.device.type == 'cpu'
if __name__ == '__main__':
test_tf_forward_backward()