|
|
|
@ -36,24 +36,6 @@ NUM_HEADS = 4
|
|
|
|
|
TOP_K = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def register_hooks(module: torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
def fwd_hook(module, input, output):
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
name = module._name if hasattr(module, "_name") else module
|
|
|
|
|
print(f"Fwd hook {name} \n output {output}")
|
|
|
|
|
|
|
|
|
|
def bwd_hook(module, grad_input, grad_output):
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
def bwd_pre_hook(module, grad_output):
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
module.register_forward_hook(fwd_hook)
|
|
|
|
|
# module.register_backward_hook(bwd_hook)
|
|
|
|
|
# module.register_full_backward_pre_hook(bwd_pre_hook)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MlpModel(nn.Module):
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
@ -1068,7 +1050,6 @@ def run_with_booster_hybridplugin(config: Tuple[int, ...]):
|
|
|
|
|
all_inputs = [input_embeddings.clone() for _ in range(dp_size)]
|
|
|
|
|
dist.all_gather(all_inputs, input_embeddings, group=plugin.dp_group)
|
|
|
|
|
torch_output_sum = 0
|
|
|
|
|
# torch_model.apply(register_hooks) # register hook for base model
|
|
|
|
|
for input_data_ in all_inputs:
|
|
|
|
|
torch_output = torch_model(inputs_embeds=input_data_.to(dtype)).last_hidden_state.mean()
|
|
|
|
|
torch_output.backward()
|
|
|
|
|