# mypy: allow-untyped-defs
import dataclasses
import importlib
import logging
import os
from typing import (
    Any,
    Callable,
    Dict,
    Final,
    List,
    Mapping,
    Optional,
    Sequence,
    Set,
    Tuple,
    TYPE_CHECKING,
    Union,
)
from typing_extensions import TypeAlias

import torch
import torch._C
import torch._ops
import torch._prims.executor
import torch.fx
from torch._subclasses.fake_tensor import FakeTensor
from torch.fx._compatibility import compatibility
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.tools_common import CALLABLE_NODE_OPS
from torch.utils import _pytree


if TYPE_CHECKING:
    import onnx
    import onnxruntime
    from onnxruntime.capi import _pybind_state as ORTC

    import torch.onnx
    import torch.onnx._internal
    import torch.onnx._internal._exporter_legacy
    import torch.onnx._internal.diagnostics
    import torch.onnx._internal.fx.decomposition_table
    import torch.onnx._internal.fx.passes  # noqa: TCH004


_SUPPORT_ONNXRT: Optional[bool] = None

__all__ = [
    "is_onnxrt_backend_supported",
    "torch_compile_backend",
    "OrtExecutionProvider",
    "OrtBackendOptions",
    "OrtBackend",
]


def is_onnxrt_backend_supported() -> bool:
    """Returns ``True`` if ONNX Runtime dependencies are installed and usable
    to support TorchDynamo backend integration; ``False`` otherwise.

    Example::

        # xdoctest: +REQUIRES(env:TORCH_DOCTEST_ONNX)
        >>> import torch
        >>> if torch.onnx.is_onnxrt_backend_supported():
        ...     @torch.compile(backend="onnxrt")
        ...     def f(x):
        ...             return x * x
        ...     print(f(torch.randn(10)))
        ... else:
        ...     print("pip install onnx onnxscript onnxruntime")
        ...
    """
    global _SUPPORT_ONNXRT

    if _SUPPORT_ONNXRT is None:
        # `onnxruntime` might import a lot of other runtime packages,
        # e.g. apex, deepspeed, transformers.
        # So lazy-importing onnxruntime to avoid possible circular import.
        try:
            importlib.import_module("onnxruntime")
            importlib.import_module("onnxruntime.capi._pybind_state")

            # This is not use directly in DORT but needed by underlying exporter,
            # so we still need to check if it exists.
            importlib.import_module("onnxscript")

            import torch.onnx  # noqa: F401
            import torch.onnx._internal  # noqa: F401
            import torch.onnx._internal._exporter_legacy  # noqa: F401
            import torch.onnx._internal.diagnostics  # noqa: F401
            from torch.onnx._internal.fx import (  # noqa: F401
                decomposition_table,
                fx_onnx_interpreter,
                passes,
                type_utils,
            )

            _SUPPORT_ONNXRT = True
        except ImportError:
            _SUPPORT_ONNXRT = False

    return _SUPPORT_ONNXRT


_dumped_onnx_model: Dict[str, int] = {}


def _dump_onnx_model(
    model_string: bytes, graph_module: Optional[torch.fx.GraphModule] = None
) -> str:
    """Stores the onnx model into a file.
    The name is "{ONNXRT_DUMP_PATH}{N}.onnx"
    where *N* is the number of files already stored with
    this prefix.
    If graph_module is not None, the graph is stored as a string with
    the same filename except the extension (.txt).
    """
    prefix = os.environ.get("ONNXRT_DUMP_PATH", None)
    if not prefix:
        return ""
    n = _dumped_onnx_model.get(prefix, -1) + 1
    filename = f"{prefix}{n}.onnx"
    with open(filename, "wb") as f:
        f.write(model_string)
    _dumped_onnx_model[prefix] = n
    if graph_module is not None:
        filename_txt = f"{prefix}{n}.txt"
        with open(filename_txt, "w", encoding="utf-8") as f:
            f.write(str(graph_module.graph))
    return filename


def _infer_default_eps() -> Sequence[str]:
    # TODO: select a good default based on the capabilities of the host
    # e.g. DML on Windows, etc.
    return ["CPUExecutionProvider"]


def _nvtx_range_push(name: str):
    """If PyTorch is installed with CUDA support, this starts NVTX range.

    Check torch.cuda.nvtx.range_push's document for more details.
    """
    if torch.cuda.is_available():
        torch.cuda.nvtx.range_push(name)


def _nvtx_range_pop():
    """If PyTorch is installed with CUDA support, this terminates NVTX range.

    Check torch.cuda.nvtx.range_pop's document for more details.
    """
    if torch.cuda.is_available():
        torch.cuda.nvtx.range_pop()


def _get_ort_device_type(device_type: str):
    from onnxruntime.capi import _pybind_state as ORTC

    if device_type == "cuda":
        return ORTC.OrtDevice.cuda()
    if device_type == "cpu":
        return ORTC.OrtDevice.cpu()
    # ort pytorch device is mapped to NPU OrtDevice type
    if device_type == "maia":
        return ORTC.OrtDevice.npu()
    raise ValueError("Unsupported device type: " + device_type)


logger = logging.getLogger(__name__)
# Uncomment the following lines to print out development info.
# logging.basicConfig(level=logging.WARNING)
# logger.setLevel(logging.WARNING)


class OrtOperatorSupport(OperatorSupport):
    """Operator support for ONNXRuntime backend.

    It has two-level of support decision. One is via support_dict and the other one
    is via extra_support_dict. The logic of using support_dict is implemented in
    OrtOperatorSupport and extra_support_dict is used by OperatorSupport.is_node_supported.
    """

    def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]):
        # Use extra_support_dict[op_name] = None to indicate
        # we support op_name with all input types. Otherwise,
        # see support_dict (type: SupportDict) in operator_support.py
        # for specifying supported types.
        super().__init__(extra_support_dict)
        self._onnx_support_dict = support_dict

    def is_node_supported(
        self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
    ) -> bool:
        # OperatorSupport.is_node_supported returns True for non-callable nodes.
        # Since ORT can't execute them, we return False here to override the base
        # behavior.
        if node.op not in CALLABLE_NODE_OPS:
            return False
        # This is the and the only place to decide if aten op is supported.
        if node.op == "call_function" and node.target in self._onnx_support_dict:
            logger.info(
                "support_dict supports node.target: %s (type: %s)",
                node.target,
                type(node.target),
            )
            return True
        # If node.target is not in support_dict, we still want to check if torch.jit.script
        # can convert it to ONNX equivalence. Let's use base mechanism to do this.
        # See extra_support_dict  for supported ops.
        if super().is_node_supported(submodules, node):
            logger.info(
                "extra_support_dict supports node.target: %s (type: %s)",
                node.target,
                type(node.target),
            )
            return True
        logger.warning(
            "support_dict and extra_support_dict don't support node.target: %s (type: %s)",
            node.target,
            type(node.target),
        )
        return False


def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None:
    """
    In torch.fx.Graph, placeholder is a special assignment node. If it's not
    executed in the beginning, it could overwrite values computed by upstream
    nodes.
    """

    graph = graph_module.graph
    placeholders = []
    first_not_placeholder = None
    for node in graph.nodes:
        if node.op == "placeholder":
            placeholders.append(node)
        if first_not_placeholder is None and node.op != "placeholder":
            first_not_placeholder = node
    if first_not_placeholder is None:
        return
    for placeholder in placeholders:
        first_not_placeholder.prepend(placeholder)


def _infer_ep_from_device(*args) -> Tuple[str, ...]:
    """Return the first valid device (i.e., GPU or CPU) in argument list."""
    eps = []
    for arg in args:
        if hasattr(arg, "device"):
            device = arg.device
            if device.type == "cuda":
                eps.append("CUDAExecutionProvider")
            elif device.type == "cpu":
                eps.append("CPUExecutionProvider")
    return tuple(eps)


def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]:
    placeholders = []
    for node in graph_module.graph.nodes:
        if node.op == "placeholder":
            if hasattr(node, "meta") and "val" in node.meta:
                assert isinstance(node.meta["val"], torch.Tensor)
            placeholders.append(node)
    return tuple(placeholders)


def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any:
    """Collect "val" fields from outputs metadata in this torch.fx.GraphModule."""
    for node in graph_module.graph.nodes:
        if node.op == "output":
            # Output node is unique. Let's retrieve output values from
            # this node's input list. And then just return.
            return node.args[0]
    raise ValueError("No output node found in this torch.fx.GraphModule.")


def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]:
    """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule."""
    flattened_output_args, _ = _pytree.tree_flatten(
        _extract_graph_module_outputs(graph_module)
    )
    # Output arguments with example value (type: torch.Tensor) in the `graph_module`.
    selected_output_args = [
        output_arg.meta["val"]
        for output_arg in flattened_output_args
        # output_arg must have tensor for its device information.
        # Otherwise, skip it.
        if (hasattr(output_arg, "meta") and "val" in output_arg.meta)
    ]
    return _infer_ep_from_device(*selected_output_args)


def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]:
    """Sort execution providers in eps based on pre-set priority."""

    def get_execution_provider_priority(ep: str) -> int:
        if ep == "CPUExecutionProvider":
            # Lowest priority.
            return 2
        if ep == "CUDAExecutionProvider":
            # Higher priority than CPU but lower than
            # other specialized EPs.
            return 1
        # Highest priority.
        return 0

    unique_eps = set(eps)
    return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True))


def _get_onnx_devices(
    values: Tuple[
        Union[
            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
        ],
        ...,
    ],
) -> Tuple["ORTC.OrtDevice", ...]:
    from onnxruntime.capi import _pybind_state as ORTC

    def _device_id_or_zero(device_id: int) -> int:
        return device_id or 0

    def _map_tensor_or_sym_to_device(
        value: Union[
            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
        ],
    ) -> int:
        if isinstance(value, torch.Tensor):
            return ORTC.OrtDevice(
                _get_ort_device_type(value.device.type),
                ORTC.OrtDevice.default_memory(),
                _device_id_or_zero(value.device.index),
            )
        elif isinstance(
            value, (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool)
        ):
            return ORTC.OrtDevice(
                _get_ort_device_type("cpu"), ORTC.OrtDevice.default_memory(), 0
            )
        else:
            raise ValueError("Unsupported value type: " + str(type(value)))

    if len(values) > 0:
        ort_devices = tuple(_map_tensor_or_sym_to_device(value) for value in values)
        return ort_devices
    else:
        return (_map_tensor_or_sym_to_device(1),)


def _get_ortvalues_from_torch_tensors(
    tensors: Tuple[torch.Tensor, ...], devices: Tuple["ORTC.OrtDevice", ...]
) -> Tuple[torch.Tensor, ...]:
    from onnxruntime.capi import _pybind_state as ORTC

    from torch.onnx._internal.fx.type_utils import _TORCH_DTYPE_TO_NUMPY_DTYPE

    ortvalues = ORTC.OrtValueVector()
    ortvalues.reserve(len(tensors))
    dtypes = []
    shapes = []
    data_ptrs = []

    for tensor in tensors:
        dtypes.append(_TORCH_DTYPE_TO_NUMPY_DTYPE[tensor.dtype])
        shapes.append(tensor.size())
        data_ptrs.append(tensor.data_ptr())
    ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices)
    return ortvalues


def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor:
    if tensor.is_sparse:
        raise ValueError("sparse tensor is not yet supported.")
    out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device)
    return out


def _adjust_scalar_from_fx_to_onnx(
    dynamo_value: Union[
        torch.Tensor,
        int,
        float,
        bool,
    ],
    value_info: "onnx.ValueInfoProto",  # type: ignore[name-defined]
) -> torch.Tensor:
    """Helper function to wrap PyTorch variables as torch.Tensor"""
    if (
        isinstance(dynamo_value, torch.Tensor)
        and len(value_info.type.tensor_type.shape.dim) == 0
        and dynamo_value.shape == (1,)
    ):
        # ONNX expect a scalar with empty shape.
        # In contrast, PyTorch usually allows implicit
        # conversion between shape=() and shape=(1,).
        #
        # Below, PyTorch's shape (1,) is reshaped to ().
        return torch.squeeze(dynamo_value)
    elif isinstance(dynamo_value, int):
        return torch.tensor(dynamo_value, dtype=torch.int64)
    elif isinstance(dynamo_value, float):
        return torch.tensor(dynamo_value, dtype=torch.float32)
    elif isinstance(dynamo_value, bool):
        return torch.tensor(dynamo_value, dtype=torch.bool)
    else:
        assert isinstance(dynamo_value, torch.Tensor)
        return dynamo_value.contiguous()


def _adjust_scalar_from_onnx_to_fx(
    tensor: torch.Tensor,
    prim_value: Union[
        torch.Tensor,
        torch.SymInt,
        int,
        torch.SymFloat,
        float,
        torch.SymBool,
        bool,
    ],
) -> Union[
    torch.Tensor,
    int,
    float,
    bool,
]:
    """Helper function to wrap ORT-produced torch.Tensor as PyTorch variables"""
    assert isinstance(tensor, torch.Tensor), "ORT's output must be tensor."
    if isinstance(
        prim_value,
        (torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool),
    ):
        # Convert tensor back to scalar to match Dynamo's expectation.
        return tensor.item()
    return tensor


def _run_onnx_session_with_ortvaluevector(
    sess: "onnxruntime.InferenceSession",
    input_names: Tuple[str, ...],
    inputs: Tuple[torch.Tensor, ...],
    input_devices: Tuple["ORTC.OrtDevice", ...],
    output_names: Tuple[str, ...],
    outputs: Tuple[torch.Tensor, ...],
    output_devices: Tuple["ORTC.OrtDevice", ...],
    preallocate_output: bool,
    input_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
    normalized_prim_outputs: Tuple[
        Union[
            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
        ],
        ...,
    ],
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
    import onnxruntime
    from onnxruntime.capi import _pybind_state as ORTC

    _nvtx_range_push("contiguous")
    inputs = tuple(
        _adjust_scalar_from_fx_to_onnx(arg, value_info)
        for arg, value_info in zip(inputs, input_value_infos)
    )
    _nvtx_range_pop()

    _nvtx_range_push("push_back_batch")
    ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices)

    # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue.
    # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue
    # to torch Tensor transferring the ownership.
    if preallocate_output:
        pth_outputs = tuple(
            _to_real_tensor(t) if isinstance(t, FakeTensor) else t for t in outputs
        )
        ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices)
    else:
        ort_outputs = ORTC.OrtValueVector()
    _nvtx_range_pop()

    _nvtx_range_push("run_with_ortvaluevector")
    run_options = onnxruntime.RunOptions()
    run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
    sess.run_with_ortvaluevector(
        run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices
    )
    _nvtx_range_pop()

    # Post-processing step:
    #  wrap ORT's outputs to the schema represented by
    #  `prim_output` (obtained by running the original
    #  torch.fx.GraphModule).
    if preallocate_output:
        # Profile the ORT-to-PyTorch type cast below
        _nvtx_range_push("after run_with_ortvaluevector")
        # Outputs are stored on pre-allocated torch.Tensors' memory,
        # so this case doesn't need to convert ORTValue to torch.Tensor.
        pth_outputs = tuple(
            _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output)  # type: ignore[misc]
            for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
        )
        _nvtx_range_pop()
        return pth_outputs
    else:
        # Profile the two ORT-to-PyTorch type casts below
        _nvtx_range_push("after run_with_ortvaluevector")
        # Map ORTValue to torch.Tensor.
        pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor(
            ort_outputs
        )
        # Change some torch.Tensor to int, float, bool.
        pth_outputs = tuple(
            _adjust_scalar_from_onnx_to_fx(onnx_output, prim_output)  # type: ignore[misc]
            for onnx_output, prim_output in zip(pth_outputs, normalized_prim_outputs)
        )
        _nvtx_range_pop()
        return pth_outputs


def _run_onnx_session_with_fetch(
    sess: "onnxruntime.InferenceSession",
    input_names: Tuple[str, ...],
    inputs: Tuple[torch.Tensor, ...],
    input_devices: Tuple["ORTC.OrtDevice", ...],
    output_names: Tuple[str, ...],
    outputs: Tuple[torch.Tensor, ...],
    output_devices: Tuple["ORTC.OrtDevice", ...],
    preallocate_output: bool,
    input_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
    normalized_prim_outputs: Tuple[
        Union[
            torch.Tensor, torch.SymInt, int, torch.SymFloat, float, torch.SymBool, bool
        ],
        ...,
    ],
) -> Tuple[Union[torch.Tensor, int, float, bool], ...]:
    import onnxruntime

    inputs = tuple(
        _adjust_scalar_from_fx_to_onnx(arg, value_info)
        for arg, value_info in zip(inputs, input_value_infos)
    )
    feed = {
        name: onnxruntime.OrtValue.ortvalue_from_numpy(tensor.cpu().numpy())
        for name, tensor in zip(input_names, inputs)
    }
    ort_outputs = sess.run(output_names, feed)
    pth_outputs = tuple(
        _adjust_scalar_from_onnx_to_fx(
            torch.from_numpy(value),
            prim_output,
        )
        for value, prim_output in zip(ort_outputs, normalized_prim_outputs)
    )
    return pth_outputs


class OrtExecutionInfoPerSession:
    """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession"""

    def __init__(
        self,
        session: "onnxruntime.InferenceSession",
        input_names: Tuple[str, ...],
        input_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
        output_names: Tuple[str, ...],
        output_value_infos: Tuple["onnx.ValueInfoProto", ...],  # type: ignore[name-defined]
        input_devices: Tuple["ORTC.OrtDevice", ...],
        output_devices: Tuple["ORTC.OrtDevice", ...],
        example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor],
    ):
        # Carrier of ONNX model and its executor.
        self.session: onnxruntime.InferenceSession = session
        # For the ONNX model stored in self.session, self.input_names[i] is the
        # name of the i-th positional input.
        self.input_names: Tuple[str, ...] = input_names
        # self.input_name[i]'s type information is stored in self.input_value_infos[i].
        self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos  # type: ignore[name-defined]
        # Similar to self.input_names, but for outputs.
        self.output_names: Tuple[str, ...] = output_names
        # Similar to self.input_value_infos but for outputs.
        self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos  # type: ignore[name-defined]
        # For the ONNX model stored in self.session, self.input_devices[i] is the
        # i-th positional input's device.
        self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices
        # Similar to self.input_devices, but for outputs.
        self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices
        # This is the outputs of executing the original torch.fx.GraphModule with example inputs
        # (i.e., args passed into OrtBackend._ort_acclerated_call).
        self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = (
            example_outputs
        )

    def is_supported(self, *args):
        from torch.onnx._internal.fx.type_utils import (
            _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE,
            from_python_type_to_onnx_tensor_element_type,
        )

        # Compare the args and the input schema in ONNX model and
        # return the first match.
        if len(args) != len(self.input_value_infos):
            return False
        for arg, value_info in zip(args, self.input_value_infos):
            if not isinstance(arg, (torch.Tensor, float, int)):
                return False

            # Check Python scalars such as int, float, and bool.
            if isinstance(arg, (int, float, bool)):
                # Map, e.g., float to onnx.TensorProto.FLOAT.
                onnx_dtype = from_python_type_to_onnx_tensor_element_type(type(arg))
                if onnx_dtype != value_info.type.tensor_type.elem_type:
                    return False
                if len(value_info.type.tensor_type.shape.dim) != 0:
                    return False
                continue

            # Check tensor.
            onnx_dtype = _TORCH_DTYPE_TO_ONNX_TENSOR_ELEMENT_TYPE[arg.dtype]
            if onnx_dtype != value_info.type.tensor_type.elem_type:
                return False
            for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim):
                if isinstance(dim, int) and (
                    onnx_dim.dim_value == dim or onnx_dim.dim_param
                ):
                    continue
                elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param:
                    continue
                else:
                    return False
        return True


@dataclasses.dataclass
class OrtExecutionInfoForAllGraphModules:
    def __init__(self) -> None:
        # All sessions (and their related information) created by exporting the same GraphModule
        # with different inputs.
        self.execution_info_per_graph_module: Dict[
            torch.fx.GraphModule, List[OrtExecutionInfoPerSession]
        ] = {}

    def search_reusable_session_execution_info(
        self, graph_module: torch.fx.GraphModule, *args
    ):
        if graph_module not in self.execution_info_per_graph_module:
            return None
        # All execution information for ONNX models exported from the same `graph_module`
        # with different inputs.
        candidates = self.execution_info_per_graph_module[graph_module]

        for candidate in candidates:
            if candidate.is_supported(*args):
                # Returns the first session that accepts this input schema.
                return candidate
        # No reusable session found.
        return None

    def cache_session_execution_info(
        self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession
    ):
        if graph_module not in self.execution_info_per_graph_module:
            self.execution_info_per_graph_module[graph_module] = [info]
        else:
            self.execution_info_per_graph_module[graph_module].append(info)


OrtExecutionProvider: TypeAlias = Union[str, Tuple[str, Mapping[str, Any]]]
"""Either the name of an ONNX Runtime execution provider as a string or
a 2-tuple of the name and a dictionary of execution provider options.

Examples::

    >>> "CPUExecutionProvider"

    >>> ("CUDAExecutionProvider", {"device_id": 3})

"""


@dataclasses.dataclass(frozen=True)
@compatibility(is_backward_compatible=False)
class OrtBackendOptions:
    """Options for constructing an ``OrtBackend``, the ONNX Runtime
    backend (``"onnxrt"``) for ``torch.compile``.

    Example::

        >>> @torch.compile(
        ...     backend="onnxrt",
        ...     options=torch.onnx._OrtBackendOptions(...),
        ... )
        ... def ort_function(x):
        ...     return x ** x
    """

    preferred_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
    """An optional sequence of execution providers to be prioritized ahead of any
    execution providers that may be inferred (see ``infer_execution_providers``).
    """

    infer_execution_providers: bool = True
    """Whether to infer an execution provider from ``torch.device`` bound to inputs or found in the graph."""

    default_execution_providers: Optional[Sequence[OrtExecutionProvider]] = None
    """The default fallback execution providers. If not specified, one will be
    be selected based on the host environment (most likely ``"CPUExecutionProvider"``).
    """

    # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession
    # in order to avoid internal allocation of output buffers in InferenceSession.
    # If output ortvalue returned from InferenceSession is allocated internally,
    # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership.
    # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor
    # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device.
    # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator,
    # and use the preallocated output buffers for InferenceSession not holding any ownership for them.
    # TODO(wschin): Make it to inference session level flag.
    # See https://github.com/pytorch/pytorch/issues/106869.
    preallocate_output: bool = False
    """If ``True``, allocate memory for ONNX Runtime's outputs on the PyTorch side."""

    use_aot_autograd: bool = True
    """Whether to wrap the ``OrtBackend`` with TorchDynamo's aot_autograd backend
    to support training (i.e., backward graphs are also sent to ``OrtBackend``).

    Symbolic execution is used to capture the forward pass and backward passes as a single graph.
    Then, a selected graph partition algorithm (``min_cut_rematerialization_partition``) is used
    to split the entire graph into forward sub-graph and backward sub-graph. Finally, both
    sub-graphs are compiled by ``OrtBackend``.
    """

    export_options: Optional["torch.onnx.ExportOptions"] = None
    """Options for the TorchDynamo-based ONNX exporter used by the ``OrtBackend``."""

    ort_session_options: Optional["onnxruntime.SessionOptions"] = None
    """Options for the ``onnxruntime.InferenceSession`` used by the ``OrtBackend``."""

    pre_ort_model_transforms: Optional[  # type: ignore[name-defined]
        Sequence[Callable[["onnx.ModelProto"], None]]
    ] = None
    """A list of graph transforms to be applied to the ONNX model before it
    is fed to ONNXRuntime's InferenceSession."""


@compatibility(is_backward_compatible=False)
class OrtBackend:
    """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls.

    The compiler entry point is OrtBackend.compile, which
        1. partitions the original graph into supported sub-graphs (type: torch.fx.GraphModule) and unsupported
           sub-graphs.
        2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call.
        3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph.
    """

    def __init__(self, options: Optional[OrtBackendOptions] = None):
        from onnxruntime.capi import _pybind_state as ORTC

        import torch.onnx
        import torch.onnx._internal._exporter_legacy
        import torch.onnx._internal.fx.decomposition_table

        self._options: Final = OrtBackendOptions() if options is None else options

        # options.export_options contains information shared between exporter and DORT.
        # For example, they should use the same decomposition table when
        #  1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py)
        #  2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model
        #     (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below).
        #
        # Convert user-facing option to internal option used by ONNX exporter
        # to access required information.
        # Some useful fields:
        # - Decomposition table for decomposing FX operators in exporter is
        #   self._resolved_onnx_exporter_options.decomposition_table.
        # - self._resolved_onnx_exporter_options.onnx_registry records what
        #   aten/prim ops are supported by exporter and their exporters (type: callable).
        self._resolved_onnx_exporter_options = (
            torch.onnx._internal._exporter_legacy.ResolvedExportOptions(
                torch.onnx.ExportOptions()
                if self._options.export_options is None
                else self._options.export_options
            )
        )

        #  Given DORT's computation flow:
        #   1. OrtOperatorSupport uses support_dict and extra_support_dict to select operators
        #      and send them to DORT.
        #   2. Then, DORT exports the selected sub-graphs into ONNX.
        #   3. Finally DORT calls ORT to do the computation.
        #  OrtOperatorSupport and create_onnx_friendly_decomposition_table(...)
        #  must use the same support_dict. If the support_dict here contains something not
        #  supported by exporter, exporter will fails in step 2 since the selected graphs may
        #  contains unsupported operators such as aten::_who_you_are.
        #  This restriction is automatically done since DORT and exporter shares the same
        #  self._resolved_onnx_exporter_options.
        support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table(
            self._resolved_onnx_exporter_options.onnx_registry
        )

        extra_support_dict: Dict[str, Any] = {
            "getattr": None,
            # To send operator.getitem to ORT, add the corresponding string
            # recognized by PyTorch's OperatorSupport class.
            "_operator.getitem": None,
            # To send operator.mul to ORT, add the corresponding string
            # recognized by PyTorch's OperatorSupport class.
            "_operator.mul": None,
            "_operator.add": None,
            "_operator.sub": None,
        }

        self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict)
        # TODO(wschin): this is a naive implementation of cache without proper guard
        # See https://github.com/pytorch/pytorch/issues/106868.
        self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {}
        # Conceptually, this filed is a 2-layer dictionary
        #   GraphModule 0
        #     ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
        #     ONNX Model 1
        #     ...
        #   GraphModule 1
        #     ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession)
        #     ONNX Model 3
        #     ...
        #   ...
        # , which caches all previous compilation result so that we can reuse them.
        # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs
        # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different
        # graphs captured by Dynamo and sent to OrtBackend.compile.
        self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules()

        self._assert_allclose_to_baseline = False

        self.execution_count = 0

        # Function which invokes ORT do to the real computation.
        self.run = (
            _run_onnx_session_with_ortvaluevector
            if hasattr(ORTC.OrtValueVector, "push_back_batch")
            else _run_onnx_session_with_fetch
        )

    def _select_eps(
        self, graph_module: torch.fx.GraphModule, *args
    ) -> Sequence[Tuple[str, Mapping[str, Any]]]:
        inferred_eps: Tuple[str, ...] = ()
        if self._options.infer_execution_providers:
            if eps_from_args := _infer_ep_from_device(*args):
                # If user feeds CUDA tensor as input argument,
                # we want to use CUDA EP.
                # Thus, `eps_from_args` (deduced from input arguments)
                # has highest priority.
                inferred_eps = eps_from_args
            elif eps_from_graph_module := _infer_ep_from_graph_module(graph_module):
                # If there is no EP in input arguments, we deduce EP from
                # graph_module's outputs. Those outputs may come from
                # FakeTensorProp or Dynamo's built-in symbolic shape inference.
                inferred_eps = eps_from_graph_module

        selected_eps = []

        for ep in (
            *(self._options.preferred_execution_providers or []),
            *_sort_eps(inferred_eps),
            *(self._options.default_execution_providers or _infer_default_eps()),
        ):
            if isinstance(ep, str):
                ep = (ep, {})
            elif isinstance(ep, tuple) and ep[1] is None:
                ep = (ep[0], {})
            if ep is not None and ep not in selected_eps:
                selected_eps.append(ep)

        return selected_eps

    def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs):
        """This function replaces GraphModule._wrapped_call in compiled model.

        The _wrapped_call is the underlying implementation of forward method. Replacing
        it means we delegate the computation to _ort_acclerated_call and therefore
        onnxruntime.InferenceSession.
        """
        import onnxruntime

        from torch.onnx._internal.fx import fx_onnx_interpreter, passes

        cached_execution_info_per_session = (
            self._all_ort_execution_info.search_reusable_session_execution_info(
                graph_module, *args
            )
        )
        if cached_execution_info_per_session:
            onnx_session = cached_execution_info_per_session.session
            input_names = cached_execution_info_per_session.input_names
            output_names = cached_execution_info_per_session.output_names
            input_value_infos = cached_execution_info_per_session.input_value_infos
            output_value_infos = cached_execution_info_per_session.output_value_infos
            input_devices = cached_execution_info_per_session.input_devices
            output_devices = cached_execution_info_per_session.output_devices
            prim_outputs = cached_execution_info_per_session.example_outputs
        else:
            # It's first time seeing such as graph. Let's make a new session
            # (type: onnxruntime.InferenceSession) for it.

            graph_module = passes.MovePlaceholderToFront(
                self._resolved_onnx_exporter_options.diagnostic_context,
                graph_module,
            ).run()
            # Generate reference outputs. They are used to indicate output
            # tensors' types and devices when calling ORT.
            #
            # WARNING: The downstream code should not change prim_outputs and
            # this backend should always produces output with schema identical to prim_outputs'.

            if self._resolved_onnx_exporter_options.dynamic_shapes:
                # No pre-allocation when dynamic shape is enabled.
                self.preallocate_output = False
                extracted_outputs = _extract_graph_module_outputs(graph_module)

                def maybe_map_to_meta_val(value):
                    if hasattr(value, "meta") and "val" in value.meta:
                        # Select outputs with "val" information. Without "val",
                        # it's not possible access output_arg.meta["val"].device.
                        return value.meta["val"]
                    else:
                        return value

                prim_outputs = _pytree.tree_map(
                    maybe_map_to_meta_val, extracted_outputs
                )
            else:
                try:
                    prim_outputs = FakeTensorProp(graph_module).propagate(
                        *args, **kwargs
                    )
                except Exception:
                    logger.warning("FakeTensorProb failed for %s", graph_module)
                    # When FakeTensorProp fails, it is not possible to preallocate output buffers
                    # because the output shapes are not inferred.
                    self.preallocate_output = False

                    # rethrow FakeTensorProb failure because it is not yet currently handled.
                    raise

            # Create the object to iterate through the nodes in graph one-by-one
            # and calls the corresponding ONNX exporter for each node.
            fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter(
                diagnostic_context=self._resolved_onnx_exporter_options.diagnostic_context
            )
            # Cast FX variables if they will result schema-mismatch when searching
            # for ONNX operator. E.g., add(double_tensor, int_tensor) is fine in PyTorch,
            # but ONNX expects add(double_tensor, double_tensor).
            graph_module = passes.InsertTypePromotion(
                self._resolved_onnx_exporter_options.diagnostic_context, graph_module
            ).run()
            # Start the per-node exporting process. It's conceptually a for loop
            # scanning through the nodes in the graph.
            exported = fx_interpreter.run(
                fx_graph_module=graph_module,
                onnxfunction_dispatcher=self._resolved_onnx_exporter_options.onnxfunction_dispatcher,
            )
            # Convert the exported result to ONNX ModelProto.
            onnx_model = exported.to_model_proto(
                opset_version=self._resolved_onnx_exporter_options.onnx_registry.opset_version,
            )

            try:
                from onnxscript import optimizer  # type: ignore[import]
                from onnxscript.rewriter import (  # type: ignore[import]
                    onnxruntime as ort_rewriter,
                )

                onnx_model = optimizer.optimize(onnx_model)
                onnx_model = ort_rewriter.rewrite(onnx_model)
            except ImportError:
                logger.warning(
                    "ONNXScript optimizer is not available. Skipping optimization. "
                    "Please `pip install onnxscript -U` to enable post-export optimization."
                )

            # Modify ONNX model using pre-registered graph transforms.
            # They are in-place modifications for avoiding unnecessary
            # copy of ONNX initializers.
            if self._options.pre_ort_model_transforms:
                for transform in self._options.pre_ort_model_transforms:
                    transform(onnx_model)

            onnx_model_bytes = onnx_model.SerializeToString()
            if os.environ.get("ONNXRT_DUMP_PATH", None):
                # If not empty, environment variable ONNXRT_DUMP_PATH defined the path
                # where generated onnx files should be stored.
                # This module keeps a global variables keeping track of the
                # stored models.
                # If ONNXRT_DUMP_PATH="dumped/dumped_model_"
                # The first file name will be 'dumped/dumped_model_0.onnx'.
                # For every dumped model, a text file 'dumped/dumped_model_0.txt'
                # is created as well to contain the string representing the graph_module.
                _dump_onnx_model(onnx_model_bytes, graph_module=graph_module)

            # Initialize a ORT session to execute this ONNX model.
            # Note that TorchDynamo assumes all inputs/outputs are on the
            # same device, but it's subject to change (very likely with
            # dynamic shape support), so we add execution providers
            # based on the logic in _select_eps: (explicitly preferred EPs,
            # EPs inferred from inputs or graph, and the fallback default EP)/
            #
            # TODO(wschin): enable external allocators.
            # See https://github.com/pytorch/pytorch/issues/106867
            onnx_session = onnxruntime.InferenceSession(
                path_or_bytes=onnx_model_bytes,
                sess_options=self._options.ort_session_options,
                providers=self._select_eps(graph_module, *args),
            )

            # Cache ORT session. It's reused for the same "graph_module".
            # Generate ONNX model and extract its input and output names.
            input_names = tuple(input.name for input in onnx_model.graph.input)
            output_names = tuple(output.name for output in onnx_model.graph.output)
            input_devices = _get_onnx_devices(args)
            # Cache devices for inputs and outputs. They are used to invoke
            # ORT session. Output devices indicate where (e.g., GPU or CPU)
            # to store outputs
            if isinstance(prim_outputs, tuple):
                output_devices = _get_onnx_devices(prim_outputs)
            else:
                output_devices = _get_onnx_devices((prim_outputs,))

            input_value_infos = tuple(input for input in onnx_model.graph.input)
            output_value_infos = tuple(output for output in onnx_model.graph.output)

            execution_info_per_session = OrtExecutionInfoPerSession(
                session=onnx_session,
                input_names=input_names,
                input_value_infos=input_value_infos,
                output_names=output_names,
                output_value_infos=output_value_infos,
                input_devices=input_devices,
                output_devices=output_devices,
                example_outputs=prim_outputs,
            )

            self._all_ort_execution_info.cache_session_execution_info(
                graph_module, execution_info_per_session
            )

        self.execution_count += 1

        # ORT always returns a tuple of outputs. If the original output is a tensor,
        # ORT output's first element must be extracted and returned. Otherwise, type
        # mismatch may happen in downstream computation.
        is_single_tensor_output = isinstance(prim_outputs, torch.Tensor)
        normalized_prim_outputs = (
            (prim_outputs,) if is_single_tensor_output else prim_outputs
        )
        assert isinstance(normalized_prim_outputs, tuple)
        assert all(
            isinstance(elem, (torch.Tensor, torch.SymInt, int))
            for elem in normalized_prim_outputs
        )

        _nvtx_range_push("run_onnx_session_with_ortvaluevector")
        onnx_outputs = self.run(
            onnx_session,
            input_names,
            args,
            input_devices,
            output_names,
            normalized_prim_outputs,
            output_devices,
            self._options.preallocate_output,
            input_value_infos,
            normalized_prim_outputs,
        )
        _nvtx_range_pop()

        if self._assert_allclose_to_baseline:
            # Compute baseline.
            baseline_outputs = torch._prims.executor.execute(
                graph_module, *args, executor="aten"
            )
            normalized_baseline_ouptuts = (
                (baseline_outputs,) if is_single_tensor_output else baseline_outputs
            )
            # Ensure every output tensor is close to the corresponding baseline.
            for onnx_output, baseline_output in zip(
                onnx_outputs, normalized_baseline_ouptuts
            ):
                torch.testing.assert_close(onnx_output, baseline_output)
        return onnx_outputs[0] if is_single_tensor_output else onnx_outputs

    def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule:
        # Deferred import since CapabilityBasedPartitioner is not decorated with
        # @compatibility; importing it at the module level will result in the test
        # failing: pytest test/test_fx.py -k test_public_api_surface
        # because this module is imported into torch.onnx.
        from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner

        # FX graph based partitioning based on ONNX supported ops.
        # Given a graph module
        #  GraphModule0
        #   node_0
        #   node_1
        #   node_2
        #   node_3
        #   node_4
        # If only node_2 is not supported by ONNX, this graph module will be partitioned into
        #  GraphModule0
        #   GraphModule1
        #    node_0
        #    node_1
        #   node_2
        #   GraphModule2
        #    node_3
        #    node_4
        # by calling CapabilityBasedPartitioner.partition_and_fuse.
        # Then, GraphModule1's and GraphModule2's forward method (GraphModule._wrapped_call)
        # will be replaced by OrtBackend._ort_accelerated_call to delegate computation to ORT.
        if graph_module in self._partitioner_cache:
            partitioned_prim_graph_module = self._partitioner_cache[graph_module]
        else:
            prim_graph_module = graph_module
            partitioner = CapabilityBasedPartitioner(
                prim_graph_module,
                self._supported_ops,
                allows_single_node_partition=True,
            )
            partitioned_prim_graph_module = partitioner.partition_and_fuse()
            self._partitioner_cache[graph_module] = partitioned_prim_graph_module

            # Overriding fused_module's __call__() function with ort_acclerated_call()
            # This loop goes through all graph partitions (each of them is an ONNX-representable graph)
            # and override their _wrapped_call function with _ort_accelerated_call.
            # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT.
            for node in partitioned_prim_graph_module.graph.nodes:
                # TODO(wschin): use a better way to identify fused submodule
                # See https://github.com/pytorch/pytorch/issues/106872.
                if node.op == "call_module" and "fused_" in node.name:
                    fused_module = getattr(partitioned_prim_graph_module, node.name)
                    # self.ort_acclerated_call is responsible for exporting graph to ONNX,
                    # creating ORT session, and running ORT session.
                    fused_module._wrapped_call = self._ort_acclerated_call

        return partitioned_prim_graph_module

    def __call__(
        self, graph_module: torch.fx.GraphModule, args
    ) -> torch.fx.GraphModule:
        """If ``OrtBackendOptions.use_aot_autograd`` is ``True``, the `auto_autograd` compiler
        will be invoked, wrapping this ``OrtBackend`` instance's ``compile`` method. Otherwise,
        the ``compile`` method is invoked directly."""
        if self._options.use_aot_autograd:
            from functorch.compile import min_cut_rematerialization_partition
            from torch._dynamo.backends.common import aot_autograd

            return aot_autograd(
                fw_compiler=self.compile,
                partition_fn=min_cut_rematerialization_partition,
                decompositions=self._resolved_onnx_exporter_options.decomposition_table,
            )(graph_module, args)

        return self.compile(graph_module, args)

    __instance_cache_max_count: Final = 8
    __instance_cache: Final[List["OrtBackend"]] = []

    @staticmethod
    def get_cached_instance_for_options(
        options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
    ) -> "OrtBackend":
        """Returns a possibly cached instance of an ``OrtBackend``. If an existing
        backend was created previously through this function with the same options,
        it will be returned. Otherwise a new backend will be created, cached, and
        returned.

        Note: if ``options`` sets ``ort_session_options``, a new ``OrtBackend``
        will always be returned, since ``onnxruntime.SessionOptions`` cannot
        participate in caching."""

        def reusable(a: OrtBackendOptions, b: OrtBackendOptions):
            if (
                a.preferred_execution_providers != b.preferred_execution_providers
                or a.infer_execution_providers != b.infer_execution_providers
                or a.default_execution_providers != b.default_execution_providers
                or a.preallocate_output != b.preallocate_output
                or a.use_aot_autograd != b.use_aot_autograd
                or a.pre_ort_model_transforms != b.pre_ort_model_transforms
            ):
                return False

            # onnxruntime.SessionOptions is a pybind11 object, cannot be pickled,
            # and holds too much potential state to reasonably check manually;
            # ort_session_options is provided at all, the backend does not participate
            # in caching.
            if a.ort_session_options is not None or b.ort_session_options is not None:
                return False

            if a.export_options is b.export_options:
                return True

            # Similarly, some objects in ExportOptions are too stateful to use for
            # caching. We should revisit this.
            if a.export_options is not None and b.export_options is not None:
                return (
                    a.export_options.dynamic_shapes == b.export_options.dynamic_shapes
                    and a.export_options.diagnostic_options
                    == b.export_options.diagnostic_options
                    and a.export_options.onnx_registry is b.export_options.onnx_registry
                    and a.export_options.fake_context is b.export_options.fake_context
                )

            # We can't account for how the two option sets may differ, so it's not safe to reuse.
            return False

        if not isinstance(options, OrtBackendOptions):
            options = OrtBackendOptions(**(options or {}))

        backend = next(
            (b for b in OrtBackend.__instance_cache if reusable(b._options, options)),
            None,
        )

        if backend is None:
            assert (
                len(OrtBackend.__instance_cache) < OrtBackend.__instance_cache_max_count
            ), (
                f"No more than {OrtBackend.__instance_cache_max_count} instances of "
                f"{OrtBackend} allowed. Please instantiate `{OrtBackend}` explicitly "
                "to pass to `torch.compile`. "
                "See https://github.com/pytorch/pytorch/pull/107973#discussion_r1306144795 "
                "for discussion."
            )
            OrtBackend.__instance_cache.append(backend := OrtBackend(options))

        return backend

    @staticmethod
    def clear_cached_instances():
        OrtBackend.__instance_cache.clear()

    @staticmethod
    def get_cached_instances():
        return tuple(OrtBackend.__instance_cache)


@compatibility(is_backward_compatible=False)
def torch_compile_backend(
    graph_module: torch.fx.GraphModule,
    args,
    *,
    options: Optional[Union[OrtBackendOptions, Mapping[str, Any]]] = None,
):
    return OrtBackend.get_cached_instance_for_options(options)(graph_module, args)
