"""
Serialization support for compiled functions.
"""
import sys
import abc
import io
import copyreg


import pickle
from numba import cloudpickle
from llvmlite import ir


#
# Pickle support
#

def _rebuild_reduction(cls, *args):
    """
    Global hook to rebuild a given class from its __reduce__ arguments.
    """
    return cls._rebuild(*args)


# Keep unpickled object via `numba_unpickle` alive.
_unpickled_memo = {}


def _numba_unpickle(address, bytedata, hashed):
    """Used by `numba_unpickle` from _helperlib.c

    Parameters
    ----------
    address : int
    bytedata : bytes
    hashed : bytes

    Returns
    -------
    obj : object
        unpickled object
    """
    key = (address, hashed)
    try:
        obj = _unpickled_memo[key]
    except KeyError:
        _unpickled_memo[key] = obj = cloudpickle.loads(bytedata)
    return obj


def dumps(obj):
    """Similar to `pickle.dumps()`. Returns the serialized object in bytes.
    """
    pickler = NumbaPickler
    with io.BytesIO() as buf:
        p = pickler(buf, protocol=4)
        p.dump(obj)
        pickled = buf.getvalue()

    return pickled


def runtime_build_excinfo_struct(static_exc, exc_args):
    exc, static_args, locinfo = cloudpickle.loads(static_exc)
    real_args = []
    exc_args_iter = iter(exc_args)
    for arg in static_args:
        if isinstance(arg, ir.Value):
            real_args.append(next(exc_args_iter))
        else:
            real_args.append(arg)
    return (exc, tuple(real_args), locinfo)


# Alias to pickle.loads to allow `serialize.loads()`
loads = cloudpickle.loads


class _CustomPickled:
    """A wrapper for objects that must be pickled with `NumbaPickler`.

    Standard `pickle` will pick up the implementation registered via `copyreg`.
    This will spawn a `NumbaPickler` instance to serialize the data.

    `NumbaPickler` overrides the handling of this type so as not to spawn a
    new pickler for the object when it is already being pickled by a
    `NumbaPickler`.
    """

    __slots__ = 'ctor', 'states'

    def __init__(self, ctor, states):
        self.ctor = ctor
        self.states = states

    def _reduce(self):
        return _CustomPickled._rebuild, (self.ctor, self.states)

    @classmethod
    def _rebuild(cls, ctor, states):
        return cls(ctor, states)


def _unpickle__CustomPickled(serialized):
    """standard unpickling for `_CustomPickled`.

    Uses `NumbaPickler` to load.
    """
    ctor, states = loads(serialized)
    return _CustomPickled(ctor, states)


def _pickle__CustomPickled(cp):
    """standard pickling for `_CustomPickled`.

    Uses `NumbaPickler` to dump.
    """
    serialized = dumps((cp.ctor, cp.states))
    return _unpickle__CustomPickled, (serialized,)


# Register custom pickling for the standard pickler.
copyreg.pickle(_CustomPickled, _pickle__CustomPickled)


def custom_reduce(cls, states):
    """For customizing object serialization in `__reduce__`.

    Object states provided here are used as keyword arguments to the
    `._rebuild()` class method.

    Parameters
    ----------
    states : dict
        Dictionary of object states to be serialized.

    Returns
    -------
    result : tuple
        This tuple conforms to the return type requirement for `__reduce__`.
    """
    return custom_rebuild, (_CustomPickled(cls, states),)


def custom_rebuild(custom_pickled):
    """Customized object deserialization.

    This function is referenced internally by `custom_reduce()`.
    """
    cls, states = custom_pickled.ctor, custom_pickled.states
    return cls._rebuild(**states)


def is_serialiable(obj):
    """Check if *obj* can be serialized.

    Parameters
    ----------
    obj : object

    Returns
    --------
    can_serialize : bool
    """
    with io.BytesIO() as fout:
        pickler = NumbaPickler(fout)
        try:
            pickler.dump(obj)
        except pickle.PicklingError:
            return False
        else:
            return True


def _no_pickle(obj):
    raise pickle.PicklingError(f"Pickling of {type(obj)} is unsupported")


def disable_pickling(typ):
    """This is called on a type to disable pickling
    """
    NumbaPickler.disabled_types.add(typ)
    # Return `typ` to allow use as a decorator
    return typ


class NumbaPickler(cloudpickle.CloudPickler):
    disabled_types = set()
    """A set of types that pickling cannot is disabled.
    """

    def reducer_override(self, obj):
        # Overridden to disable pickling of certain types
        if type(obj) in self.disabled_types:
            _no_pickle(obj)  # noreturn
        return super().reducer_override(obj)


def _custom_reduce__custompickled(cp):
    return cp._reduce()


NumbaPickler.dispatch_table[_CustomPickled] = _custom_reduce__custompickled


class ReduceMixin(abc.ABC):
    """A mixin class for objects that should be reduced by the NumbaPickler
    instead of the standard pickler.
    """
    # Subclass MUST override the below methods

    @abc.abstractmethod
    def _reduce_states(self):
        raise NotImplementedError

    @abc.abstractclassmethod
    def _rebuild(cls, **kwargs):
        raise NotImplementedError

    # Subclass can override the below methods

    def _reduce_class(self):
        return self.__class__

    # Private methods

    def __reduce__(self):
        return custom_reduce(self._reduce_class(), self._reduce_states())


class PickleCallableByPath:
    """Wrap a callable object to be pickled by path to workaround limitation
    in pickling due to non-pickleable objects in function non-locals.

    Note:
    - Do not use this as a decorator.
    - Wrapped object must be a global that exist in its parent module and it
      can be imported by `from the_module import the_object`.

    Usage:

    >>> def my_fn(x):
    >>>     ...
    >>> wrapped_fn = PickleCallableByPath(my_fn)
    >>> # refer to `wrapped_fn` instead of `my_fn`
    """
    def __init__(self, fn):
        self._fn = fn

    def __call__(self, *args, **kwargs):
        return self._fn(*args, **kwargs)

    def __reduce__(self):
        return type(self)._rebuild, (self._fn.__module__, self._fn.__name__,)

    @classmethod
    def _rebuild(cls, modname, fn_path):
        return cls(getattr(sys.modules[modname], fn_path))
