|
|
|
@ -36,24 +36,6 @@ NUM_HEADS = 4
|
|
|
|
|
TOP_K = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_hooks(module: torch.nn.Module): |
|
|
|
|
|
|
|
|
|
def fwd_hook(module, input, output): |
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
name = module._name if hasattr(module, "_name") else module |
|
|
|
|
print(f"Fwd hook {name} \n output {output}") |
|
|
|
|
|
|
|
|
|
def bwd_hook(module, grad_input, grad_output): |
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
|
def bwd_pre_hook(module, grad_output): |
|
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
|
module.register_forward_hook(fwd_hook) |
|
|
|
|
# module.register_backward_hook(bwd_hook) |
|
|
|
|
# module.register_full_backward_pre_hook(bwd_pre_hook) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MlpModel(nn.Module): |
|
|
|
|
def __init__( |
|
|
|
|
self, |
|
|
|
@ -1068,7 +1050,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|
|
|
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)] |
|
|
|
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group) |
|
|
|
|
torch_output_sum = 0 |
|
|
|
|
# torch_model.apply(register_hooks) # register hook for base model |
|
|
|
|
for input_data_ in all_inputs: |
|
|
|
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean() |
|
|
|
|
torch_output.backward() |
|
|
|
|