import logging
import sys
import types
import threading
import inspect
from functools import wraps
from itertools import chain
from numba.core import config

class TLS(threading.local):
    """Use a subclass to properly initialize the TLS variables in all threads."""
    def __init__(self):
        self.tracing = False
        self.indent = 0

tls = TLS()

def find_function_info(func, spec, args):
    """Return function meta-data in a tuple.

    (name, type)"""

    module = getattr(func, '__module__', None)
    name = getattr(func, '__name__', None)
    self = getattr(func, '__self__', None)
    cname = None
    if self:
        cname = self.__name__
        #cname = self.__class__.__name__
    # Try to deduce the class' name even for unbound methods from their
    # first argument, which we assume to be a class instance if named 'self'...
    elif len(spec.args) and spec.args[0] == 'self':
        cname = args[0].__class__.__name__
    # ...or a class object if named 'cls'
    elif len(spec.args) and spec.args[0] == 'cls':
        cname = args[0].__name__
    if name:
        qname = []
        if module and module != '__main__':
            qname.append(module)
            qname.append('.')
        if cname:
            qname.append(cname)
            qname.append('.')
        qname.append(name)
        name = ''.join(qname)
    return name, None

def chop(value):
    MAX_SIZE = 320
    s = repr(value)
    if len(s) > MAX_SIZE:
        return s[:MAX_SIZE] + '...' + s[-1]
    else:
        return s

def create_events(fname, spec, args, kwds):

    values = dict()
    if spec.defaults:
        values = dict(zip(spec.args[-len(spec.defaults):],spec.defaults))
    values.update(kwds)
    values.update(list(zip(spec.args[:len(args)], args)))
    positional = ['%s=%r'%(a, values.pop(a)) for a in spec.args]
    anonymous = [str(a) for a in args[len(positional):]]
    keywords = ['%s=%r'%(k, values[k]) for k in sorted(values.keys())]
    params = ', '.join([f for f in chain(positional, anonymous, keywords) if f])

    enter = ['>> ', tls.indent * ' ', fname, '(', params, ')']
    leave = ['<< ', tls.indent * ' ', fname]
    return enter, leave


def dotrace(*args, **kwds):
    """Function decorator to trace a function's entry and exit.

    *args: categories in which to trace this function. Example usage:

    @trace
    def function(...):...

    @trace('mycategory')
    def function(...):...


    """

    recursive = kwds.get('recursive', False)
    def decorator(func):

        spec = None
        logger = logging.getLogger('trace')
        def wrapper(*args, **kwds):
            if not logger.isEnabledFor(logging.INFO) or tls.tracing:
                return func(*args, **kwds)

            fname, ftype = find_function_info(func, spec, args)

            try:
                tls.tracing = True
                enter, leave = create_events(fname, spec, args, kwds)

                try:
                    logger.info(''.join(enter))
                    tls.indent += 1
                    try:
                        try:
                            tls.tracing = False
                            result = func(*args, **kwds)
                        finally:
                            tls.tracing = True
                    except:
                        type, value, traceback = sys.exc_info()
                        leave.append(' => exception thrown\n\traise ')
                        mname = type.__module__
                        if mname != '__main__':
                            leave.append(mname)
                            leave.append('.')
                        leave.append(type.__name__)
                        if value.args:
                            leave.append('(')
                            leave.append(', '.join(chop(v) for v in value.args))
                            leave.append(')')
                        else:
                            leave.append('()')
                        raise
                    else:
                        if result is not None:
                            leave.append(' -> ')
                            leave.append(chop(result))
                finally:
                    tls.indent -= 1
                    logger.info(''.join(leave))
            finally:
                tls.tracing = False
            return result
        # wrapper end

        result = None
        rewrap = lambda x: x
        # Unwrap already wrapped functions
        # (to be rewrapped again later)
        if type(func) == classmethod:
            rewrap = type(func)
            # Note: 'func.__func__' only works in Python 3
            func = func.__get__(True).__func__
        elif type(func) == staticmethod:
            rewrap = type(func)
            # Note: 'func.__func__' only works in Python 3
            func = func.__get__(True)
        elif type(func) == property:
            raise NotImplementedError

        spec = inspect.getfullargspec(func)
        return rewrap(wraps(func)(wrapper))

    arg0 = len(args) and args[0] or None
    # not supported yet...
    if recursive:
        raise NotImplementedError
        if inspect.ismodule(arg0):
            for n, f in inspect.getmembers(arg0, inspect.isfunction):
                setattr(arg0, n, decorator(f))
            for n, c in inspect.getmembers(arg0, inspect.isclass):
                dotrace(c, *args, recursive=recursive)
        elif inspect.isclass(arg0):
            for n, f in inspect.getmembers(arg0, lambda x: (inspect.isfunction(x) or
                                                            inspect.ismethod(x))):
                setattr(arg0, n, decorator(f))


    if callable(arg0) or type(arg0) in (classmethod, staticmethod):
        return decorator(arg0)
    elif type(arg0) == property:
        # properties combine up to three functions: 'get', 'set', 'del',
        # so let's wrap them all.
        pget, pset, pdel = None, None, None
        if arg0.fget:
            pget = decorator(arg0.fget)
        if arg0.fset:
            pset = decorator(arg0.fset)
        if arg0.fdel:
            pdel = decorator(arg0.fdel)
        return property(pget, pset, pdel)

    else:
        return decorator

def notrace(*args, **kwds):
    """Just a no-op in case tracing is disabled."""
    def decorator(func):
        return func
    arg0 = len(args) and args[0] or None

    if callable(arg0) or type(arg0) in (classmethod, staticmethod):
        return decorator(arg0)
    else:
        return decorator

def doevent(msg):
    msg = ['== ', tls.indent * ' ', msg]
    logger = logging.getLogger('trace')
    logger.info(''.join(msg))

def noevent(msg):
    pass

if config.TRACE:
    logger = logging.getLogger('trace')
    logger.setLevel(logging.INFO)
    logger.handlers = [logging.StreamHandler()]
    trace = dotrace
    event = doevent
else:
    trace = notrace
    event = noevent
