<html><head><meta name="color-scheme" content="light dark"></head><body><pre style="word-wrap: break-word; white-space: pre-wrap;">import unittest
import string

import numpy as np

from numba import njit, jit, literal_unroll
from numba.core import event as ev
from numba.tests.support import TestCase, override_config


class TestEvent(TestCase):

    def setUp(self):
        # Trigger compilation to ensure all listeners are initialized
        njit(lambda: None)()
        self.__registered_listeners = len(ev._registered)

    def tearDown(self):
        # Check there is no lingering listeners
        self.assertEqual(len(ev._registered), self.__registered_listeners)

    def test_recording_listener(self):
        @njit
        def foo(x):
            return x + x

        with ev.install_recorder("numba:compile") as rec:
            foo(1)

        self.assertIsInstance(rec, ev.RecordingListener)
        # Check there must be at least two events.
        # Because there must be a START and END for the compilation of foo()
        self.assertGreaterEqual(len(rec.buffer), 2)

    def test_compiler_lock_event(self):
        @njit
        def foo(x):
            return x + x

        foo(1)
        md = foo.get_metadata(foo.signatures[0])
        lock_duration = md['timers']['compiler_lock']
        self.assertIsInstance(lock_duration, float)
        self.assertGreater(lock_duration, 0)

    def test_llvm_lock_event(self):
        @njit
        def foo(x):
            return x + x

        foo(1)
        md = foo.get_metadata(foo.signatures[0])
        lock_duration = md['timers']['llvm_lock']
        self.assertIsInstance(lock_duration, float)
        self.assertGreater(lock_duration, 0)

    def test_run_pass_event(self):
        @njit
        def foo(x):
            return x + x

        with ev.install_recorder("numba:run_pass") as recorder:
            foo(2)

        self.assertGreater(len(recorder.buffer), 0)
        for _, event in recorder.buffer:
            # Check that all fields are there
            data = event.data
            self.assertIsInstance(data['name'], str)
            self.assertIsInstance(data['qualname'], str)
            self.assertIsInstance(data['module'], str)
            self.assertIsInstance(data['flags'], str)
            self.assertIsInstance(data['args'], str)
            self.assertIsInstance(data['return_type'], str)

    def test_install_listener(self):
        ut = self

        class MyListener(ev.Listener):
            def on_start(self, event):
                ut.assertEqual(event.status, ev.EventStatus.START)
                ut.assertEqual(event.kind, "numba:compile")
                ut.assertIs(event.data["dispatcher"], foo)
                dispatcher = event.data["dispatcher"]
                ut.assertIs(dispatcher, foo)
                # Check that the compiling signature is NOT in the overloads
                ut.assertNotIn(event.data["args"], dispatcher.overloads)

            def on_end(self, event):
                ut.assertEqual(event.status, ev.EventStatus.END)
                ut.assertEqual(event.kind, "numba:compile")
                dispatcher = event.data["dispatcher"]
                ut.assertIs(dispatcher, foo)
                # Check that the compiling signature is in the overloads
                ut.assertIn(event.data["args"], dispatcher.overloads)

        @njit
        def foo(x):
            return x

        listener = MyListener()
        with ev.install_listener("numba:compile", listener) as yielded:
            foo(1)

        # Check that the yielded value is the same listener
        self.assertIs(listener, yielded)

    def test_global_register(self):
        ut = self

        class MyListener(ev.Listener):
            def on_start(self, event):
                ut.assertEqual(event.status, ev.EventStatus.START)
                ut.assertEqual(event.kind, "numba:compile")
                # Check it is the same dispatcher
                dispatcher = event.data["dispatcher"]
                ut.assertIs(dispatcher, foo)
                # Check that the compiling signature is NOT in the overloads
                ut.assertNotIn(event.data["args"], dispatcher.overloads)

            def on_end(self, event):
                ut.assertEqual(event.status, ev.EventStatus.END)
                ut.assertEqual(event.kind, "numba:compile")
                # Check it is the same dispatcher
                dispatcher = event.data["dispatcher"]
                ut.assertIs(dispatcher, foo)
                # Check that the compiling signature is in the overloads
                ut.assertIn(event.data["args"], dispatcher.overloads)

        @njit
        def foo(x):
            return x

        listener = MyListener()
        ev.register("numba:compile", listener)
        foo(1)
        ev.unregister("numba:compile", listener)

    def test_lifted_dispatcher(self):
        @jit(forceobj=True)
        def foo():
            object()   # to trigger loop-lifting
            c = 0
            for i in range(10):
                c += i
            return c

        with ev.install_recorder("numba:compile") as rec:
            foo()

        # Check that there are 4 events.
        # Two for `foo()` and two for the lifted loop.
        self.assertGreaterEqual(len(rec.buffer), 4)

        cres = foo.overloads[foo.signatures[0]]
        [ldisp] = cres.lifted

        lifted_cres = ldisp.overloads[ldisp.signatures[0]]
        self.assertIsInstance(
            lifted_cres.metadata["timers"]["compiler_lock"],
            float,
        )
        self.assertIsInstance(
            lifted_cres.metadata["timers"]["llvm_lock"],
            float,
        )

    def test_timing_properties(self):
        a = tuple(string.ascii_lowercase)

        @njit
        def bar(x):
            acc = 0
            for i in literal_unroll(a):
                if i in {'1': x}:
                    acc += 1
                else:
                    acc += np.sqrt(x[0, 0])
            return np.sin(x), acc

        @njit
        def foo(x):
            return bar(np.zeros((x, x)))

        with override_config('LLVM_PASS_TIMINGS', True):
            foo(1)

        def get_timers(fn, prop):
            md = fn.get_metadata(fn.signatures[0])
            return md[prop]

        foo_timers = get_timers(foo, 'timers')
        bar_timers = get_timers(bar, 'timers')
        foo_llvm_timer = get_timers(foo, 'llvm_pass_timings')
        bar_llvm_timer = get_timers(bar, 'llvm_pass_timings')

        # Check: time spent in bar() must be longer than in foo()
        self.assertLess(bar_timers['llvm_lock'],
                        foo_timers['llvm_lock'])
        self.assertLess(bar_timers['compiler_lock'],
                        foo_timers['compiler_lock'])

        # Check: time spent in LLVM itself must be less than in the LLVM lock
        self.assertLess(foo_llvm_timer.get_total_time(),
                        foo_timers['llvm_lock'])
        self.assertLess(bar_llvm_timer.get_total_time(),
                        bar_timers['llvm_lock'])

        # Check: time spent in LLVM lock must be less than in compiler
        self.assertLess(foo_timers['llvm_lock'],
                        foo_timers['compiler_lock'])
        self.assertLess(bar_timers['llvm_lock'],
                        bar_timers['compiler_lock'])


if __name__ == "__main__":
    unittest.main()
</pre></body></html>