Accept NumPy arrays in advanced indexing by ndgrigorian · Pull Request #2128 · IntelPython/dpctl · GitHub
Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 102 additions & 63 deletions dpctl/tensor/_copy_utils.py
13 changes: 7 additions & 6 deletions dpctl/tensor/_slicing.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numbers
from operator import index
from cpython.buffer cimport PyObject_CheckBuffer
from numpy import ndarray


cdef bint _is_buffer(object o):
Expand Down Expand Up @@ -46,7 +47,7 @@ cdef Py_ssize_t _slice_len(

cdef bint _is_integral(object x) except *:
"""Gives True if x is an integral slice spec"""
if isinstance(x, usm_ndarray):
if isinstance(x, (ndarray, usm_ndarray)):
if x.ndim > 0:
return False
if x.dtype.kind not in "ui":
Expand Down Expand Up @@ -74,7 +75,7 @@ cdef bint _is_integral(object x) except *:

cdef bint _is_boolean(object x) except *:
"""Gives True if x is an integral slice spec"""
if isinstance(x, usm_ndarray):
if isinstance(x, (ndarray, usm_ndarray)):
if x.ndim > 0:
return False
if x.dtype.kind not in "b":
Expand Down Expand Up @@ -185,7 +186,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
raise IndexError(
"Index {0} is out of range for axes 0 with "
"size {1}".format(ind, shape[0]))
elif isinstance(ind, usm_ndarray):
elif isinstance(ind, (ndarray, usm_ndarray)):
return (shape, strides, offset, (ind,), 0)
elif isinstance(ind, tuple):
axes_referenced = 0
Expand Down Expand Up @@ -216,7 +217,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
axes_referenced += 1
if not array_streak_started and array_streak_interrupted:
explicit_index += 1
elif isinstance(i, usm_ndarray):
elif isinstance(i, (ndarray, usm_ndarray)):
if not seen_arrays_yet:
seen_arrays_yet = True
array_streak_started = True
Expand Down Expand Up @@ -302,7 +303,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
array_streak = False
elif _is_integral(ind_i):
if array_streak:
if not isinstance(ind_i, usm_ndarray):
if not isinstance(ind_i, (ndarray, usm_ndarray)):
ind_i = index(ind_i)
# integer will be converted to an array,
# still raise if OOB
Expand Down Expand Up @@ -337,7 +338,7 @@ def _basic_slice_meta(ind, shape : tuple, strides : tuple, offset : int):
"Index {0} is out of range for axes "
"{1} with size {2}".format(ind_i, k, shape[k])
)
elif isinstance(ind_i, usm_ndarray):
elif isinstance(ind_i, (ndarray, usm_ndarray)):
if not array_streak:
array_streak = True
if not advanced_start_pos_set:
Expand Down
22 changes: 22 additions & 0 deletions dpctl/tests/test_usm_ndarray_indexing.py
Loading