# mypy: allow-untyped-defs
import torch.fx as fx

def set_trace(gm: fx.GraphModule) -> fx.GraphModule:
    """
    Sets a breakpoint in `gm`'s generated python code. It drops into pdb when
    `gm` gets run.

    Args:
        gm: graph module to insert breakpoint. It is then recompiled for it to
            take effect.

    Returns:
        the `gm` with breakpoint inserted.
    """
    def insert_pdb(body):
        return ["import pdb; pdb.set_trace()\n", *body]

    with gm.graph.on_generate_code(
        make_transformer=lambda cur_transform: (
            # new code transformer to register
            lambda body: (
                insert_pdb(
                    cur_transform(body) if cur_transform
                    else body
                )
            )
        )
    ):
        gm.recompile()

    return gm
