# mypy: allow-untyped-defs
import dataclasses
import hashlib
import re
import typing
from enum import IntEnum
from typing import Any, Dict, Optional, Union

from torch._export.serde import schema
from torch._export.serde.union import _Union


class SchemaUpdateError(Exception):
    pass


def _check(x, msg):
    if not x:
        raise SchemaUpdateError(msg)


def _staged_schema():
    ret: Dict[str, Any] = {}
    defs = {}

    def _handle_aggregate(ty):
        def dump_type(t):
            if isinstance(t, type):
                return t.__name__
            elif isinstance(t, str):
                assert t in defs
                return t
            elif o := typing.get_origin(t):
                # Lemme know if there's a better way to do this.
                if o == list:
                    head = "List"
                elif o == dict:
                    head = "Dict"
                elif o == tuple:
                    if typing.get_args(t) == ():
                        return "Tuple[()]"
                    head = "Tuple"
                elif o == Union:
                    args = typing.get_args(t)
                    assert len(args) == 2 and args[1] == type(None)
                    return f"Optional[{dump_type(args[0])}]"
                else:
                    raise AssertionError(f"Type {t} is not supported in export schema.")
                return (
                    f"{head}[{', '.join([dump_type(x) for x in typing.get_args(t)])}]"
                )
            elif t == ():
                return "()"
            else:
                raise AssertionError(f"Type {t} is not supported in export schema.")

        def dump_field(f):
            t = dump_type(f.type)
            ret = {"type": t}

            value = dataclasses.MISSING
            if f.default is not dataclasses.MISSING:
                value = f.default
            elif f.default_factory is not dataclasses.MISSING:
                value = f.default_factory()

            if t.startswith("Optional[") and value is not None:
                raise AssertionError(
                    f"Optional field {ty.__name__}.{f.name} must have default value to be None."
                )

            if value is not dataclasses.MISSING:
                default = str(value)
                ret["default"] = default
            return ret

        return {f.name: dump_field(f) for f in dataclasses.fields(ty)}

    def _handle_int_enum(name, ty):
        ret[name] = {"kind": "enum", "fields": {x.name: x.value for x in ty}}

    def _handle_struct(name, ty):
        ret[name] = {"kind": "struct", "fields": _handle_aggregate(ty)}

    def _handle_union(name, ty):
        ret[name] = {"kind": "union", "fields": _handle_aggregate(ty)}

    for name in dir(schema):
        if name.startswith("_"):
            continue

        value = getattr(schema, name)

        if hasattr(value, "__module__") and value.__module__ != schema.__name__:
            continue

        defs[name] = value

    for name, value in defs.items():
        if isinstance(value, type):
            if issubclass(value, IntEnum):
                _handle_int_enum(name, value)
            elif dataclasses.is_dataclass(value):
                if issubclass(value, _Union):
                    _handle_union(name, value)
                else:
                    _handle_struct(name, value)
            else:
                raise AssertionError(f"Unknown schema type {name}: {value}")
        elif isinstance(value, (int, tuple)):
            assert name in ("SCHEMA_VERSION", "TREESPEC_VERSION")
        else:
            raise AssertionError(f"Unknown variable {name}: {value}")

    ret["SCHEMA_VERSION"] = list(defs["SCHEMA_VERSION"])
    assert all(x > 0 for x in ret["SCHEMA_VERSION"])
    ret["TREESPEC_VERSION"] = defs["TREESPEC_VERSION"]
    assert ret["TREESPEC_VERSION"] > 0
    return ret


def _diff_schema(dst, src):
    additions = {key: src[key] for key in src.keys() - dst.keys()}
    subtractions = {key: dst[key] for key in dst.keys() - src.keys()}

    common_keys = src.keys() & dst.keys()

    versions = {"SCHEMA_VERSION", "TREESPEC_VERSION"}
    common_keys -= versions

    for key in common_keys:
        src_kind = src[key]["kind"]
        src_fields = src[key]["fields"]
        dst_kind = dst[key]["kind"]
        dst_fields = dst[key]["fields"]
        _check(
            src_kind == dst_kind,
            f"Type {key} changed kind from {dst_kind} to {src_kind}",
        )
        assert isinstance(src_fields, dict) and isinstance(dst_fields, dict)
        added_fields = {
            key: src_fields[key] for key in src_fields.keys() - dst_fields.keys()
        }
        subtracted_fields = {
            key: dst_fields[key] for key in dst_fields.keys() - src_fields.keys()
        }
        common_fields = src_fields.keys() & dst_fields.keys()

        for field in common_fields:
            src_field = src_fields[field]
            dst_field = dst_fields[field]
            if src_kind == "struct":
                _check(
                    src_field["type"] == dst_field["type"],
                    f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
                )
                if "default" in src_field and "default" not in dst_field:
                    added_fields[field] = {}
                    added_fields[field]["default"] = src_field["default"]
                if "default" not in src_field and "default" in dst_field:
                    subtracted_fields[field] = {}
                    subtracted_fields[field]["default"] = dst_field["default"]
            elif src_kind == "enum":
                _check(
                    src_field == dst_field,
                    f"Value of the enum field {key}.{field} changed from {dst_field} to {src_field}",
                )
            elif src_kind == "union":
                _check(
                    src_field["type"] == dst_field["type"],
                    f"Type of the field {key}.{field} changed from {dst_field['type']} to {src_field['type']}",
                )
            else:
                raise AssertionError(f"Unknown kind {src_kind}: {key}")
        if len(added_fields) > 0:
            assert key not in additions
            additions[key] = {}
            additions[key]["fields"] = added_fields
        if len(subtracted_fields) > 0:
            assert key not in subtractions
            subtractions[key] = {}
            subtractions[key]["fields"] = subtracted_fields

    return additions, subtractions


def _hash_schema(s):
    return hashlib.sha256(repr(s).encode("utf-8")).hexdigest()


@dataclasses.dataclass
class _Commit:
    result: Dict[str, Any]
    checksum_result: str
    path: str
    additions: Dict[str, Any]
    subtractions: Dict[str, Any]
    base: Dict[str, Any]
    checksum_base: Optional[str]


def update_schema():
    import importlib.resources

    if importlib.resources.is_resource(__package__, "schema.yaml"):
        content = importlib.resources.read_text(__package__, "schema.yaml")
        match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
        _check(match is not None, "checksum not found in schema.yaml")
        assert match is not None
        checksum_base = match.group(1)
        from yaml import load, Loader

        dst = load(content, Loader=Loader)
        assert isinstance(dst, dict)
    else:
        checksum_base = None
        dst = {"SCHEMA_VERSION": None, "TREESPEC_VERSION": None}

    src = _staged_schema()
    additions, subtractions = _diff_schema(dst, src)
    return _Commit(
        result=src,
        checksum_result=_hash_schema(src),
        path=__package__.replace(".", "/") + "/schema.yaml",
        additions=additions,
        subtractions=subtractions,
        base=dst,
        checksum_base=checksum_base,
    )


def check(commit: _Commit, force_unsafe: bool = False):
    next_version = None
    reason = ""
    # Step 1: Detect major schema updates.
    if len(commit.additions) > 0:
        for k, v in commit.additions.items():
            if k not in commit.base:
                continue
            kind = commit.result[k]["kind"]
            fields = v["fields"]
            for f, d in fields.items():
                if "default" not in d and kind == "struct":
                    reason += (
                        f"Field {k}.{f} is added to schema.py without a default value as an incomparible change "
                        + "which requires major version bump.\n"
                    )
                    next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]

    if len(commit.subtractions) > 0:
        for k, v in commit.subtractions.items():
            if k not in commit.result:
                continue
            for f in v["fields"]:
                reason = f"Field {k}.{f} is removed from schema.py as an incompatible change which requires major version bump.\n"
            next_version = [commit.base["SCHEMA_VERSION"][0] + 1, 1]

    if force_unsafe:
        reason += "--force-unsafe is used."
        next_version = commit.result["SCHEMA_VERSION"]
    else:
        # Step 2: Detect minor schema updates.
        if next_version is None and len(commit.additions) > 0:
            for k, v in commit.additions.items():
                for f in v["fields"]:
                    reason += (
                        f"Field {k}.{f} is added to schema.py as an compatible change "
                        + "which still requires minor version bump.\n"
                    )
            next_version = [
                commit.base["SCHEMA_VERSION"][0],
                commit.base["SCHEMA_VERSION"][1] + 1,
            ]
        if next_version is None and len(commit.subtractions) > 0:
            for k, v in commit.subtractions.items():
                for f in v["fields"]:
                    reason += (
                        f"Field {k}.{f} is removed from schema.py as an compatible change "
                        + "which still requires minor version bump.\n"
                    )
            next_version = [
                commit.base["SCHEMA_VERSION"][0],
                commit.base["SCHEMA_VERSION"][1] + 1,
            ]

    return next_version, reason
