import copy
import itertools
import math
import random
import sys
import unittest

import numpy as np

from numba import jit, njit
from numba.core import utils, errors
from numba.tests.support import TestCase, MemoryLeakMixin

from numba.misc.quicksort import make_py_quicksort, make_jit_quicksort
from numba.misc.mergesort import make_jit_mergesort
from numba.misc.timsort import make_py_timsort, make_jit_timsort, MergeRun


def make_temp_list(keys, n):
    return [keys[0]] * n

def make_temp_array(keys, n):
    return np.empty(n, keys.dtype)


py_list_timsort = make_py_timsort(make_temp_list)

py_array_timsort = make_py_timsort(make_temp_array)

jit_list_timsort = make_jit_timsort(make_temp_list)

jit_array_timsort = make_jit_timsort(make_temp_array)

py_quicksort = make_py_quicksort()

jit_quicksort = make_jit_quicksort()


def sort_usecase(val):
    val.sort()

def argsort_usecase(val):
    return val.argsort()

def argsort_kind_usecase(val, is_stable=False):
    if is_stable:
        return val.argsort(kind='mergesort')
    else:
        return val.argsort(kind='quicksort')

def sorted_usecase(val):
    return sorted(val)

def sorted_reverse_usecase(val, b):
    return sorted(val, reverse=b)

def np_sort_usecase(val):
    return np.sort(val)

def np_argsort_usecase(val):
    return np.argsort(val)

def np_argsort_kind_usecase(val, is_stable=False):
    if is_stable:
        return np.argsort(val, kind='mergesort')
    else:
        return np.argsort(val, kind='quicksort')

def list_sort_usecase(n):
    np.random.seed(42)
    l = []
    for i in range(n):
        l.append(np.random.random())
    ll = l[:]
    ll.sort()
    return l, ll

def list_sort_reverse_usecase(n, b):
    np.random.seed(42)
    l = []
    for i in range(n):
        l.append(np.random.random())
    ll = l[:]
    ll.sort(reverse=b)
    return l, ll


class BaseSortingTest(object):

    def random_list(self, n, offset=10):
        random.seed(42)
        l = list(range(offset, offset + n))
        random.shuffle(l)
        return l

    def sorted_list(self, n, offset=10):
        return list(range(offset, offset + n))

    def revsorted_list(self, n, offset=10):
        return list(range(offset, offset + n))[::-1]

    def initially_sorted_list(self, n, m=None, offset=10):
        if m is None:
            m = n // 2
        l = self.sorted_list(m, offset)
        l += self.random_list(n - m, offset=l[-1] + offset)
        return l

    def duprandom_list(self, n, factor=None, offset=10):
        random.seed(42)
        if factor is None:
            factor = int(math.sqrt(n))
        l = (list(range(offset, offset + (n // factor) + 1)) * (factor + 1))[:n]
        assert len(l) == n
        random.shuffle(l)
        return l

    def dupsorted_list(self, n, factor=None, offset=10):
        if factor is None:
            factor = int(math.sqrt(n))
        l = (list(range(offset, offset + (n // factor) + 1)) * (factor + 1))[:n]
        assert len(l) == n, (len(l), n)
        l.sort()
        return l

    def assertSorted(self, orig, result):
        self.assertEqual(len(result), len(orig))
        # sorted() returns a list, so make sure we compare to another list
        self.assertEqual(list(result), sorted(orig))

    def assertSortedValues(self, orig, orig_values, result, result_values):
        self.assertEqual(len(result), len(orig))
        self.assertEqual(list(result), sorted(orig))
        zip_sorted = sorted(zip(orig, orig_values), key=lambda x: x[0])
        zip_result = list(zip(result, result_values))
        self.assertEqual(zip_sorted, zip_result)
        # Check stability
        for i in range(len(zip_result) - 1):
            (k1, v1), (k2, v2) = zip_result[i], zip_result[i + 1]
            if k1 == k2:
                # Assuming values are unique, which is enforced by the tests
                self.assertLess(orig_values.index(v1), orig_values.index(v2))

    def fibo(self):
        a = 1
        b = 1
        while True:
            yield a
            a, b = b, a + b

    def make_sample_sorted_lists(self, n):
        lists = []
        for offset in (20, 120):
            lists.append(self.sorted_list(n, offset))
            lists.append(self.dupsorted_list(n, offset))
        return lists

    def make_sample_lists(self, n):
        lists = []
        for offset in (20, 120):
            lists.append(self.sorted_list(n, offset))
            lists.append(self.dupsorted_list(n, offset))
            lists.append(self.revsorted_list(n, offset))
            lists.append(self.duprandom_list(n, offset))
        return lists


class BaseTimsortTest(BaseSortingTest):

    def merge_init(self, keys):
        f = self.timsort.merge_init
        return f(keys)

    def test_binarysort(self):
        n = 20
        def check(l, n, start=0):
            res = self.array_factory(l)
            f(res, res, 0, n, start)
            self.assertSorted(l, res)

        f = self.timsort.binarysort
        l = self.sorted_list(n)
        check(l, n)
        check(l, n, n//2)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        check(l, n, n//2)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_binarysort_with_values(self):
        n = 20
        v = list(range(100, 100+n))

        def check(l, n, start=0):
            res = self.array_factory(l)
            res_v = self.array_factory(v)
            f(res, res_v, 0, n, start)
            self.assertSortedValues(l, v, res, res_v)

        f = self.timsort.binarysort
        l = self.sorted_list(n)
        check(l, n)
        check(l, n, n//2)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        check(l, n, n//2)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_count_run(self):
        n = 16
        f = self.timsort.count_run

        def check(l, lo, hi):
            n, desc = f(self.array_factory(l), lo, hi)
            # Fully check invariants
            if desc:
                for k in range(lo, lo + n - 1):
                    a, b = l[k], l[k + 1]
                    self.assertGreater(a, b)
                if lo + n < hi:
                    self.assertLessEqual(l[lo + n - 1], l[lo + n])
            else:
                for k in range(lo, lo + n - 1):
                    a, b = l[k], l[k + 1]
                    self.assertLessEqual(a, b)
                if lo + n < hi:
                    self.assertGreater(l[lo + n - 1], l[lo + n], l)


        l = self.sorted_list(n, offset=100)
        check(l, 0, n)
        check(l, 1, n - 1)
        check(l, 1, 2)
        l = self.revsorted_list(n, offset=100)
        check(l, 0, n)
        check(l, 1, n - 1)
        check(l, 1, 2)
        l = self.random_list(n, offset=100)
        for i in range(len(l) - 1):
            check(l, i, n)
        l = self.duprandom_list(n, offset=100)
        for i in range(len(l) - 1):
            check(l, i, n)

    def test_gallop_left(self):
        n = 20
        f = self.timsort.gallop_left

        def check(l, key, start, stop, hint):
            k = f(key, l, start, stop, hint)
            # Fully check invariants
            self.assertGreaterEqual(k, start)
            self.assertLessEqual(k, stop)
            if k > start:
                self.assertLess(l[k - 1], key)
            if k < stop:
                self.assertGreaterEqual(l[k], key)

        def check_all_hints(l, key, start, stop):
            for hint in range(start, stop):
                check(l, key, start, stop, hint)

        def check_sorted_list(l):
            l = self.array_factory(l)
            for key in (l[5], l[15], l[0], -1000, l[-1], 1000):
                check_all_hints(l, key, 0, n)
                check_all_hints(l, key, 1, n - 1)
                check_all_hints(l, key, 8, n - 8)

        l = self.sorted_list(n, offset=100)
        check_sorted_list(l)
        l = self.dupsorted_list(n, offset=100)
        check_sorted_list(l)

    def test_gallop_right(self):
        n = 20
        f = self.timsort.gallop_right

        def check(l, key, start, stop, hint):
            k = f(key, l, start, stop, hint)
            # Fully check invariants
            self.assertGreaterEqual(k, start)
            self.assertLessEqual(k, stop)
            if k > start:
                self.assertLessEqual(l[k - 1], key)
            if k < stop:
                self.assertGreater(l[k], key)

        def check_all_hints(l, key, start, stop):
            for hint in range(start, stop):
                check(l, key, start, stop, hint)

        def check_sorted_list(l):
            l = self.array_factory(l)
            for key in (l[5], l[15], l[0], -1000, l[-1], 1000):
                check_all_hints(l, key, 0, n)
                check_all_hints(l, key, 1, n - 1)
                check_all_hints(l, key, 8, n - 8)

        l = self.sorted_list(n, offset=100)
        check_sorted_list(l)
        l = self.dupsorted_list(n, offset=100)
        check_sorted_list(l)

    def test_merge_compute_minrun(self):
        f = self.timsort.merge_compute_minrun

        for i in range(0, 64):
            self.assertEqual(f(i), i)
        for i in range(6, 63):
            if 2**i > sys.maxsize:
                break
            self.assertEqual(f(2**i), 32)
        for i in self.fibo():
            if i < 64:
                continue
            if i >= sys.maxsize:
                break
            k = f(i)
            self.assertGreaterEqual(k, 32)
            self.assertLessEqual(k, 64)
            if i > 500:
                # i/k is close to, but strictly less than, an exact power of 2
                quot = i // k
                p = 2 ** utils.bit_length(quot)
                self.assertLess(quot, p)
                self.assertGreaterEqual(quot, 0.9 * p)

    def check_merge_lo_hi(self, func, a, b):
        na = len(a)
        nb = len(b)

        # Add sentinels at start and end, to check they weren't moved
        orig_keys = [42] + a + b + [-42]
        keys = self.array_factory(orig_keys)
        ms = self.merge_init(keys)
        ssa = 1
        ssb = ssa + na

        #new_ms = func(ms, keys, [], ssa, na, ssb, nb)
        new_ms = func(ms, keys, keys, ssa, na, ssb, nb)
        self.assertEqual(keys[0], orig_keys[0])
        self.assertEqual(keys[-1], orig_keys[-1])
        self.assertSorted(orig_keys[1:-1], keys[1:-1])
        # Check the MergeState result
        self.assertGreaterEqual(len(new_ms.keys), len(ms.keys))
        self.assertGreaterEqual(len(new_ms.values), len(ms.values))
        self.assertIs(new_ms.pending, ms.pending)
        self.assertGreaterEqual(new_ms.min_gallop, 1)

    def test_merge_lo_hi(self):
        f_lo = self.timsort.merge_lo
        f_hi = self.timsort.merge_hi

        # The larger sizes exercise galloping
        for (na, nb) in [(12, 16), (40, 40), (100, 110), (1000, 1100)]:
            for a, b in itertools.product(self.make_sample_sorted_lists(na),
                                          self.make_sample_sorted_lists(nb)):
                self.check_merge_lo_hi(f_lo, a, b)
                self.check_merge_lo_hi(f_hi, b, a)

    def check_merge_at(self, a, b):
        f = self.timsort.merge_at
        # Prepare the array to be sorted
        na = len(a)
        nb = len(b)
        # Add sentinels at start and end, to check they weren't moved
        orig_keys = [42] + a + b + [-42]
        ssa = 1
        ssb = ssa + na

        stack_sentinel = MergeRun(-42, -42)

        def run_merge_at(ms, keys, i):
            new_ms = f(ms, keys, keys, i)
            self.assertEqual(keys[0], orig_keys[0])
            self.assertEqual(keys[-1], orig_keys[-1])
            self.assertSorted(orig_keys[1:-1], keys[1:-1])
            # Check stack state
            self.assertIs(new_ms.pending, ms.pending)
            self.assertEqual(ms.pending[i], (ssa, na + nb))
            self.assertEqual(ms.pending[0], stack_sentinel)
            return new_ms

        # First check with i == len(stack) - 2
        keys = self.array_factory(orig_keys)
        ms = self.merge_init(keys)
        # Push sentinel on stack, to check it wasn't touched
        ms = self.timsort.merge_append(ms, stack_sentinel)
        i = ms.n
        ms = self.timsort.merge_append(ms, MergeRun(ssa, na))
        ms = self.timsort.merge_append(ms, MergeRun(ssb, nb))
        ms = run_merge_at(ms, keys, i)
        self.assertEqual(ms.n, i + 1)

        # Now check with i == len(stack) - 3
        keys = self.array_factory(orig_keys)
        ms = self.merge_init(keys)
        # Push sentinel on stack, to check it wasn't touched
        ms = self.timsort.merge_append(ms, stack_sentinel)
        i = ms.n
        ms = self.timsort.merge_append(ms, MergeRun(ssa, na))
        ms = self.timsort.merge_append(ms, MergeRun(ssb, nb))
        # A last run (trivial here)
        last_run = MergeRun(ssb + nb, 1)
        ms = self.timsort.merge_append(ms, last_run)
        ms = run_merge_at(ms, keys, i)
        self.assertEqual(ms.n, i + 2)
        self.assertEqual(ms.pending[ms.n - 1], last_run)

    def test_merge_at(self):
        # The larger sizes exercise galloping
        for (na, nb) in [(12, 16), (40, 40), (100, 110), (500, 510)]:
            for a, b in itertools.product(self.make_sample_sorted_lists(na),
                                          self.make_sample_sorted_lists(nb)):
                self.check_merge_at(a, b)
                self.check_merge_at(b, a)

    def test_merge_force_collapse(self):
        f = self.timsort.merge_force_collapse

        # Test with runs of ascending sizes, then descending sizes
        sizes_list = [(8, 10, 15, 20)]
        sizes_list.append(sizes_list[0][::-1])

        for sizes in sizes_list:
            for chunks in itertools.product(*(self.make_sample_sorted_lists(n)
                                              for n in sizes)):
                # Create runs of the given sizes
                orig_keys = sum(chunks, [])
                keys = self.array_factory(orig_keys)
                ms = self.merge_init(keys)
                pos = 0
                for c in chunks:
                    ms = self.timsort.merge_append(ms, MergeRun(pos, len(c)))
                    pos += len(c)
                # Sanity check
                self.assertEqual(sum(ms.pending[ms.n - 1]), len(keys))
                # Now merge the runs
                ms = f(ms, keys, keys)
                # Remaining run is the whole list
                self.assertEqual(ms.n, 1)
                self.assertEqual(ms.pending[0], MergeRun(0, len(keys)))
                # The list is now sorted
                self.assertSorted(orig_keys, keys)

    def test_run_timsort(self):
        f = self.timsort.run_timsort

        for size_factor in (1, 10):
            # Make lists to be sorted from three chunks of different kinds.
            sizes = (15, 30, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                keys = self.array_factory(orig_keys)
                f(keys)
                # The list is now sorted
                self.assertSorted(orig_keys, keys)

    def test_run_timsort_with_values(self):
        # Run timsort, but also with a values array
        f = self.timsort.run_timsort_with_values

        for size_factor in (1, 5):
            chunk_size = 80 * size_factor
            a = self.dupsorted_list(chunk_size)
            b = self.duprandom_list(chunk_size)
            c = self.revsorted_list(chunk_size)
            orig_keys = a + b + c
            orig_values = list(range(1000, 1000 + len(orig_keys)))

            keys = self.array_factory(orig_keys)
            values = self.array_factory(orig_values)
            f(keys, values)
            # This checks sort stability
            self.assertSortedValues(orig_keys, orig_values, keys, values)


class TestTimsortPurePython(BaseTimsortTest, TestCase):

    timsort = py_list_timsort

    # Much faster than a Numpy array in pure Python
    array_factory = list


class TestTimsortArraysPurePython(BaseTimsortTest, TestCase):

    timsort = py_array_timsort

    def array_factory(self, lst):
        return np.array(lst, dtype=np.int32)


class JITTimsortMixin(object):

    timsort = jit_array_timsort

    test_merge_at = None
    test_merge_force_collapse = None

    def wrap_with_mergestate(self, timsort, func, _cache={}):
        """
        Wrap *func* into another compiled function inserting a runtime-created
        mergestate as the first function argument.
        """
        key = timsort, func
        if key in _cache:
            return _cache[key]

        merge_init = timsort.merge_init

        @timsort.compile
        def wrapper(keys, values, *args):
            ms = merge_init(keys)
            res = func(ms, keys, values, *args)
            return res

        _cache[key] = wrapper
        return wrapper


class TestTimsortArrays(JITTimsortMixin, BaseTimsortTest, TestCase):

    def array_factory(self, lst):
        return np.array(lst, dtype=np.int32)

    def check_merge_lo_hi(self, func, a, b):
        na = len(a)
        nb = len(b)

        func = self.wrap_with_mergestate(self.timsort, func)

        # Add sentinels at start and end, to check they weren't moved
        orig_keys = [42] + a + b + [-42]
        keys = self.array_factory(orig_keys)
        ssa = 1
        ssb = ssa + na

        new_ms = func(keys, keys, ssa, na, ssb, nb)
        self.assertEqual(keys[0], orig_keys[0])
        self.assertEqual(keys[-1], orig_keys[-1])
        self.assertSorted(orig_keys[1:-1], keys[1:-1])



class BaseQuicksortTest(BaseSortingTest):

    # NOTE these tests assume a non-argsort quicksort.

    def test_insertion_sort(self):
        n = 20
        def check(l, n):
            res = self.array_factory([9999] + l + [-9999])
            f(res, res, 1, n)
            self.assertEqual(res[0], 9999)
            self.assertEqual(res[-1], -9999)
            self.assertSorted(l, res[1:-1])

        f = self.quicksort.insertion_sort
        l = self.sorted_list(n)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_partition(self):
        n = 20
        def check(l, n):
            res = self.array_factory([9999] + l + [-9999])
            index = f(res, res, 1, n)
            self.assertEqual(res[0], 9999)
            self.assertEqual(res[-1], -9999)
            pivot = res[index]
            for i in range(1, index):
                self.assertLessEqual(res[i], pivot)
            for i in range(index + 1, n):
                self.assertGreaterEqual(res[i], pivot)

        f = self.quicksort.partition
        l = self.sorted_list(n)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_partition3(self):
        # Test the unused partition3() function
        n = 20
        def check(l, n):
            res = self.array_factory([9999] + l + [-9999])
            lt, gt = f(res, 1, n)
            self.assertEqual(res[0], 9999)
            self.assertEqual(res[-1], -9999)
            pivot = res[lt]
            for i in range(1, lt):
                self.assertLessEqual(res[i], pivot)
            for i in range(lt, gt + 1):
                self.assertEqual(res[i], pivot)
            for i in range(gt + 1, n):
                self.assertGreater(res[i], pivot)

        f = self.quicksort.partition3
        l = self.sorted_list(n)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.initially_sorted_list(n, n//2)
        check(l, n)
        l = self.revsorted_list(n)
        check(l, n)
        l = self.random_list(n)
        check(l, n)
        l = self.duprandom_list(n)
        check(l, n)

    def test_run_quicksort(self):
        f = self.quicksort.run_quicksort

        for size_factor in (1, 5):
            # Make lists to be sorted from two chunks of different kinds.
            sizes = (15, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                keys = self.array_factory(orig_keys)
                f(keys)
                # The list is now sorted
                self.assertSorted(orig_keys, keys)

    def test_run_quicksort_lt(self):
        def lt(a, b):
            return a > b

        f = self.make_quicksort(lt=lt).run_quicksort

        for size_factor in (1, 5):
            # Make lists to be sorted from two chunks of different kinds.
            sizes = (15, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                keys = self.array_factory(orig_keys)
                f(keys)
                # The list is now rev-sorted
                self.assertSorted(orig_keys, keys[::-1])

        # An imperfect comparison function, as LT(a, b) does not imply not LT(b, a).
        # The sort should handle it gracefully.
        def lt_floats(a, b):
            return math.isnan(b) or a < b

        f = self.make_quicksort(lt=lt_floats).run_quicksort

        np.random.seed(42)
        for size in (5, 20, 50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.1] = float('nan')
            orig_keys = list(orig)
            keys = self.array_factory(orig_keys)
            f(keys)
            non_nans = orig[~np.isnan(orig)]
            # Non-NaNs are sorted at the front
            self.assertSorted(non_nans, keys[:len(non_nans)])


class TestQuicksortPurePython(BaseQuicksortTest, TestCase):

    quicksort = py_quicksort
    make_quicksort = staticmethod(make_py_quicksort)

    # Much faster than a Numpy array in pure Python
    array_factory = list


class TestQuicksortArrays(BaseQuicksortTest, TestCase):

    quicksort = jit_quicksort
    make_quicksort = staticmethod(make_jit_quicksort)

    def array_factory(self, lst):
        return np.array(lst, dtype=np.float64)

class TestQuicksortMultidimensionalArrays(BaseSortingTest, TestCase):

    quicksort = make_jit_quicksort(is_np_array=True)
    make_quicksort = staticmethod(make_jit_quicksort)

    def assertSorted(self, orig, result):
        self.assertEqual(orig.shape, result.shape)
        self.assertPreciseEqual(orig, result)

    def array_factory(self, lst, shape=None):
        array = np.array(lst, dtype=np.float64)
        if shape is None:
            return array.reshape(-1, array.shape[0])
        else:
            return array.reshape(shape)

    def get_shapes(self, n):
        shapes = []
        if n == 1:
            return shapes

        for i in range(2, int(math.sqrt(n)) + 1):
            if n % i == 0:
                shapes.append((n // i, i))
                shapes.append((i, n // i))
                _shapes = self.get_shapes(n // i)
                for _shape in _shapes:
                    shapes.append((i,) + _shape)
                    shapes.append(_shape + (i,))

        return shapes

    def test_run_quicksort(self):
        f = self.quicksort.run_quicksort

        for size_factor in (1, 5):
            # Make lists to be sorted from two chunks of different kinds.
            sizes = (15, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                shape_list = self.get_shapes(len(orig_keys))
                shape_list.append(None)
                for shape in shape_list:
                    keys = self.array_factory(orig_keys, shape=shape)
                    keys_copy = self.array_factory(orig_keys, shape=shape)
                    f(keys)
                    keys_copy.sort()
                    # The list is now sorted
                    self.assertSorted(keys_copy, keys)

    def test_run_quicksort_lt(self):
        def lt(a, b):
            return a > b

        f = self.make_quicksort(lt=lt, is_np_array=True).run_quicksort

        for size_factor in (1, 5):
            # Make lists to be sorted from two chunks of different kinds.
            sizes = (15, 20)

            all_lists = [self.make_sample_lists(n * size_factor) for n in sizes]
            for chunks in itertools.product(*all_lists):
                orig_keys = sum(chunks, [])
                shape_list = self.get_shapes(len(orig_keys))
                shape_list.append(None)
                for shape in shape_list:
                    keys = self.array_factory(orig_keys, shape=shape)
                    keys_copy = -self.array_factory(orig_keys, shape=shape)
                    f(keys)
                    # The list is now rev-sorted
                    keys_copy.sort()
                    keys_copy = -keys_copy
                    self.assertSorted(keys_copy, keys)

        # An imperfect comparison function, as LT(a, b) does not imply not LT(b, a).
        # The sort should handle it gracefully.
        def lt_floats(a, b):
            return math.isnan(b) or a < b

        f = self.make_quicksort(lt=lt_floats, is_np_array=True).run_quicksort

        np.random.seed(42)
        for size in (5, 20, 50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.1] = float('nan')
            orig_keys = list(orig)
            shape_list = self.get_shapes(len(orig_keys))
            shape_list.append(None)
            for shape in shape_list:
                keys = self.array_factory(orig_keys, shape=shape)
                keys_copy = self.array_factory(orig_keys, shape=shape)
                f(keys)
                keys_copy.sort()
                # Non-NaNs are sorted at the front
                self.assertSorted(keys_copy, keys)

class TestNumpySort(TestCase):

    def setUp(self):
        np.random.seed(42)

    def int_arrays(self):
        for size in (5, 20, 50, 500):
            yield np.random.randint(99, size=size)

    def float_arrays(self):
        for size in (5, 20, 50, 500):
            yield np.random.random(size=size) * 100
        # Now with NaNs.  Numpy sorts them at the end.
        for size in (5, 20, 50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.1] = float('nan')
            yield orig
        # 90% of values are NaNs.
        for size in (50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.9] = float('nan')
            yield orig

    def has_duplicates(self, arr):
        """
        Whether the array has duplicates.  Takes NaNs into account.
        """
        if np.count_nonzero(np.isnan(arr)) > 1:
            return True
        if np.unique(arr).size < arr.size:
            return True
        return False

    def check_sort_inplace(self, pyfunc, cfunc, val):
        expected = copy.copy(val)
        got = copy.copy(val)
        pyfunc(expected)
        cfunc(got)
        self.assertPreciseEqual(got, expected)

    def check_sort_copy(self, pyfunc, cfunc, val):
        orig = copy.copy(val)
        expected = pyfunc(val)
        got = cfunc(val)
        self.assertPreciseEqual(got, expected)
        # The original wasn't mutated
        self.assertPreciseEqual(val, orig)

    def check_argsort(self, pyfunc, cfunc, val, kwargs={}):
        orig = copy.copy(val)
        expected = pyfunc(val, **kwargs)
        got = cfunc(val, **kwargs)
        self.assertPreciseEqual(orig[got], np.sort(orig),
                                msg="the array wasn't argsorted")
        # Numba and Numpy results may differ if there are duplicates
        # in the array
        if not self.has_duplicates(orig):
            self.assertPreciseEqual(got, expected)
        # The original wasn't mutated
        self.assertPreciseEqual(val, orig)

    def test_array_sort_int(self):
        pyfunc = sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for orig in self.int_arrays():
            self.check_sort_inplace(pyfunc, cfunc, orig)

    def test_array_sort_float(self):
        pyfunc = sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for orig in self.float_arrays():
            self.check_sort_inplace(pyfunc, cfunc, orig)

    def test_array_sort_complex(self):
        pyfunc = sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for real in self.float_arrays():
            imag = real[::]
            np.random.shuffle(imag)
            orig = np.array([complex(*x) for x in zip(real, imag)])
            self.check_sort_inplace(pyfunc, cfunc, orig)

    def test_np_sort_int(self):
        pyfunc = np_sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for orig in self.int_arrays():
            self.check_sort_copy(pyfunc, cfunc, orig)

    def test_np_sort_float(self):
        pyfunc = np_sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for size in (5, 20, 50, 500):
            orig = np.random.random(size=size) * 100
            orig[np.random.random(size=size) < 0.1] = float('nan')
            self.check_sort_copy(pyfunc, cfunc, orig)

    def test_np_sort_complex(self):
        pyfunc = np_sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for size in (5, 20, 50, 500):
            real = np.random.random(size=size) * 100
            imag = np.random.random(size=size) * 100
            real[np.random.random(size=size) < 0.1] = float('nan')
            imag[np.random.random(size=size) < 0.1] = float('nan')
            orig = np.array([complex(*x) for x in zip(real, imag)])
            self.check_sort_copy(pyfunc, cfunc, orig)

    def test_argsort_int(self):
        def check(pyfunc):
            cfunc = jit(nopython=True)(pyfunc)
            for orig in self.int_arrays():
                self.check_argsort(pyfunc, cfunc, orig)

        check(argsort_usecase)
        check(np_argsort_usecase)

    def test_argsort_kind_int(self):
        def check(pyfunc, is_stable):
            cfunc = jit(nopython=True)(pyfunc)
            for orig in self.int_arrays():
                self.check_argsort(pyfunc, cfunc, orig,
                                   dict(is_stable=is_stable))

        check(argsort_kind_usecase, is_stable=True)
        check(np_argsort_kind_usecase, is_stable=True)
        check(argsort_kind_usecase, is_stable=False)
        check(np_argsort_kind_usecase, is_stable=False)

    def test_argsort_float(self):
        def check(pyfunc):
            cfunc = jit(nopython=True)(pyfunc)
            for orig in self.float_arrays():
                self.check_argsort(pyfunc, cfunc, orig)

        check(argsort_usecase)
        check(np_argsort_usecase)

    def test_argsort_float_supplemental(self):
        def check(pyfunc, is_stable):
            cfunc = jit(nopython=True)(pyfunc)
            for orig in self.float_arrays():
                self.check_argsort(pyfunc, cfunc, orig,
                                   dict(is_stable=is_stable))

        check(argsort_kind_usecase, is_stable=True)
        check(np_argsort_kind_usecase, is_stable=True)
        check(argsort_kind_usecase, is_stable=False)
        check(np_argsort_kind_usecase, is_stable=False)

    def test_argsort_complex(self):
        def check(pyfunc):
            cfunc = jit(nopython=True)(pyfunc)
            for real in self.float_arrays():
                imag = real[::]
                np.random.shuffle(imag)
                orig = np.array([complex(*x) for x in zip(real, imag)])
                self.check_argsort(pyfunc, cfunc, orig)

        check(argsort_usecase)
        check(np_argsort_usecase)

    def test_argsort_complex_supplemental(self):
        def check(pyfunc, is_stable):
            cfunc = jit(nopython=True)(pyfunc)
            for real in self.float_arrays():
                imag = real[::]
                np.random.shuffle(imag)
                orig = np.array([complex(*x) for x in zip(real, imag)])
                self.check_argsort(pyfunc, cfunc, orig,
                                   dict(is_stable=is_stable))

        check(argsort_kind_usecase, is_stable=True)
        check(np_argsort_kind_usecase, is_stable=True)
        check(argsort_kind_usecase, is_stable=False)
        check(np_argsort_kind_usecase, is_stable=False)

    def test_bad_array(self):
        cfunc = jit(nopython=True)(np_sort_usecase)
        msg = '.*Argument "a" must be array-like.*'
        with self.assertRaisesRegex(errors.TypingError, msg) as raises:
            cfunc(None)


class TestPythonSort(TestCase):

    def test_list_sort(self):
        pyfunc = list_sort_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for size in (20, 50, 500):
            orig, ret = cfunc(size)
            self.assertEqual(sorted(orig), ret)
            self.assertNotEqual(orig, ret)   # sanity check

    def test_list_sort_reverse(self):
        pyfunc = list_sort_reverse_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for size in (20, 50, 500):
            for b in (False, True):
                orig, ret = cfunc(size, b)
                self.assertEqual(sorted(orig, reverse=b), ret)
                self.assertNotEqual(orig, ret)   # sanity check

    def test_sorted(self):
        pyfunc = sorted_usecase
        cfunc = jit(nopython=True)(pyfunc)

        for size in (20, 50, 500):
            orig = np.random.random(size=size) * 100
            expected = sorted(orig)
            got = cfunc(orig)
            self.assertPreciseEqual(got, expected)
            self.assertNotEqual(list(orig), got)   # sanity check

    def test_sorted_reverse(self):
        pyfunc = sorted_reverse_usecase
        cfunc = jit(nopython=True)(pyfunc)
        size = 20

        orig = np.random.random(size=size) * 100
        for b in (False, True):
            expected = sorted(orig, reverse=b)
            got = cfunc(orig, b)
            self.assertPreciseEqual(got, expected)
            self.assertNotEqual(list(orig), got)   # sanity check


class TestMergeSort(TestCase):
    def setUp(self):
        np.random.seed(321)

    def check_argsort_stable(self, sorter, low, high, count):
        # make data with high possibility of duplicated key
        data = np.random.randint(low, high, count)
        expect = np.argsort(data, kind='mergesort')
        got = sorter(data)
        np.testing.assert_equal(expect, got)

    def test_argsort_stable(self):
        arglist = [
            (-2, 2, 5),
            (-5, 5, 10),
            (0, 10, 101),
            (0, 100, 1003),
        ]
        imp = make_jit_mergesort(is_argsort=True)
        toplevel = imp.run_mergesort
        sorter = njit(lambda arr: toplevel(arr))
        for args in arglist:
            self.check_argsort_stable(sorter, *args)


nop_compiler = lambda x:x


class TestSortSlashSortedWithKey(MemoryLeakMixin, TestCase):

    def test_01(self):

        a = [3, 1, 4, 1, 5, 9]

        @njit
        def external_key(z):
            return 1. / z

        @njit
        def foo(x, key=None):
            new_x = x[:]
            new_x.sort(key=key)
            return sorted(x[:], key=key), new_x

        self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:]))
        self.assertPreciseEqual(foo(a[:], external_key),
                                foo.py_func(a[:], external_key))

    def test_02(self):

        a = [3, 1, 4, 1, 5, 9]

        @njit
        def foo(x):
            def closure_key(z):
                return 1. / z
            new_x = x[:]
            new_x.sort(key=closure_key)
            return sorted(x[:], key=closure_key), new_x

        self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:]))

    def test_03(self):

        a = [3, 1, 4, 1, 5, 9]

        def gen(compiler):

            @compiler
            def bar(x, func):
                new_x = x[:]
                new_x.sort(key=func)
                return sorted(x[:], key=func), new_x

            @compiler
            def foo(x):
                def closure_escapee_key(z):
                    return 1. / z
                return bar(x, closure_escapee_key)

            return foo

        self.assertPreciseEqual(gen(njit)(a[:]), gen(nop_compiler)(a[:]))

    def test_04(self):

        a = ['a','b','B','b','C','A']

        @njit
        def external_key(z):
            return z.upper()

        @njit
        def foo(x, key=None):
            new_x = x[:]
            new_x.sort(key=key)
            return sorted(x[:], key=key), new_x

        self.assertPreciseEqual(foo(a[:]), foo.py_func(a[:]))
        self.assertPreciseEqual(foo(a[:], external_key),
                                foo.py_func(a[:], external_key))

    def test_05(self):

        a = ['a','b','B','b','C','A']

        @njit
        def external_key(z):
            return z.upper()

        @njit
        def foo(x, key=None, reverse=False):
            new_x = x[:]
            new_x.sort(key=key, reverse=reverse)
            return (sorted(x[:], key=key, reverse=reverse), new_x)

        for key, rev in itertools.product((None, external_key),
                                          (True, False, 1, -12, 0)):
            self.assertPreciseEqual(foo(a[:], key, rev),
                                    foo.py_func(a[:], key, rev))

    def test_optional_on_key(self):
        a = [3, 1, 4, 1, 5, 9]

        @njit
        def foo(x, predicate):
            if predicate:
                def closure_key(z):
                    return 1. / z
            else:
                closure_key = None

            new_x = x[:]
            new_x.sort(key=closure_key)

            return (sorted(x[:], key=closure_key), new_x)

        with self.assertRaises(errors.TypingError) as raises:
            TF = True
            foo(a[:], TF)

        msg = "Key must concretely be None or a Numba JIT compiled function"
        self.assertIn(msg, str(raises.exception))

    def test_exceptions_sorted(self):

        @njit
        def foo_sorted(x, key=None, reverse=False):
            return sorted(x[:], key=key, reverse=reverse)

        @njit
        def foo_sort(x, key=None, reverse=False):
            new_x = x[:]
            new_x.sort(key=key, reverse=reverse)
            return new_x

        @njit
        def external_key(z):
            return 1. / z

        a = [3, 1, 4, 1, 5, 9]

        for impl in (foo_sort, foo_sorted):

            # check illegal key
            with self.assertRaises(errors.TypingError) as raises:
                impl(a, key="illegal")

            expect = "Key must be None or a Numba JIT compiled function"
            self.assertIn(expect, str(raises.exception))

            # check illegal reverse
            with self.assertRaises(errors.TypingError) as raises:
                impl(a, key=external_key, reverse="go backwards")

            expect = "an integer is required for 'reverse'"
            self.assertIn(expect, str(raises.exception))


class TestArrayArgsort(MemoryLeakMixin, TestCase):
    """Tests specific to array.argsort"""

    def test_exceptions(self):

        @njit
        def nonliteral_kind(kind):
            np.arange(5).argsort(kind=kind)

        # check non-literal kind
        with self.assertRaises(errors.TypingError) as raises:
            # valid spelling but not literal
            nonliteral_kind('quicksort')

        expect = '"kind" must be a string literal'
        self.assertIn(expect, str(raises.exception))

        @njit
        def unsupported_kwarg():
            np.arange(5).argsort(foo='')

        with self.assertRaises(errors.TypingError) as raises:
            unsupported_kwarg()

        expect = "Unsupported keywords: ['foo']"
        self.assertIn(expect, str(raises.exception))


if __name__ == '__main__':
    unittest.main()
