ENH, API: Add `sort_compare` slot to DType and use sort wrappers if provided by MaanasArora · Pull Request #29987 · numpy/numpy · GitHub
Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/release/upcoming_changes/29987.new_feature.rst
36 changes: 33 additions & 3 deletions doc/source/reference/c-api/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2085,7 +2085,8 @@ struct contains a ``flags`` field which is a bitwise OR of ``NPY_SORTKIND``
values indicating the kind of sort to perform (that is, whether it is a
stable and/or descending sort). If the strided loop depends on the flags,
a good way to deal with this is to define :c:macro:`NPY_METH_get_loop`,
and not set any of the other loop slots.
and not set any of the other loop slots. Note that for both ascending and
descending sorts, NaN-like values should be sorted to the end.

.. c:struct:: PyArrayMethod_SortParameters

Expand All @@ -2097,6 +2098,10 @@ and not set any of the other loop slots.
These specs can be registered using :c:func:`PyUFunc_AddLoopsFromSpecs`
along with other ufunc loops.

Alternatively, custom sorting and argsorting for a DType can be
registered by defining the DType slot :c:macro:`NPY_DT_sort_compare`
with a comparison function.

API for calling array methods
-----------------------------

Expand Down Expand Up @@ -3712,6 +3717,25 @@ member of ``PyArrayDTypeMeta_Spec`` struct.
The number of decimal digits of precision. Corresponds to ``DIG`` from C
standard macros (e.g., ``FLT_DIG``, ``DBL_DIG``).

.. c:macro:: NPY_DT_sort_compare

.. c:type:: int (PyArrayDTypeMeta_CompareFuncWithContext)( \
char *a, char *b, PyArrayMethod_Context *context)

If defined, implements a comparison function for sorting arrays of this DType,
which can be used instead of the full sort loops (see :ref:`array-methods-sorting`).
If defined, NumPy will use this function to implement all sorting algorithms for
the DType.

The `parameters` member of the *context* can be used to access the sorting parameters
of type ``PyArrayMethod_SortParameters``, which are the same as passed to the sort
ArrayMethods. This can be used to determine if the sort is ascending or descending.

The function must return a negative value if *a* < *b*, zero if *a* == *b*,
and a positive value if *a* > *b*. If the sort is descending, the comparison
result should be inverted, however NaN handling should remain the same (i.e., NaNs
are always considered larger than any other value, regardless of sort order).

PyArray_ArrFuncs slots
^^^^^^^^^^^^^^^^^^^^^^

Expand All @@ -3738,6 +3762,8 @@ DType API slots but for now we have exposed the legacy
.. c:macro:: NPY_DT_PyArray_ArrFuncs_compare

Computes a comparison for `numpy.sort`, implements ``PyArray_CompareFunc``.
This slot may be deprecated in the future in favor of the
``NPY_DT_sort_compare`` DType API slot.

.. c:macro:: NPY_DT_PyArray_ArrFuncs_argmax

Expand Down Expand Up @@ -3781,13 +3807,17 @@ DType API slots but for now we have exposed the legacy

An array of PyArray_SortFunc of length ``NPY_NSORTS``. If set, allows
defining custom sorting implementations for each of the sorting
algorithms numpy implements.
algorithms numpy implements. This slot may be deprecated in the future
in favor of the ArrayMethod API for sorting
(see :ref:`array-methods-sorting`).

.. c:macro:: NPY_DT_PyArray_ArrFuncs_argsort

An array of PyArray_ArgSortFunc of length ``NPY_NSORTS``. If set,
allows defining custom argsorting implementations for each of the
sorting algorithms numpy implements.
sorting algorithms numpy implements. This slot may be deprecated in
the future in favor of the ArrayMethod API for argsorting
(see :ref:`array-methods-sorting`).

Macros and Static Inline Functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
4 changes: 4 additions & 0 deletions numpy/_core/include/numpy/dtype_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ typedef int (PyArrayMethod_PromoterFunction)(PyObject *ufunc,
#define NPY_DT_get_fill_zero_loop 10
#define NPY_DT_finalize_descr 11
#define NPY_DT_get_constant 12
#define NPY_DT_sort_compare 13

// These PyArray_ArrFunc slots will be deprecated and replaced eventually
// getitem and setitem can be defined as a performance optimization;
Expand Down Expand Up @@ -544,4 +545,7 @@ typedef struct {
NPY_SORTKIND flags;
} PyArrayMethod_SortParameters;

typedef int (PyArrayDTypeMeta_CompareFuncWithContext)(
char *, char *, PyArrayMethod_Context *context);

#endif /* NUMPY_CORE_INCLUDE_NUMPY___DTYPE_API_H_ */
35 changes: 35 additions & 0 deletions numpy/_core/src/common/npy_sort.h.src
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <Python.h>
#include <numpy/npy_common.h>
#include <numpy/ndarraytypes.h>
#include <numpy/ndarrayobject.h>

#define NPY_ENOMEM 1
#define NPY_ECOMP 2
Expand Down Expand Up @@ -107,6 +108,40 @@ NPY_NO_EXPORT int npy_aheapsort(void *vec, npy_intp *ind, npy_intp cnt, void *ar
NPY_NO_EXPORT int npy_amergesort(void *vec, npy_intp *ind, npy_intp cnt, void *arr);
NPY_NO_EXPORT int npy_atimsort(void *vec, npy_intp *ind, npy_intp cnt, void *arr);

/*
*****************************************************************************
** NEW-STYLE GENERIC SORT LOOPS **
*****************************************************************************
*/

NPY_NO_EXPORT int npy_quicksort_loop(PyArrayMethod_Context *context,
char *const *data, const npy_intp *dimensions, const npy_intp *strides,
NpyAuxData *transferdata);
NPY_NO_EXPORT int npy_mergesort_loop(PyArrayMethod_Context *context,
char *const *data, const npy_intp *dimensions, const npy_intp *strides,
NpyAuxData *transferdata);
NPY_NO_EXPORT int npy_aquicksort_loop(PyArrayMethod_Context *context,
char *const *data, const npy_intp *dimensions, const npy_intp *strides,
NpyAuxData *transferdata);
NPY_NO_EXPORT int npy_amergesort_loop(PyArrayMethod_Context *context,
char *const *data, const npy_intp *dimensions, const npy_intp *strides,
NpyAuxData *transferdata);

/*
*****************************************************************************
** GENERIC SORT IMPLEMENTATIONS **
*****************************************************************************
*/

NPY_NO_EXPORT int npy_quicksort_impl(void *start, npy_intp num, void *varr,
npy_intp elsize, PyArray_CompareFunc *cmp);
NPY_NO_EXPORT int npy_mergesort_impl(void *start, npy_intp num, void *varr,
npy_intp elsize, PyArray_CompareFunc *cmp);
NPY_NO_EXPORT int npy_aquicksort_impl(void *vv, npy_intp *tosort, npy_intp num, void *varr,
npy_intp elsize, PyArray_CompareFunc *cmp);
NPY_NO_EXPORT int npy_amergesort_impl(void *v, npy_intp *tosort, npy_intp num, void *varr,
npy_intp elsize, PyArray_CompareFunc *cmp);

#ifdef __cplusplus
}
#endif
Expand Down
10 changes: 5 additions & 5 deletions numpy/_core/src/multiarray/array_method.c
Original file line number Diff line number Diff line change
Expand Up @@ -343,11 +343,11 @@ fill_arraymethod_from_slots(
}
}
if (i >= meth->nin && NPY_DT_is_parametric(res->dtypes[i])) {
PyErr_Format(PyExc_TypeError,
"must provide a `resolve_descriptors` function if any "
"output DType is parametric. (method: %s)",
spec->name);
return -1;
// PyErr_Format(PyExc_TypeError,
// "must provide a `resolve_descriptors` function if any "
// "output DType is parametric. (method: %s)",
// spec->name);
// return -1;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be merged!!! I noticed we cannot actually not pass resolve_descriptors in most cases, at least if the dtype is parametric (since it's passed as output)! Should we special case somehow?

}
}
}
Expand Down
136 changes: 136 additions & 0 deletions numpy/_core/src/multiarray/dtypemeta.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "dtype_traversal.h"
#include "npy_static_data.h"
#include "multiarraymodule.h"
#include "npy_sort.h"

#include <assert.h>

Expand Down Expand Up @@ -179,6 +180,133 @@ PyArray_ArrFuncs default_funcs = {
.setitem = &legacy_setitem_using_DType,
};


NPY_NO_EXPORT int
default_sort_get_loop(
PyArrayMethod_Context *context,
int aligned, int move_references,
const npy_intp *strides,
PyArrayMethod_StridedLoop **out_loop,
NpyAuxData **out_transferdata,
NPY_ARRAYMETHOD_FLAGS *flags)
{
PyArrayMethod_SortParameters *parameters = (PyArrayMethod_SortParameters *)context->parameters;

if (PyDataType_FLAGCHK(context->descriptors[0], NPY_NEEDS_PYAPI)) {
*flags |= NPY_METH_REQUIRES_PYAPI;
}

if (parameters->flags == NPY_SORT_STABLE) {
*out_loop = (PyArrayMethod_StridedLoop *)npy_mergesort_loop;
}
else if (parameters->flags == NPY_SORT_DEFAULT) {
*out_loop = (PyArrayMethod_StridedLoop *)npy_quicksort_loop;
}
else {
PyErr_SetString(PyExc_RuntimeError, "unsupported sort kind");
return -1;
}
return 0;
}


NPY_NO_EXPORT int
default_argsort_get_loop(
PyArrayMethod_Context *context,
int aligned, int move_references,
const npy_intp *strides,
PyArrayMethod_StridedLoop **out_loop,
NpyAuxData **out_transferdata,
NPY_ARRAYMETHOD_FLAGS *flags)
{
PyArrayMethod_SortParameters *parameters = (PyArrayMethod_SortParameters *)context->parameters;

if (PyDataType_FLAGCHK(context->descriptors[0], NPY_NEEDS_PYAPI)) {
*flags |= NPY_METH_REQUIRES_PYAPI;
}

if (parameters->flags == NPY_SORT_STABLE) {
*out_loop = (PyArrayMethod_StridedLoop *)npy_amergesort_loop;
}
else if (parameters->flags == NPY_SORT_DEFAULT) {
*out_loop = (PyArrayMethod_StridedLoop *)npy_aquicksort_loop;
}
else {
PyErr_SetString(PyExc_RuntimeError, "unsupported sort kind");
return -1;
}
return 0;
}


NPY_NO_EXPORT int
wrap_default_sort_methods(PyArray_DTypeMeta *cls)
{
const char *tp_name = ((PyTypeObject *)cls)->tp_name;

PyArray_DTypeMeta *sort_dtypes[2] = {cls, cls};
PyType_Slot sort_slots[2] = {
{NPY_METH_get_loop, &default_sort_get_loop},
{0, NULL},
};

char *sort_name = PyDataMem_NEW(strlen(tp_name) + 7);
if (sort_name == NULL) {
return -1;
}
sprintf(sort_name, "%s_sort", tp_name);

PyArrayMethod_Spec sort_spec = {
.name = (const char *)sort_name,
.nin = 1,
.nout = 1,
.dtypes = sort_dtypes,
.slots = sort_slots,
};

PyBoundArrayMethodObject *sort_method = PyArrayMethod_FromSpec_int(
&sort_spec, 1);
if (sort_method == NULL) {
PyDataMem_FREE(sort_name);
return -1;
}
NPY_DT_SLOTS(cls)->sort_meth = sort_method->method;
Py_INCREF(sort_method->method);
Py_DECREF(sort_method);

PyArray_DTypeMeta *argsort_dtypes[2] = {cls, &PyArray_IntpDType};
PyType_Slot argsort_slots[2] = {
{NPY_METH_get_loop, &default_argsort_get_loop},
{0, NULL},
};

char *argsort_name = PyDataMem_NEW(strlen(tp_name) + 9);
if (argsort_name == NULL) {
return -1;
}
sprintf(argsort_name, "%s_argsort", tp_name);

PyArrayMethod_Spec argsort_spec = {
.name = (const char *)argsort_name,
.nin = 1,
.nout = 1,
.dtypes = argsort_dtypes,
.slots = argsort_slots,
};
PyBoundArrayMethodObject *argsort_method = PyArrayMethod_FromSpec_int(
&argsort_spec, 1);
if (argsort_method == NULL) {
PyDataMem_FREE(argsort_name);
return -1;
}
NPY_DT_SLOTS(cls)->argsort_meth = argsort_method->method;
Py_INCREF(argsort_method->method);
Py_DECREF(argsort_method);

return 0;
}


/*
* Internal version of PyArrayInitDTypeMeta_FromSpec.
*
Expand Down Expand Up @@ -349,6 +477,14 @@ dtypemeta_initialize_struct_from_spec(
return -1;
}

/* If sort_compare is set, we need to fill in the sorting array method slots */
if (NPY_DT_SLOTS(DType)->sort_compare != NULL) {
if (wrap_default_sort_methods(DType) < 0) {
Py_DECREF(DType);
return -1;
}
}

/*
* And now, register all the casts that are currently defined!
*/
Expand Down
7 changes: 6 additions & 1 deletion numpy/_core/src/multiarray/dtypemeta.h
Loading
Loading