# Tests numba.analysis functions
import collections
import types as pytypes

import numpy as np
from numba.core.compiler import run_frontend, Flags, StateDict
from numba import jit, njit, literal_unroll
from numba.core import types, errors, ir, rewrites, ir_utils, utils, cpu
from numba.core import postproc
from numba.core.inline_closurecall import InlineClosureCallPass
from numba.tests.support import (TestCase, MemoryLeakMixin, SerialMixin,
                                 IRPreservingTestPipeline)
from numba.core.analysis import dead_branch_prune, rewrite_semantic_constants
from numba.core.untyped_passes import (ReconstructSSA, TranslateByteCode,
                                       IRProcessing, DeadBranchPrune,
                                       PreserveIR)
from numba.core.compiler import DefaultPassBuilder, CompilerBase, PassManager


_GLOBAL = 123

enable_pyobj_flags = Flags()
enable_pyobj_flags.enable_pyobject = True


def compile_to_ir(func):
    func_ir = run_frontend(func)
    state = StateDict()
    state.func_ir = func_ir
    state.typemap = None
    state.calltypes = None
    # Transform to SSA
    ReconstructSSA().run_pass(state)
    # call this to get print etc rewrites
    rewrites.rewrite_registry.apply('before-inference', state)
    return func_ir


class TestBranchPruneBase(MemoryLeakMixin, TestCase):
    """
    Tests branch pruning
    """
    _DEBUG = False

    # find *all* branches
    def find_branches(self, the_ir):
        branches = []
        for blk in the_ir.blocks.values():
            tmp = [_ for _ in blk.find_insts(cls=ir.Branch)]
            branches.extend(tmp)
        return branches

    def assert_prune(self, func, args_tys, prune, *args, **kwargs):
        # This checks that the expected pruned branches have indeed been pruned.
        # func is a python function to assess
        # args_tys is the numba types arguments tuple
        # prune arg is a list, one entry per branch. The value in the entry is
        # encoded as follows:
        # True: using constant inference only, the True branch will be pruned
        # False: using constant inference only, the False branch will be pruned
        # None: under no circumstances should this branch be pruned
        # *args: the argument instances to pass to the function to check
        #        execution is still valid post transform
        # **kwargs:
        #        - flags: args to pass to `jit` default is `nopython=True`,
        #          e.g. permits use of e.g. object mode.

        func_ir = compile_to_ir(func)
        before = func_ir.copy()
        if self._DEBUG:
            print("=" * 80)
            print("before inline")
            func_ir.dump()

        # run closure inlining to ensure that nonlocals in closures are visible
        inline_pass = InlineClosureCallPass(func_ir,
                                            cpu.ParallelOptions(False),)
        inline_pass.run()

        # Remove all Dels, and re-run postproc
        post_proc = postproc.PostProcessor(func_ir)
        post_proc.run()

        rewrite_semantic_constants(func_ir, args_tys)
        if self._DEBUG:
            print("=" * 80)
            print("before prune")
            func_ir.dump()

        dead_branch_prune(func_ir, args_tys)

        after = func_ir
        if self._DEBUG:
            print("after prune")
            func_ir.dump()

        before_branches = self.find_branches(before)
        self.assertEqual(len(before_branches), len(prune))

        # what is expected to be pruned
        expect_removed = []
        for idx, prune in enumerate(prune):
            branch = before_branches[idx]
            if prune is True:
                expect_removed.append(branch.truebr)
            elif prune is False:
                expect_removed.append(branch.falsebr)
            elif prune is None:
                pass  # nothing should be removed!
            elif prune == 'both':
                expect_removed.append(branch.falsebr)
                expect_removed.append(branch.truebr)
            else:
                assert 0, "unreachable"

        # compare labels
        original_labels = set([_ for _ in before.blocks.keys()])
        new_labels = set([_ for _ in after.blocks.keys()])
        # assert that the new labels are precisely the original less the
        # expected pruned labels
        try:
            self.assertEqual(new_labels, original_labels - set(expect_removed))
        except AssertionError as e:
            print("new_labels", sorted(new_labels))
            print("original_labels", sorted(original_labels))
            print("expect_removed", sorted(expect_removed))
            raise e

        supplied_flags = kwargs.pop('flags', {'nopython': True})
        # NOTE: original testing used `compile_isolated` hence use of `cres`.
        cres = jit(args_tys, **supplied_flags)(func).overloads[args_tys]
        if args is None:
            res = cres.entry_point()
            expected = func()
        else:
            res = cres.entry_point(*args)
            expected = func(*args)
        self.assertEqual(res, expected)


class TestBranchPrune(TestBranchPruneBase, SerialMixin):

    def test_single_if(self):

        def impl(x):
            if 1 == 0:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [True], None)

        def impl(x):
            if 1 == 1:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [False], None)

        def impl(x):
            if x is None:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [False], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10)

        def impl(x):
            if x == 10:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [True], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10)

        def impl(x):
            if x == 10:
                z = 3.14159  # noqa: F841 # no effect

        self.assert_prune(impl, (types.NoneType('none'),), [True], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10)

        def impl(x):
            z = None
            y = z
            if x == y:
                return 100

        self.assert_prune(impl, (types.NoneType('none'),), [False], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10)

    def test_single_if_else(self):

        def impl(x):
            if x is None:
                return 3.14159
            else:
                return 1.61803

        self.assert_prune(impl, (types.NoneType('none'),), [False], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True], 10)

    def test_single_if_const_val(self):

        def impl(x):
            if x == 100:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [True], None)
        self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100)

        def impl(x):
            # switch the condition order
            if 100 == x:
                return 3.14159

        self.assert_prune(impl, (types.NoneType('none'),), [True], None)
        self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100)

    def test_single_if_else_two_const_val(self):

        def impl(x, y):
            if x == y:
                return 3.14159
            else:
                return 1.61803

        self.assert_prune(impl, (types.IntegerLiteral(100),) * 2, [None], 100,
                          100)
        self.assert_prune(impl, (types.NoneType('none'),) * 2, [False], None,
                          None)
        self.assert_prune(impl, (types.IntegerLiteral(100),
                                 types.NoneType('none'),), [True], 100, None)
        self.assert_prune(impl, (types.IntegerLiteral(100),
                                 types.IntegerLiteral(1000)), [None], 100, 1000)

    def test_single_if_else_w_following_undetermined(self):

        def impl(x):
            x_is_none_work = False
            if x is None:
                x_is_none_work = True
            else:
                dead = 7  # noqa: F841 # no effect

            if x_is_none_work:
                y = 10
            else:
                y = -3
            return y

        self.assert_prune(impl, (types.NoneType('none'),), [False, None], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)

        def impl(x):
            x_is_none_work = False
            if x is None:
                x_is_none_work = True
            else:
                pass

            if x_is_none_work:
                y = 10
            else:
                y = -3
            return y

        if utils.PYVERSION >= (3, 10):
            # Python 3.10 creates a block with a NOP in it for the `pass` which
            # means it gets pruned.
            self.assert_prune(impl, (types.NoneType('none'),), [False, None],
                              None)
        else:
            self.assert_prune(impl, (types.NoneType('none'),), [None, None],
                              None)

        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)

    def test_double_if_else_rt_const(self):

        def impl(x):
            one_hundred = 100
            x_is_none_work = 4
            if x is None:
                x_is_none_work = 100
            else:
                dead = 7  # noqa: F841 # no effect

            if x_is_none_work == one_hundred:
                y = 10
            else:
                y = -3

            return y, x_is_none_work

        self.assert_prune(impl, (types.NoneType('none'),), [False, None], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, None], 10)

    def test_double_if_else_non_literal_const(self):

        def impl(x):
            one_hundred = 100
            if x == one_hundred:
                y = 3.14159
            else:
                y = 1.61803
            return y

        # no prune as compilation specialization on literal value not permitted
        self.assert_prune(impl, (types.IntegerLiteral(10),), [None], 10)
        self.assert_prune(impl, (types.IntegerLiteral(100),), [None], 100)

    def test_single_two_branches_same_cond(self):

        def impl(x):
            if x is None:
                y = 10
            else:
                y = 40

            if x is not None:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.NoneType('none'),), [False, True], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)

    def test_cond_is_kwarg_none(self):

        def impl(x=None):
            if x is None:
                y = 10
            else:
                y = 40

            if x is not None:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.Omitted(None),),
                          [False, True], None)
        self.assert_prune(impl, (types.NoneType('none'),), [False, True], None)
        self.assert_prune(impl, (types.IntegerLiteral(10),), [True, False], 10)

    def test_cond_is_kwarg_value(self):

        def impl(x=1000):
            if x == 1000:
                y = 10
            else:
                y = 40

            if x != 1000:
                z = 100
            else:
                z = 400

            return z, y

        self.assert_prune(impl, (types.Omitted(1000),), [None, None], 1000)
        self.assert_prune(impl, (types.IntegerLiteral(1000),), [None, None],
                          1000)
        self.assert_prune(impl, (types.IntegerLiteral(0),), [None, None], 0)
        self.assert_prune(impl, (types.NoneType('none'),), [True, False], None)

    def test_cond_rewrite_is_correct(self):
        # this checks that when a condition is replaced, it is replace by a
        # true/false bit that correctly represents the evaluated condition
        def fn(x):
            if x is None:
                return 10
            return 12

        def check(func, arg_tys, bit_val):
            func_ir = compile_to_ir(func)

            # check there is 1 branch
            before_branches = self.find_branches(func_ir)
            self.assertEqual(len(before_branches), 1)

            # check the condition in the branch is a binop
            pred_var = before_branches[0].cond
            pred_defn = ir_utils.get_definition(func_ir, pred_var)
            self.assertEqual(pred_defn.op, 'call')
            condition_var = pred_defn.args[0]
            condition_op = ir_utils.get_definition(func_ir, condition_var)
            self.assertEqual(condition_op.op, 'binop')

            # do the prune, this should kill the dead branch and rewrite the
            #'condition to a true/false const bit
            if self._DEBUG:
                print("=" * 80)
                print("before prune")
                func_ir.dump()
            dead_branch_prune(func_ir, arg_tys)
            if self._DEBUG:
                print("=" * 80)
                print("after prune")
                func_ir.dump()

            # after mutation, the condition should be a const value `bit_val`
            new_condition_defn = ir_utils.get_definition(func_ir, condition_var)
            self.assertTrue(isinstance(new_condition_defn, ir.Const))
            self.assertEqual(new_condition_defn.value, bit_val)

        check(fn, (types.NoneType('none'),), 1)
        check(fn, (types.IntegerLiteral(10),), 0)

    def test_global_bake_in(self):

        def impl(x):
            if _GLOBAL == 123:
                return x
            else:
                return x + 10

        self.assert_prune(impl, (types.IntegerLiteral(1),), [False], 1)

        global _GLOBAL
        tmp = _GLOBAL

        try:
            _GLOBAL = 5

            def impl(x):
                if _GLOBAL == 123:
                    return x
                else:
                    return x + 10

            self.assert_prune(impl, (types.IntegerLiteral(1),), [True], 1)
        finally:
            _GLOBAL = tmp

    def test_freevar_bake_in(self):

        _FREEVAR = 123

        def impl(x):
            if _FREEVAR == 123:
                return x
            else:
                return x + 10

        self.assert_prune(impl, (types.IntegerLiteral(1),), [False], 1)

        _FREEVAR = 12

        def impl(x):
            if _FREEVAR == 123:
                return x
            else:
                return x + 10

        self.assert_prune(impl, (types.IntegerLiteral(1),), [True], 1)

    def test_redefined_variables_are_not_considered_in_prune(self):
        # see issue #4163, checks that if a variable that is an argument is
        # redefined in the user code it is not considered const

        def impl(array, a=None):
            if a is None:
                a = 0
            if a < 0:
                return 10
            return 30

        self.assert_prune(impl,
                          (types.Array(types.float64, 2, 'C'),
                           types.NoneType('none'),),
                          [None, None],
                          np.zeros((2, 3)), None)

    def test_comparison_operators(self):
        # see issue #4163, checks that a variable that is an argument and has
        # value None survives TypeError from invalid comparison which should be
        # dead

        def impl(array, a=None):
            x = 0
            if a is None:
                return 10 # dynamic exec would return here
            # static analysis requires that this is executed with a=None,
            # hence TypeError
            if a < 0:
                return 20
            return x

        self.assert_prune(impl,
                          (types.Array(types.float64, 2, 'C'),
                           types.NoneType('none'),),
                          [False, 'both'],
                          np.zeros((2, 3)), None)

        self.assert_prune(impl,
                          (types.Array(types.float64, 2, 'C'),
                           types.float64,),
                          [None, None],
                          np.zeros((2, 3)), 12.)

    def test_redefinition_analysis_same_block(self):
        # checks that a redefinition in a block with prunable potential doesn't
        # break

        def impl(array, x, a=None):
            b = 2
            if x < 4:
                b = 12
            if a is None: # known true
                a = 7 # live
            else:
                b = 15 # dead
            if a < 0: # valid as a result of the redefinition of 'a'
                return 10
            return 30 + b + a

        self.assert_prune(impl,
                          (types.Array(types.float64, 2, 'C'),
                           types.float64, types.NoneType('none'),),
                          [None, False, None],
                          np.zeros((2, 3)), 1., None)

    def test_redefinition_analysis_different_block_can_exec(self):
        # checks that a redefinition in a block that may be executed prevents
        # pruning

        def impl(array, x, a=None):
            b = 0
            if x > 5:
                a = 11 # a redefined, cannot tell statically if this will exec
            if x < 4:
                b = 12
            if a is None: # cannot prune, cannot determine if re-defn occurred
                b += 5
            else:
                b += 7
                if a < 0:
                    return 10
            return 30 + b

        self.assert_prune(impl,
                          (types.Array(types.float64, 2, 'C'),
                           types.float64, types.NoneType('none'),),
                          [None, None, None, None],
                          np.zeros((2, 3)), 1., None)

    def test_redefinition_analysis_different_block_cannot_exec(self):
        # checks that a redefinition in a block guarded by something that
        # has prune potential

        def impl(array, x=None, a=None):
            b = 0
            if x is not None:
                a = 11
            if a is None:
                b += 5
            else:
                b += 7
            return 30 + b

        self.assert_prune(impl,
                          (types.Array(types.float64, 2, 'C'),
                           types.NoneType('none'), types.NoneType('none')),
                          [True, None],
                          np.zeros((2, 3)), None, None)

        self.assert_prune(impl,
                          (types.Array(types.float64, 2, 'C'),
                           types.NoneType('none'), types.float64),
                          [True, None],
                          np.zeros((2, 3)), None, 1.2)

        self.assert_prune(impl,
                          (types.Array(types.float64, 2, 'C'),
                           types.float64, types.NoneType('none')),
                          [None, None],
                          np.zeros((2, 3)), 1.2, None)

    def test_closure_and_nonlocal_can_prune(self):
        # Closures must be inlined ahead of branch pruning in case nonlocal
        # is used. See issue #6585.
        def impl():
            x = 1000

            def closure():
                nonlocal x
                x = 0

            closure()

            if x == 0:
                return True
            else:
                return False

        self.assert_prune(impl, (), [False,],)

    def test_closure_and_nonlocal_cannot_prune(self):
        # Closures must be inlined ahead of branch pruning in case nonlocal
        # is used. See issue #6585.
        def impl(n):
            x = 1000

            def closure(t):
                nonlocal x
                x = t

            closure(n)

            if x == 0:
                return True
            else:
                return False

        self.assert_prune(impl, (types.int64,), [None,], 1)


class TestBranchPrunePredicates(TestBranchPruneBase, SerialMixin):
    # Really important thing to remember... the branch on predicates end up as
    # POP_JUMP_IF_<bool> and the targets are backwards compared to normal, i.e.
    # the true condition is far jump and the false the near i.e. `if x` would
    # end up in Numba IR as e.g. `branch x 10, 6`.

    _TRUTHY = (1, "String", True, 7.4, 3j)
    _FALSEY = (0, "", False, 0.0, 0j, None)

    def _literal_const_sample_generator(self, pyfunc, consts):
        """
        This takes a python function, pyfunc, and manipulates its co_const
        __code__ member to create a new function with different co_consts as
        supplied in argument consts.

        consts is a dict {index: value} of co_const tuple index to constant
        value used to update a pyfunc clone's co_const.
        """
        pyfunc_code = pyfunc.__code__

        # translate consts spec to update the constants
        co_consts = {k: v for k, v in enumerate(pyfunc_code.co_consts)}
        for k, v in consts.items():
            co_consts[k] = v
        new_consts = tuple([v for _, v in sorted(co_consts.items())])

        # create code object with mutation
        new_code = pyfunc_code.replace(co_consts=new_consts)

        # get function
        return pytypes.FunctionType(new_code, globals())

    def test_literal_const_code_gen(self):
        def impl(x):
            _CONST1 = "PLACEHOLDER1"
            if _CONST1:
                return 3.14159
            else:
                _CONST2 = "PLACEHOLDER2"
            return _CONST2 + 4

        new = self._literal_const_sample_generator(impl, {1:0, 3:20})
        iconst = impl.__code__.co_consts
        nconst = new.__code__.co_consts
        self.assertEqual(iconst, (None, "PLACEHOLDER1", 3.14159,
                                  "PLACEHOLDER2", 4))
        self.assertEqual(nconst, (None, 0, 3.14159,  20, 4))
        self.assertEqual(impl(None), 3.14159)
        self.assertEqual(new(None), 24)

    def test_single_if_const(self):

        def impl(x):
            _CONST1 = "PLACEHOLDER1"
            if _CONST1:
                return 3.14159

        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for const in c_inp:
                func = self._literal_const_sample_generator(impl, {1: const})
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_negate_const(self):

        def impl(x):
            _CONST1 = "PLACEHOLDER1"
            if not _CONST1:
                return 3.14159

        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for const in c_inp:
                func = self._literal_const_sample_generator(impl, {1: const})
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_else_const(self):

        def impl(x):
            _CONST1 = "PLACEHOLDER1"
            if _CONST1:
                return 3.14159
            else:
                return 1.61803

        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for const in c_inp:
                func = self._literal_const_sample_generator(impl, {1: const})
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_else_negate_const(self):

        def impl(x):
            _CONST1 = "PLACEHOLDER1"
            if not _CONST1:
                return 3.14159
            else:
                return 1.61803

        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for const in c_inp:
                func = self._literal_const_sample_generator(impl, {1: const})
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_freevar(self):
        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for const in c_inp:

                def func(x):
                    if const:
                        return 3.14159, const
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_negate_freevar(self):
        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for const in c_inp:

                def func(x):
                    if not const:
                        return 3.14159, const
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_else_freevar(self):
        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for const in c_inp:

                def func(x):
                    if const:
                        return 3.14159, const
                    else:
                        return 1.61803, const
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_else_negate_freevar(self):
        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for const in c_inp:

                def func(x):
                    if not const:
                        return 3.14159, const
                    else:
                        return 1.61803, const
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    # globals in this section have absurd names after their test usecase names
    # so as to prevent collisions and permit tests to run in parallel
    def test_single_if_global(self):
        global c_test_single_if_global

        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for c in c_inp:
                c_test_single_if_global = c

                def func(x):
                    if c_test_single_if_global:
                        return 3.14159, c_test_single_if_global

                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_negate_global(self):
        global c_test_single_if_negate_global

        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for c in c_inp:
                c_test_single_if_negate_global = c

                def func(x):
                    if c_test_single_if_negate_global:
                        return 3.14159, c_test_single_if_negate_global

                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_else_global(self):
        global c_test_single_if_else_global

        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for c in c_inp:
                c_test_single_if_else_global = c

                def func(x):
                    if c_test_single_if_else_global:
                        return 3.14159, c_test_single_if_else_global
                    else:
                        return 1.61803, c_test_single_if_else_global
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_single_if_else_negate_global(self):
        global c_test_single_if_else_negate_global

        for c_inp, prune in (self._TRUTHY, False), (self._FALSEY, True):
            for c in c_inp:
                c_test_single_if_else_negate_global = c

                def func(x):
                    if not c_test_single_if_else_negate_global:
                        return 3.14159, c_test_single_if_else_negate_global
                    else:
                        return 1.61803, c_test_single_if_else_negate_global
                self.assert_prune(func, (types.NoneType('none'),), [prune],
                                  None)

    def test_issue_5618(self):

        @njit
        def foo():
            values = np.zeros(1)
            tmp = 666
            if tmp:
                values[0] = tmp
            return values

        self.assertPreciseEqual(foo.py_func()[0], 666.)
        self.assertPreciseEqual(foo()[0], 666.)


class TestBranchPruneSSA(MemoryLeakMixin, TestCase):
    # Tests SSA rewiring of phi nodes after branch pruning.

    class SSAPrunerCompiler(CompilerBase):
        def define_pipelines(self):
            # This is a simple pipeline that does branch pruning on IR in SSA
            # form, then types and lowers as per the standard nopython pipeline.
            pm = PassManager("testing pm")
            pm.add_pass(TranslateByteCode, "analyzing bytecode")
            pm.add_pass(IRProcessing, "processing IR")
            # SSA early
            pm.add_pass(ReconstructSSA, "ssa")
            pm.add_pass(DeadBranchPrune, "dead branch pruning")
            # type and then lower as usual
            pm.add_pass(PreserveIR, "preserves the IR as metadata")
            dpb = DefaultPassBuilder
            typed_passes = dpb.define_typed_pipeline(self.state)
            pm.passes.extend(typed_passes.passes)
            lowering_passes = dpb.define_nopython_lowering_pipeline(self.state)
            pm.passes.extend(lowering_passes.passes)
            pm.finalize()
            return [pm]

    def test_ssa_update_phi(self):
        # This checks that dead branch pruning is rewiring phi nodes correctly
        # after a block containing an incoming for a phi is removed.

        @njit(pipeline_class=self.SSAPrunerCompiler)
        def impl(p=None, q=None):
            z = 1
            r = False
            if p is None:
                r = True # live

            if r and q is not None:
                z = 20 # dead

            # one of the incoming blocks for z is dead, the phi needs an update
            # were this not done, it would refer to variables that do not exist
            # and result in a lowering error.
            return z, r

        self.assertPreciseEqual(impl(), impl.py_func())

    def test_ssa_replace_phi(self):
        # This checks that when a phi only has one incoming, because the other
        # has been pruned, that a direct assignment is used instead.

        @njit(pipeline_class=self.SSAPrunerCompiler)
        def impl(p=None):
            z = 0
            if p is None:
                z = 10
            else:
                z = 20

            return z

        self.assertPreciseEqual(impl(), impl.py_func())
        func_ir = impl.overloads[impl.signatures[0]].metadata['preserved_ir']

        # check the func_ir, make sure there's no phi nodes
        for blk in func_ir.blocks.values():
            self.assertFalse([*blk.find_exprs('phi')])


class TestBranchPrunePostSemanticConstRewrites(TestBranchPruneBase):
    # Tests that semantic constants rewriting works by virtue of branch pruning

    def test_array_ndim_attr(self):

        def impl(array):
            if array.ndim == 2:
                if array.shape[1] == 2:
                    return 1
            else:
                return 10

        self.assert_prune(impl, (types.Array(types.float64, 2, 'C'),), [False,
                                                                        None],
                          np.zeros((2, 3)))
        self.assert_prune(impl, (types.Array(types.float64, 1, 'C'),), [True,
                                                                        'both'],
                          np.zeros((2,)))

    def test_tuple_len(self):

        def impl(tup):
            if len(tup) == 3:
                if tup[2] == 2:
                    return 1
            else:
                return 0

        self.assert_prune(impl, (types.UniTuple(types.int64, 3),), [False,
                                                                    None],
                          tuple([1, 2, 3]))
        self.assert_prune(impl, (types.UniTuple(types.int64, 2),), [True,
                                                                    'both'],
                          tuple([1, 2]))

    def test_attr_not_len(self):
        # The purpose of this test is to make sure that the conditions guarding
        # the rewrite part do not themselves raise exceptions.
        # This produces an `ir.Expr` call node for `float.as_integer_ratio`,
        # which is a getattr() on `float`.

        @njit
        def test():
            float.as_integer_ratio(1.23)

        # this should raise a TypingError
        with self.assertRaises(errors.TypingError) as e:
            test()

        self.assertIn("Unknown attribute 'as_integer_ratio'", str(e.exception))

    def test_ndim_not_on_array(self):

        FakeArray = collections.namedtuple('FakeArray', ['ndim'])
        fa = FakeArray(ndim=2)

        def impl(fa):
            if fa.ndim == 2:
                return fa.ndim
            else:
                object()

        # check prune works for array ndim
        self.assert_prune(impl, (types.Array(types.float64, 2, 'C'),), [False],
                          np.zeros((2, 3)))

        # check prune fails for something with `ndim` attr that is not array
        FakeArrayType = types.NamedUniTuple(types.int64, 1, FakeArray)
        self.assert_prune(impl, (FakeArrayType,), [None], fa,
                          flags={'nopython':False, 'forceobj':True})

    def test_semantic_const_propagates_before_static_rewrites(self):
        # see issue #5015, the ndim needs writing in as a const before
        # the rewrite passes run to make e.g. getitems static where possible
        @njit
        def impl(a, b):
            return a.shape[:b.ndim]

        args = (np.zeros((5, 4, 3, 2)), np.zeros((1, 1)))

        self.assertPreciseEqual(impl(*args), impl.py_func(*args))

    def test_tuple_const_propagation(self):
        @njit(pipeline_class=IRPreservingTestPipeline)
        def impl(*args):
            s = 0
            for arg in literal_unroll(args):
                s += len(arg)
            return s

        inp = ((), (1, 2, 3), ())
        self.assertPreciseEqual(impl(*inp), impl.py_func(*inp))

        ol = impl.overloads[impl.signatures[0]]
        func_ir = ol.metadata['preserved_ir']
        # make sure one of the inplace binop args is a Const
        binop_consts = set()
        for blk in func_ir.blocks.values():
            for expr in blk.find_exprs('inplace_binop'):
                inst = blk.find_variable_assignment(expr.rhs.name)
                self.assertIsInstance(inst.value, ir.Const)
                binop_consts.add(inst.value.value)
        self.assertEqual(binop_consts, {len(x) for x in inp})
