Semantics "with dppl_context" by 1e-to · Pull Request #40 · IntelPython/numba · GitHub
Skip to content
This repository was archived by the owner on Jan 25, 2023. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 2 additions & 16 deletions numba/core/cpu_dispatcher.py
1 change: 1 addition & 0 deletions numba/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from numba.core.errors import DeprecationError, NumbaDeprecationWarning
from numba.stencils.stencil import stencil
from numba.core import config, sigutils, registry, cpu_dispatcher
from numba.dppl import gpu_dispatcher


_logger = logging.getLogger(__name__)
Expand Down
5 changes: 4 additions & 1 deletion numba/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,9 @@ def with_lifting(func_ir, typingctx, targetctx, flags, locals):
"""
from numba.core import postproc

def dispatcher_factory(func_ir, objectmode=False, **kwargs):
def dispatcher_factory(func_ir, objectmode=False, dppl_mode=False, **kwargs):
from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith
from numba.dppl.withcontexts import DPPLLiftedWith

myflags = flags.copy()
if objectmode:
Expand All @@ -335,6 +336,8 @@ def dispatcher_factory(func_ir, objectmode=False, **kwargs):
myflags.force_pyobject = True
myflags.no_cpython_wrapper = False
cls = ObjModeLiftedWith
elif dppl_mode:
cls = DPPLLiftedWith

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is it actually doing?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary to create a dispatcher specifically for lifted code with new semantics. This is for now an intermediate solution, in the future it will be necessary to rewrite it so as not to change the numba files.

else:
cls = LiftedWith
return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs)
Expand Down
22 changes: 22 additions & 0 deletions numba/dppl/gpu_dispatcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from numba.core import dispatcher, compiler
from numba.core.registry import cpu_target, dispatcher_registry
import numba.dppl_config as dppl_config
from numba.dppl.compiler import DPPLCompiler


class GPUDispatcher(dispatcher.Dispatcher):
targetdescr = cpu_target

def __init__(self, py_func, locals={}, targetoptions={}, impl_kind='direct', pipeline_class=compiler.Compiler):
if dppl_config.dppl_present:
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=DPPLCompiler)
else:
print("---------------------------------------------------------------------")
print("WARNING : DPPL pipeline ignored. Ensure OpenCL drivers are installed.")
print("---------------------------------------------------------------------")
dispatcher.Dispatcher.__init__(self, py_func, locals=locals,
targetoptions=targetoptions, impl_kind=impl_kind, pipeline_class=pipeline_class)


dispatcher_registry['gpu'] = GPUDispatcher
120 changes: 120 additions & 0 deletions numba/dppl/tests/dppl/test_with_semantics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
from numba.dppl.testing import unittest
from numba.dppl.testing import DPPLTestCase
from numba.dppl.withcontexts import dppl_context
from numba.core import typing, cpu
from numba.core.compiler import compile_ir, DEFAULT_FLAGS
from numba.core.transforms import with_lifting
from numba.core.registry import cpu_target
from numba.core.bytecode import FunctionIdentity, ByteCode
from numba.core.interpreter import Interpreter
from numba.tests.support import captured_stdout
from numba import njit, prange
import numpy as np


def get_func_ir(func):
func_id = FunctionIdentity.from_function(func)
bc = ByteCode(func_id=func_id)
interp = Interpreter(func_id)
func_ir = interp.interpret(bc)
return func_ir


def liftcall1():
x = 1
print("A", x)
with dppl_context:
x += 1
print("B", x)
return x


def liftcall2():
x = 1
print("A", x)
with dppl_context:
x += 1
print("B", x)
with dppl_context:
x += 10
print("C", x)
return x


def liftcall3():
x = 1
print("A", x)
with dppl_context:
if x > 0:
x += 1
print("B", x)
with dppl_context:
for i in range(10):
x += i
print("C", x)
return x


class BaseTestWithLifting(DPPLTestCase):
def setUp(self):
super(BaseTestWithLifting, self).setUp()
self.typingctx = typing.Context()
self.targetctx = cpu.CPUContext(self.typingctx)
self.flags = DEFAULT_FLAGS

def check_extracted_with(self, func, expect_count, expected_stdout):
the_ir = get_func_ir(func)
new_ir, extracted = with_lifting(
the_ir, self.typingctx, self.targetctx, self.flags,
locals={},
)
self.assertEqual(len(extracted), expect_count)
cres = self.compile_ir(new_ir)

with captured_stdout() as out:
cres.entry_point()

self.assertEqual(out.getvalue(), expected_stdout)

def compile_ir(self, the_ir, args=(), return_type=None):
typingctx = self.typingctx
targetctx = self.targetctx
flags = self.flags
# Register the contexts in case for nested @jit or @overload calls
with cpu_target.nested_context(typingctx, targetctx):
return compile_ir(typingctx, targetctx, the_ir, args,
return_type, flags, locals={})


class TestLiftCall(BaseTestWithLifting):

def check_same_semantic(self, func):
"""Ensure same semantic with non-jitted code
"""
jitted = njit(target="gpu")(func)
with captured_stdout() as got:
jitted()

with captured_stdout() as expect:
func()

self.assertEqual(got.getvalue(), expect.getvalue())

def test_liftcall1(self):
self.check_extracted_with(liftcall1, expect_count=1,
expected_stdout="A 1\nB 2\n")
self.check_same_semantic(liftcall1)

def test_liftcall2(self):
self.check_extracted_with(liftcall2, expect_count=2,
expected_stdout="A 1\nB 2\nC 12\n")
self.check_same_semantic(liftcall2)

def test_liftcall3(self):
self.check_extracted_with(liftcall3, expect_count=2,
expected_stdout="A 1\nB 2\nC 47\n")
self.check_same_semantic(liftcall3)


if __name__ == '__main__':
unittest.main()
144 changes: 144 additions & 0 deletions numba/dppl/withcontexts.py