# mypy: allow-untyped-defs
from __future__ import annotations

import dataclasses
import itertools
import logging
import math
import operator
from typing import (
    Callable,
    Dict,
    Generic,
    Optional,
    overload,
    SupportsFloat,
    TYPE_CHECKING,
    TypeVar,
    Union,
)
from typing_extensions import TypeGuard

import sympy
from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom

import torch
from torch._logging import LazyString
from torch._prims_common import dtype_to_type

from .functions import (
    _keep_float,
    FloatTrueDiv,
    FloorDiv,
    IntTrueDiv,
    OpaqueUnaryFn_exp,
    OpaqueUnaryFn_log,
    OpaqueUnaryFn_sqrt,
    PowByNatural,
    RoundDecimal,
    RoundToInt,
    safe_pow,
    ToFloat,
    TruncToFloat,
    TruncToInt,
)
from .interp import sympy_interp
from .numbers import int_oo, IntInfinity, NegativeIntInfinity


log = logging.getLogger(__name__)

__all__ = ["ValueRanges", "ValueRangeAnalysis", "bound_sympy"]

_T = TypeVar("_T", sympy.Expr, SympyBoolean)


class ValueRangeError(RuntimeError):
    pass


# Like sympify, but supports less stuff, and also ensures that direct
# sympy expressions don't have free variables
def simple_sympify(e):
    if isinstance(e, bool):
        return sympy.true if e else sympy.false
    elif isinstance(e, int):
        return sympy.Integer(e)
    elif isinstance(e, float):
        # infinity is special; we use it to bracket integers as well
        if math.isinf(e):
            return sympy.oo if e > 0 else -sympy.oo
        return sympy.Float(e)
    elif isinstance(e, sympy.Expr):
        assert e.is_number, e
        # NaNs can occur when doing things like 0 * sympy.oo, but it is better
        # if the operator notices this and takes care of it, because sometimes
        # the NaN is inappropriate (for example, for ints, the [-oo, oo] range
        # should go to zero when multiplied with [0, 0])
        assert e != sympy.nan
        return e
    elif isinstance(e, BooleanAtom):
        return e
    else:
        raise AssertionError(f"not simple sympy type {type(e)}: {e}")


# Sympy atomics only. Unlike <=, it also works on Sympy bools.
def sympy_generic_le(lower, upper):
    if isinstance(lower, sympy.Expr):
        assert isinstance(upper, sympy.Expr)
        return lower <= upper
    else:
        # only negative condition is True > False
        assert isinstance(lower, SympyBoolean) and isinstance(upper, SympyBoolean), (
            lower,
            upper,
        )
        return not (lower and not upper)


def vr_is_bool(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[SympyBoolean]]:
    return vr.is_bool


def vr_is_expr(vr: ValueRanges[_T]) -> TypeGuard[ValueRanges[sympy.Expr]]:
    return not vr.is_bool


ExprIn = Union[int, float, sympy.Expr]
BoolIn = Union[bool, SympyBoolean]
AllIn = Union[ExprIn, BoolIn]
ExprFn = Callable[[sympy.Expr], sympy.Expr]
ExprFn2 = Callable[[sympy.Expr, sympy.Expr], sympy.Expr]
BoolFn = Callable[[SympyBoolean], SympyBoolean]
BoolFn2 = Callable[[SympyBoolean, SympyBoolean], SympyBoolean]
AllFn = Union[ExprFn, BoolFn]
AllFn2 = Union[ExprFn2, BoolFn2]


@dataclasses.dataclass(frozen=True)
class ValueRanges(Generic[_T]):
    if TYPE_CHECKING:
        # ruff doesn't understand circular references but mypy does
        ExprVR = ValueRanges[sympy.Expr]  # noqa: F821
        BoolVR = ValueRanges[SympyBoolean]  # noqa: F821
        AllVR = Union[ExprVR, BoolVR]

    # Although the type signature here suggests you can pass any
    # sympy expression, in practice the analysis here only works
    # with constant sympy expressions
    lower: _T
    upper: _T
    is_bool: bool
    is_int: bool
    is_float: bool

    def __repr__(self) -> str:
        return f"VR[{self.lower}, {self.upper}]"

    @overload
    def __init__(
        self: ValueRanges[sympy.Expr],
        lower: ExprIn,
        upper: ExprIn,
    ) -> None:
        ...

    @overload
    def __init__(  # type: ignore[misc]
        self: ValueRanges[SympyBoolean],
        lower: BoolIn,
        upper: BoolIn,
    ) -> None:
        ...

    def __init__(self, lower: AllIn, upper: AllIn) -> None:
        lower = simple_sympify(lower)
        upper = simple_sympify(upper)
        # TODO: when the bounds have free variables, this may be
        # nontrivial to actually verify
        try:
            if not sympy_generic_le(lower, upper):
                raise ValueRangeError(f"Invalid ranges [{lower}:{upper}]")
        except TypeError as e:
            raise TypeError(f"Could not compare {lower} <= {upper}") from e

        is_bool_lower = isinstance(lower, SympyBoolean)
        is_bool_upper = isinstance(upper, SympyBoolean)
        assert is_bool_lower == is_bool_upper, (lower, upper)

        # Warning: is_int/is_float is best effort.  We do pretty well in
        # Dynamo, but in Inductor these attributes are often wrong because we
        # are not very rigorous in dtype analysis.  This is also why we need
        # the flexible analysis for is_int: sometimes a sympy.oo pops in for
        # an integer bound. I would /like/ for us not to do this, but it's
        # too hard to push the invariant through right now.
        if isinstance(lower, sympy.Integer) and upper == sympy.oo:
            upper = int_oo
        if isinstance(upper, sympy.Integer) and lower == -sympy.oo:
            lower = -int_oo
        # NB: [-int_oo, -int_oo] and [int_oo, int_oo] are allowed
        integer_types = (sympy.Integer, NegativeIntInfinity, IntInfinity)
        is_int_lower = isinstance(lower, integer_types)
        is_int_upper = isinstance(upper, integer_types)

        # Because this is a frozen class
        object.__setattr__(self, "lower", lower)
        object.__setattr__(self, "upper", upper)
        # Unlike bool/int in Python, we don't report bools are ints
        #
        # NB: is_bool_lower == is_bool_upper, so we only need to check one
        object.__setattr__(self, "is_bool", is_bool_lower)
        object.__setattr__(
            self,
            "is_int",
            not self.is_bool and is_int_lower and is_int_upper,
        )
        """
        # This assert is just impossible right now, too many sympy bugs
        if self.is_int:
            # NB: sympy will sometimes randomly lose the float-ness of zero,
            # so we also need to account for that in the assertion here.
            # See also https://github.com/sympy/sympy/issues/26620
            assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], (
                lower,
                upper,
            )
            assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper)
        """
        # NB: [-oo, oo] always advertises as float!
        object.__setattr__(self, "is_float", not self.is_bool and not self.is_int)
        assert self.is_bool or self.is_int or self.is_float, (lower, upper)

    def boolify(self) -> ValueRanges[SympyBoolean]:
        if vr_is_bool(self):
            return self
        elif self == ValueRanges.unknown():
            return ValueRanges.unknown_bool()
        else:
            raise AssertionError(f"not bool like {self}")

    def __contains__(self, x: AllIn) -> bool:
        return ValueRanges.wrap(x).issubset(self)

    def issubset(self, other):
        return sympy_generic_le(other.lower, self.lower) and sympy_generic_le(
            self.upper, other.upper
        )

    def tighten(self, other) -> ValueRanges:
        """Given two ValueRanges, returns their intersection"""
        return self & other

    # Intersection
    @overload
    def __and__(
        self: ValueRanges[sympy.Expr],
        other: ValueRanges[sympy.Expr],
    ) -> ValueRanges[sympy.Expr]:
        ...

    @overload
    def __and__(  # type: ignore[misc]
        self: ValueRanges[SympyBoolean],
        other: ValueRanges[SympyBoolean],
    ) -> ValueRanges[SympyBoolean]:
        ...

    def __and__(self: AllVR, other: AllVR) -> AllVR:
        if other == ValueRanges.unknown():
            return self
        if self == ValueRanges.unknown():
            return other
        assert self.is_bool == other.is_bool, (self, other)
        assert self.is_int == other.is_int, (self, other)
        assert self.is_float == other.is_float, (self, other)
        if self.is_bool:
            return ValueRanges(
                sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper)
            )
        else:
            return ValueRanges(
                sympy.Max(self.lower, other.lower), sympy.Min(self.upper, other.upper)
            )

    # Union
    @overload
    def __or__(
        self: ValueRanges[sympy.Expr],
        other: ValueRanges[sympy.Expr],
    ) -> ValueRanges[sympy.Expr]:
        ...

    @overload
    def __or__(  # type: ignore[misc]
        self: ValueRanges[SympyBoolean],
        other: ValueRanges[SympyBoolean],
    ) -> ValueRanges[SympyBoolean]:
        ...

    def __or__(self: AllVR, other: AllVR) -> AllVR:
        if ValueRanges.unknown() in (self, other):
            return ValueRanges.unknown()
        assert self.is_bool == other.is_bool, (self, other)
        assert self.is_int == other.is_int, (self, other)
        assert self.is_float == other.is_float, (self, other)
        if self.is_bool:
            return ValueRanges(
                sympy.And(self.lower, other.lower), sympy.Or(self.upper, other.upper)
            )
        else:
            return ValueRanges(
                sympy.Min(self.lower, other.lower), sympy.Max(self.upper, other.upper)
            )

    def is_singleton(self) -> bool:
        return self.lower == self.upper

    @staticmethod
    def unknown() -> ValueRanges[sympy.Expr]:
        return ValueRanges(-sympy.oo, sympy.oo)

    @staticmethod
    def unknown_int() -> ValueRanges[sympy.Expr]:
        return ValueRanges(-int_oo, int_oo)

    @staticmethod
    def unknown_bool() -> ValueRanges[SympyBoolean]:
        return ValueRanges(sympy.false, sympy.true)

    @overload
    @staticmethod
    # work around the fact that bool and int overlap
    def wrap(arg: Union[ExprIn, ExprVR]) -> ExprVR:  # type: ignore[overload-overlap]
        ...

    @overload
    @staticmethod
    def wrap(arg: Union[BoolIn, BoolVR]) -> BoolVR:  # type: ignore[misc]
        ...

    @staticmethod
    def wrap(arg: Union[AllIn, AllVR]) -> AllVR:
        if isinstance(arg, ValueRanges):
            return arg
        if isinstance(arg, float) and math.isnan(arg):
            return ValueRanges.unknown()
        # arg is either ExprIn or BoolIn, but we don't know it here
        return ValueRanges(arg, arg)  # type: ignore[arg-type]

    @staticmethod
    def increasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
        """Increasing: x <= y => f(x) <= f(y)."""
        x = ValueRanges.wrap(x)
        return ValueRanges(fn(x.lower), fn(x.upper))

    @overload
    @staticmethod
    def decreasing_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
        ...

    @overload
    @staticmethod
    def decreasing_map(x: Union[BoolIn, BoolVR], fn: BoolFn) -> BoolVR:  # type: ignore[misc]
        ...

    @staticmethod
    def decreasing_map(x: Union[AllIn, AllVR], fn: AllFn) -> AllVR:
        """Decreasing: x <= y => f(x) >= f(y)."""
        x = ValueRanges.wrap(x)
        # consistently either Expr or Bool, but we don't know it here
        return ValueRanges(fn(x.upper), fn(x.lower))  # type: ignore[arg-type]

    @staticmethod
    def monotone_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
        """It's increasing or decreasing."""
        x = ValueRanges.wrap(x)
        l = fn(x.lower)
        u = fn(x.upper)
        return ValueRanges(min(l, u), max(l, u))

    @staticmethod
    def convex_min_zero_map(x: Union[ExprIn, ExprVR], fn: ExprFn) -> ExprVR:
        """Fn is convex and has a minimum at 0."""
        x = ValueRanges.wrap(x)
        if 0 in x:
            upper = max(fn(x.lower), fn(x.upper))
            upper = simple_sympify(upper)
            if isinstance(upper, sympy.Float) or upper == sympy.oo:
                return ValueRanges(0.0, upper)
            return ValueRanges(0, upper)
        return ValueRanges.monotone_map(x, fn)

    @overload
    @staticmethod
    def coordinatewise_increasing_map(
        x: Union[ExprIn, ExprVR],
        y: Union[ExprIn, ExprVR],
        fn: ExprFn2,
    ) -> ExprVR:
        ...

    @overload
    @staticmethod
    def coordinatewise_increasing_map(  # type: ignore[misc]
        x: Union[BoolIn, BoolVR],
        y: Union[BoolIn, BoolVR],
        fn: BoolFn2,
    ) -> BoolVR:
        ...

    @staticmethod
    def coordinatewise_increasing_map(
        x: Union[AllIn, AllVR],
        y: Union[AllIn, AllVR],
        fn: AllFn2,
    ) -> AllVR:
        """
        It's increasing on each coordinate.

        Mathematically:
        For every 1 <= i <= n and x_i <= y_i we have that
        f(x1, .., xn) <= f(x1, , yi, ..., xn)
        """
        x, y = ValueRanges.wrap(x), ValueRanges.wrap(y)
        return ValueRanges(
            fn(x.lower, y.lower),  # type: ignore[arg-type]
            fn(x.upper, y.upper),  # type: ignore[arg-type]
        )

    @classmethod
    def coordinatewise_monotone_map(cls, x, y, fn):
        """It's increasing or decreasing on each coordinate."""
        x, y = cls.wrap(x), cls.wrap(y)
        products = [
            fn(a, b)
            for a, b in itertools.product([x.lower, x.upper], [y.lower, y.upper])
        ]
        return ValueRanges(min(products), max(products))


class SymPyValueRangeAnalysis:
    """
    It gives bounds on a SymPy operator given bounds on its arguments
    See the function `bound_sympy` for a function that applies this logic to a full SymPy expression
    """

    @staticmethod
    def constant(value, dtype):
        if isinstance(value, ValueRanges):
            assert value.is_singleton()
            value = value.lower
        # NB: value is NOT a sympy expression, it's a constant!
        is_python = isinstance(value, (int, float, bool))
        assert is_python or isinstance(
            value, (BooleanAtom, sympy.Integer, sympy.Number)
        )

        # using nan makes subsequent computation throw, and for the purposes of optimization
        # returning -math.inf - math.inf is equivalent to giving up
        if isinstance(value, SupportsFloat) and math.isnan(value):
            if dtype == torch.bool:
                return ValueRanges.unknown_bool()
            elif dtype.is_floating_point:
                return ValueRanges.unknown()
            else:
                return ValueRanges(-int_oo, int_oo)

        if is_python:
            type_ = dtype_to_type(dtype)
            value = type_(value)
        else:
            # We do a type check on a best-effort basis
            # We don't want to force a cast to sympy.Float if the value is Rational to avoid losing precision
            if dtype == torch.bool:
                assert isinstance(value, BooleanAtom)
            elif dtype.is_floating_point:
                assert not value.is_finite or value.is_real
            else:
                # dtype is intXX
                assert value.is_integer

        r = ValueRanges.wrap(value)
        return r

    @staticmethod
    def to_dtype(a, dtype, src_dtype=None):
        if dtype == torch.float64:
            return ValueRanges.increasing_map(a, ToFloat)
        elif dtype == torch.bool:
            return ValueRanges.unknown_bool()
        elif not dtype.is_floating_point:
            return ValueRanges.unknown_int()
        return ValueRanges.unknown()

    @staticmethod
    def trunc_to_int(a, dtype):
        return ValueRanges.increasing_map(a, TruncToInt)

    @staticmethod
    def not_(a):
        a = ValueRanges.wrap(a)
        a = a.boolify()
        assert a.is_bool
        return ValueRanges.decreasing_map(a, sympy.Not)

    @staticmethod
    def or_(a, b):
        return ValueRanges.coordinatewise_increasing_map(a, b, sympy.Or)

    @staticmethod
    def and_(a, b):
        return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And)

    @staticmethod
    def eq(a, b):
        a = ValueRanges.wrap(a)
        b = ValueRanges.wrap(b)
        if a.is_singleton() and b.is_singleton() and a.lower == b.lower:
            return ValueRanges.wrap(sympy.true)
        elif a.lower > b.upper or b.lower > a.upper:  # ranges disjoint
            return ValueRanges.wrap(sympy.false)
        return ValueRanges(sympy.false, sympy.true)

    @classmethod
    def ne(cls, a, b):
        return cls.not_(cls.eq(a, b))

    @classmethod
    def identity(cls, a):
        return ValueRanges.wrap(a)

    @classmethod
    def lt(cls, a, b):
        a = ValueRanges.wrap(a)
        b = ValueRanges.wrap(b)
        assert a.is_bool == b.is_bool
        if a.is_bool:
            return cls.and_(cls.not_(a), b)
        else:
            if a.upper < b.lower:
                return ValueRanges.wrap(sympy.true)
            elif a.lower >= b.upper:
                return ValueRanges.wrap(sympy.false)
            return ValueRanges(sympy.false, sympy.true)

    @classmethod
    def gt(cls, a, b):
        return cls.lt(b, a)

    @classmethod
    def le(cls, a, b):
        return cls.not_(cls.gt(a, b))

    @classmethod
    def ge(cls, a, b):
        return cls.not_(cls.lt(a, b))

    @staticmethod
    def add(a, b):
        return ValueRanges.coordinatewise_increasing_map(
            a, b, _keep_float(operator.add)
        )

    @classmethod
    def mul(cls, a, b):
        a = ValueRanges.wrap(a)
        b = ValueRanges.wrap(b)

        assert a.is_bool == b.is_bool
        if a.is_bool:
            return cls.and_(a, b)

        def safe_mul(a, b):
            # Make unknown() * wrap(0.0) == wrap(0.0)
            if a == 0.0 or a == 0:
                return a
            elif b == 0.0 or b == 0:
                return b
            else:
                return a * b

        return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul))

    @staticmethod
    def int_truediv(a, b):
        a = ValueRanges.wrap(a)
        b = ValueRanges.wrap(b)
        if 0 in b or ((-int_oo in a or int_oo in a) and (-int_oo in b or int_oo in b)):
            return ValueRanges.unknown()
        else:
            return ValueRanges.coordinatewise_monotone_map(
                a, b, _keep_float(IntTrueDiv)
            )

    @staticmethod
    def truediv(a, b):
        a = ValueRanges.wrap(a)
        b = ValueRanges.wrap(b)
        if 0 in b or (
            (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b)
        ):
            return ValueRanges.unknown()
        else:
            return ValueRanges.coordinatewise_monotone_map(
                a, b, _keep_float(FloatTrueDiv)
            )

    @staticmethod
    def floordiv(a, b):
        a = ValueRanges.wrap(a)
        b = ValueRanges.wrap(b)
        if 0 in b:
            return ValueRanges.unknown_int()
        products = []
        for x, y in itertools.product([a.lower, a.upper], [b.lower, b.upper]):
            r = FloorDiv(x, y)
            if r is sympy.nan:
                products.append((sympy.sign(x) * sympy.sign(y)) * int_oo)
            else:
                products.append(r)

        return ValueRanges(min(products), max(products))

    @classmethod
    def mod(cls, x, y):
        x = ValueRanges.wrap(x)
        y = ValueRanges.wrap(y)
        # nb. We implement C semantics

        def c_mod(a, b):
            ret = abs(a) % abs(b)
            if a < 0:
                ret *= -1
            return ret

        def c_div(a, b):
            x = a / b
            return sympy.Integer(x) if x.is_finite and x not in (int_oo, -int_oo) else x

        if 0 in y:
            return ValueRanges.unknown_int()
        elif y.is_singleton():
            y_val = abs(y.lower)
            # If it wraps, we need to take the whole interval

            # The function is locally linear if they are in the same class
            if c_div(x.lower, y_val) == c_div(x.upper, y_val):
                return ValueRanges.increasing_map(x, lambda u: c_mod(u, y_val))
            if x.upper < 0:
                # Negative case
                return ValueRanges(-y_val + 1, 0)
            elif x.lower > 0:
                # Positive case
                return ValueRanges(0, y_val - 1)
            else:
                # Mixed case
                lower = max(-y_val + 1, x.lower)
                upper = min(y_val - 1, x.upper)
                return ValueRanges(lower, upper)
        else:
            # Too difficult, we bail out
            upper = cls.abs(y).upper - 1
            return ValueRanges(-upper, upper)

    @classmethod
    def modular_indexing(cls, a, b, c):
        return cls.mod(cls.floordiv(a, b), c)

    @classmethod
    def is_non_overlapping_and_dense_indicator(cls, *args):
        return ValueRanges.unknown_int()

    @classmethod
    def pow_by_natural(cls, a, b):
        a = ValueRanges.wrap(a)
        b = ValueRanges.wrap(b)
        if a.is_singleton() and b.is_singleton():
            return ValueRanges.wrap(safe_pow(a.lower, b.lower))
        # NB: Exclude zero, because zero is special
        elif a.lower >= 1:
            # We should know that b >= 0 but we may have forgotten this fact due
            # to replacements, so don't assert it, but DO clamp it to prevent
            # degenerate problems
            return ValueRanges.coordinatewise_increasing_map(
                a, b & ValueRanges(0, int_oo), PowByNatural
            )
        elif b.is_singleton():
            if b.lower % 2 == 0:
                # x^n where n is even
                return ValueRanges.convex_min_zero_map(
                    a, lambda x: safe_pow(x, b.lower)
                )
            else:
                # x^n where n is odd
                return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower))
        else:
            # a is potentially negative, and we don't know if the exponent is
            # even or odd.  So just conservatively set the upper and lower
            # bound based on what the maximum absolute value could be, in both
            # directions
            max_base = max(a.upper, -a.lower)
            return ValueRanges(
                -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper)
            )

    @classmethod
    def pow(cls, a, b):
        return ValueRanges.unknown()

        # We could implement all this, but for floating point pow, is there
        # really a point?
        """
        a = ValueRanges.wrap(a)
        b = ValueRanges.wrap(b)

        # Not implemented yet. It's a bit tricky
        # If you want to implement it, compute the partial derivatives of a ** b
        # and check the ranges where the function is increasing / decreasing
        # Another non-tight way of doing this is defaulting to doing noting that for a > 0,  a ** b == exp(b * log(a))
        # If this second option is implemented, by carefult about the types and possible infinities here and there.
        if not b.is_singleton():
            return ValueRanges.unknown()

        b = b.lower
        if a.is_singleton():
            a = a.lower
            r = a**b
            if not r.is_finite:
                return ValueRanges.unknown()
            return ValueRanges.wrap(r)

        if b == 0:
            if not a.lower.is_finite:
                return ValueRanges.unknown()
            return ValueRanges.wrap(1.0)

        if b < 0:
            a = cls.reciprocal(a)
            b = -b

        if a == ValueRanges.unknown():
            return ValueRanges.unknown()

        # If the base is positive, then we're good, otherwise nothing's defined
        if a.lower >= 0:
            return ValueRanges.increasing_map(a, lambda x: x**b)
        else:
            return ValueRanges.unknown()
        """

    @staticmethod
    def reciprocal(x):
        """Needed as it's used in pow, but it won't appear on a SymPy expression"""
        x = ValueRanges.wrap(x)
        if 0 in x:
            return ValueRanges.unknown()
        else:
            return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y))  # type: ignore[operator]

    @staticmethod
    def abs(x):
        return ValueRanges.convex_min_zero_map(x, abs)

    @staticmethod
    def exp(x):
        return ValueRanges.increasing_map(x, OpaqueUnaryFn_exp)

    @staticmethod
    def log(x):
        x = ValueRanges.wrap(x)
        if x.lower <= 0:
            return ValueRanges.unknown()
        return ValueRanges.increasing_map(x, OpaqueUnaryFn_log)

    @classmethod
    def minimum(cls, a, b):
        return cls.min_or_max(a, b, sympy.Min)

    @classmethod
    def maximum(cls, a, b):
        return cls.min_or_max(a, b, sympy.Max)

    @staticmethod
    def min_or_max(a, b, fn):
        a = ValueRanges.wrap(a)
        b = ValueRanges.wrap(b)
        return ValueRanges.coordinatewise_increasing_map(a, b, fn)

    @classmethod
    def floor_to_int(cls, x, dtype):
        return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor)

    @classmethod
    def ceil_to_int(cls, x, dtype):
        return ValueRanges.increasing_map(
            x, sympy.functions.elementary.integers.ceiling
        )

    # I think these implementations are sound.  The hazard here is that sympy
    # will carry out the floor/ceil at too high precision and then something
    # bad will happen when we convert it to float.
    #
    # For truncation, the implementation is clearly sound, because the desired
    # target float is always exactly representable, since you're just chopping
    # off bits the mantissa.  But what about ceil/floor?
    #
    # The important constraint here is that we're not defining floor on
    # arbitrary real numbers, only representable float numbers.  So we can
    # take advantage of the fact that before we reach the first
    # unrepresentable integer in floating point space, we have the range of
    # numbers corresponding to exponent zero: all integers, with no fractional
    # amounts.  floor/ceil is an identity operation in this case.  In the
    # range below here, representable floating point numbers are spaced
    # exactly 1/2 apart, and notably, both the floor/ceil are defined floating
    # point numbers.  There is no "gap" as you step up to the next exponent.

    @classmethod
    def floor(cls, x):
        return ValueRanges.increasing_map(
            x, _keep_float(sympy.functions.elementary.integers.floor)
        )

    @classmethod
    def ceil(cls, x):
        return ValueRanges.increasing_map(
            x, _keep_float(sympy.functions.elementary.integers.ceiling)
        )

    @classmethod
    def round_decimal(cls, number, ndigits):
        if not ndigits.is_singleton():
            return ValueRanges.unknown()

        ndigits = ndigits.lower
        # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind
        # the second parameter.
        fn = lambda number: RoundDecimal(number, ndigits)  # type: ignore[misc, assignment]  # noqa: E731

        return ValueRanges.increasing_map(number, fn)

    @classmethod
    def round_to_int(cls, number, dtype):
        return ValueRanges.increasing_map(number, RoundToInt)

    # It's used in some models on symints
    @staticmethod
    def sqrt(x):
        x = ValueRanges.wrap(x)
        if x.lower < 0:
            return ValueRanges.unknown()
        return ValueRanges.increasing_map(x, OpaqueUnaryFn_sqrt)

    @staticmethod
    def where(a, b, c):
        b = ValueRanges.wrap(b)
        c = ValueRanges.wrap(c)
        a = a.boolify()
        # We sometimes write unknown without specifying the type correctly
        # In particular, we do that when initialising the bounds for loads in bounds.py
        assert b.is_bool == c.is_bool or ValueRanges.unknown() in (b, c)
        if b.is_bool:
            return ValueRanges(sympy.And(b.lower, c.lower), sympy.Or(b.upper, c.upper))
        else:
            return ValueRanges(sympy.Min(b.lower, c.lower), sympy.Max(b.upper, c.upper))

    # expr_cond_pair is used to represent a single (expr, condition) pair in piecewise.
    # We just return the value range of the expression and its corresponding condition as a tuple
    # and defer the analysis to piecewise
    @staticmethod
    def expr_cond_pair(a, b):
        b = b.boolify()
        return (a, b)

    # piecewise function can be used to convert a SymBool to SymInt:
    # int_expr = Piecewise((1, bool_expr), (0, True)), it evalutes to 1 when sym_bool is True and 0 otherwise.
    #
    # ranges is a sequence of (expr_range, condition_range) pairs. The range pair is constructed in expr_cond_pair.
    # The ValueRange of Piecewise is just the union of all expr ranges whose condition expr can be True.
    @staticmethod
    def piecewise(*ranges):
        init_range = None
        for expr_range, cond_range in ranges:
            if sympy.true in cond_range:
                if init_range is None:
                    init_range = expr_range
                else:
                    init_range = init_range | expr_range
        return init_range

    @staticmethod
    def cos(x):
        # TODO: We should tighten value ranges
        # If input range span is pi + 2*pi*k, then output range is (-1, 1)
        # otherwise the minimum of the value of the function on the extremes
        return ValueRanges(-1.0, 1.0)

    @staticmethod
    def cosh(x):
        return ValueRanges(0.0, sympy.oo)
        """
        x = ValueRanges.wrap(x)
        if x.lower > 0:
            return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh)
        elif x.upper < 0:
            return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh)
        return ValueRanges(0.0, sympy.oo)
        """

    @staticmethod
    def sin(x):
        # TODO: We should tighten value ranges
        # See details on cos
        return ValueRanges(-1.0, 1.0)

    @staticmethod
    def sinh(x):
        # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh)
        return ValueRanges(-sympy.oo, sympy.oo)

    @staticmethod
    def tan(x):
        return ValueRanges(-sympy.oo, sympy.oo)

    @staticmethod
    def tanh(x):
        # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh)
        return ValueRanges(-sympy.oo, sympy.oo)

    @staticmethod
    def asin(x):
        return ValueRanges(-sympy.oo, sympy.oo)
        """
        x = ValueRanges.wrap(x)
        if -1 <= x.lower and x.upper <= 1:
            return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh)
        return ValueRanges.unknown()
        """

    @staticmethod
    def acos(x):
        return ValueRanges(-sympy.oo, sympy.oo)
        """
        x = ValueRanges.wrap(x)
        if -1 <= x.lower and x.upper <= 1:
            return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos)
        return ValueRanges.unknown()
        """

    @staticmethod
    def atan(x):
        return ValueRanges(-sympy.oo, sympy.oo)
        # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan)

    @staticmethod
    def trunc(x):
        return ValueRanges.increasing_map(x, TruncToFloat)


class ValueRangeAnalysis(SymPyValueRangeAnalysis):
    def __init__(self) -> None:
        self.name = "ValueRangeAnalysis"
        boolean_operators = (
            "xor",
            "logical_and",
            "logical_or",
            "logical_not",
        )
        for op in boolean_operators:
            setattr(self, op, self.bool_handler)

    @staticmethod
    def bool_handler(*args, **kwargs):
        # just assuming bools can have both values
        return ValueRanges(sympy.false, sympy.true)  # type: ignore[arg-type]

    @staticmethod
    def default_handler(*args, **kwargs):
        # many ops are unlikely to show up in optimizable indexing compute,
        # so we dont have full coverage
        return ValueRanges.unknown()

    def load(self, name: str, index: sympy.Expr):
        return ValueRanges.unknown()

    def store(self, name, index, value, mode=None):
        return

    def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
        return ValueRanges.unknown()

    @classmethod
    def index_expr(cls, index, dtype):
        assert isinstance(index, ValueRanges)
        return cls.to_dtype(index, dtype)

    @staticmethod
    def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None):
        x = ValueRanges.wrap(x)

        if dtype == torch.bool:
            if x.is_singleton():
                return ValueRanges.wrap(x.lower != 0)
            elif x.is_bool:
                return x
            elif 0 not in x:
                return ValueRanges.wrap(sympy.true)
            else:
                return ValueRanges(sympy.false, sympy.true)

        def cast(x, dtype):
            # dtype is int or float
            if dtype.is_floating_point:
                return sympy.Float(x)
            else:
                if x in (int_oo, -int_oo):
                    return x
                try:
                    return sympy.Integer(x)
                except TypeError:
                    # inf cannot be cast to Integer
                    return x

        if x.is_bool:
            if x.is_singleton():
                val = 1 if x.lower else 0
                return ValueRanges.wrap(cast(val, dtype))
            else:
                return ValueRanges(cast(0, dtype), cast(1, dtype))
        else:
            # int to float or float to int
            return ValueRanges(cast(x.lower, dtype), cast(x.upper, dtype))

    @staticmethod
    def square(x):
        return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2))

    @staticmethod
    def neg(x):
        return ValueRanges.decreasing_map(x, operator.neg)

    # TODO: this is slightly inaccurate because truncdiv operates at integer
    # precision, but we're going through float truediv which means we can
    # potentially lose precision on the bounds
    @classmethod
    def truncdiv(cls, a, b):
        x = cls.truediv(a, b)
        if x == ValueRanges.unknown():
            return x

        return cls.trunc(x)

    @classmethod
    def sub(cls, a, b):
        return cls.add(a, cls.neg(b))

    def __getattr__(self, name):
        log.debug("unhandled ValueRange op %s", name)
        return self.default_handler


def bound_sympy(
    expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None
) -> ValueRanges:
    log.debug(
        "bound_sympy(%s)%s",
        expr,
        LazyString(
            lambda: "\n"
            + "\n".join(
                f"  {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols
            )
            if ranges
            else ""
        ),
    )
    if isinstance(expr, sympy.Number):
        return ValueRanges.wrap(expr)

    ranges = ranges or {}

    # If there's a tracing context, augment available constrained ranges.
    context = torch._guards.TracingContext.try_get()
    if context and context.fake_mode.shape_env:
        ranges = {**context.fake_mode.shape_env.var_to_range, **ranges}

    unbounded_vars = expr.free_symbols - ranges.keys()
    if unbounded_vars:
        # Give some bounds to the free variables via their SymPy assumptions
        # TODO A better way of doing this would be to assign them a range upon creation, as
        #      size variables can come with a lower bound of 2, as we specialize on 0 and 1
        unbounded_ranges: Dict[sympy.Symbol, ValueRanges] = {}
        for s in unbounded_vars:
            if s.is_integer:  # type: ignore[attr-defined]
                if s.is_positive:  # type: ignore[attr-defined]
                    vr = ValueRanges(1, int_oo)
                elif s.is_nonnegative:  # type: ignore[attr-defined]
                    vr = ValueRanges(0, int_oo)
                else:
                    vr = ValueRanges.unknown_int()
            else:
                # Don't bother trying very hard here
                vr = ValueRanges.unknown()
            unbounded_ranges[s] = vr  # type: ignore[index]
        ranges = {**ranges, **unbounded_ranges}

    return sympy_interp(SymPyValueRangeAnalysis, ranges, expr)
