ENH: Allow downstream to register for DLPack support by seberg · Pull Request #31256 · numpy/numpy · GitHub
Skip to content

ENH: Allow downstream to register for DLPack support#31256

Open
seberg wants to merge 4 commits intonumpy:mainfrom
seberg:dlpack-register-dtype
Open

ENH: Allow downstream to register for DLPack support#31256
seberg wants to merge 4 commits intonumpy:mainfrom
seberg:dlpack-register-dtype

Conversation

@seberg
Copy link
Copy Markdown
Member

@seberg seberg commented Apr 16, 2026

This basically allows ml_dtypes to add something like:

np.dtypes.register_dlpack_dtype((4, 16), np.dtype("bfloat16"))

after defining their bfloat16 dtype. After that import and export will work. (And of course the same for all other dtypes.)

This is a pragmatic version just using dicts internally. The code is slightly longer then I had hoped, but straight forward.

Now, generally, I don't like registering on the dtype instances, but:

  • It seems simple and focused enough that it isn't a big long-term burden.
  • The more proper "put it on the DType class" style is actually a bit more complex because it would need either need two full methods (and a loop to ask all DTypes or so) or assumptions around dtypes being non-parametric.

In the end, basically, I suspect this is good enough for all we need and not bad enough that need to worry a lot that it might stand in the way one day.

(I had an agent execute much of this, but all my design and all audited, just not all typed out by hand really...)

This is a pragmatic version just using dicts internally.  The code
is slightly longer then I had hoped, but straight forward.

Now, generally, I don't like regsistering on the dtype instances, but:
* It seems simple and focused enough that it isn't a big long-term burden.
* The more proper "put it on the dtype" style is actually a bit more complex
  because it would need assumptions like that we mean the normal byte-order
  (or multiple methods).
  I.e. my expectation is that it would be enough added complexity to not be
  all that nice.

In the end, basically, I suspect this is good enough for all we need and
not bad enough that need to worry a lot that it might stand in the way one day.

(I had an agent execute much of this, but all my design and all audited, just
not all typed out by hand really...)
@seberg seberg force-pushed the dlpack-register-dtype branch from 092dc8a to 7bb13a7 Compare April 16, 2026 15:28
Comment thread numpy/dtypes.pyi Outdated
@seberg seberg force-pushed the dlpack-register-dtype branch from 61f2bc5 to 0f2fc9d Compare April 16, 2026 15:55
Comment thread numpy/_core/src/multiarray/dlpack.c Outdated
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Return a registered dtype or raise an error if non is found.
* Return a registered dtype or raise an error if none is found.

(just reading for curiosity)

Comment thread numpy/_core/src/multiarray/dlpack.c Outdated
PyErr_SetString(PyExc_BufferError,
"DLPack only supports signed/unsigned integers, float "
"and complex dtypes (or dtypes registered via "
"by importing external packages).");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe slight reword on the "via by" ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made it: (or dtypes registered by third-party packages).


def test_roundtrip(self, dtype=np.dtype("S1")): # noqa: B008
# Register "S1" as kDLFloat8_e3m4 == 7
# (use of kwarg ensure singleton in free-threading)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, that seems pretty tricky to do right!

Copy link
Copy Markdown
Member Author

@seberg seberg Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in normal use, I think. np.dtype("bfloat16") does return a singleton. It also is only important during registration to avoid the error. "S1" is just a weird choice (but I chose it because it is weird) :).

def test_register_conflict(self):
np.dtypes.register_dlpack_dtype((4, 16), np.dtype(np.float16))
with pytest.raises(ValueError, match="already exported"):
np.dtypes.register_dlpack_dtype((5, 16), np.dtype(np.float16))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on the same object or the same properties of the dtype? I guess if someone tries hard enough to encode the same dtype "behavior" in different objects that's their problem (and not necessarily a problem anyway, apart from sanity)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, you can definitely register stuff that would break dlpack with dtype that are not supported by NumPy itself. (up to the point of crashing things)


# But... accept that this now does get exported (but won't roundtrip)
arr = np.from_dlpack(np.array(["12", "23"], dtype="S2"))
assert arr.dtype == np.float16
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit weird I guess--on main it is: BufferError: DLPack only supports signed/unsigned integers, float and complex dtypes..

No point in i.e., warning since nobody pays attention to warnings anyway?

Is there a good reason that someone would want to do this on purpose or is it most likely to happen by mistake in a control flow?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dunno... I thought I'd just allow this in theory just as a (maybe silly) hedge towards two competing bfloat16 implementations.
(Although even then, it might already work, due to the dict lookup in practice.)

Comment thread numpy/dtypes.py Outdated

__all__ = []

def register_dlpack_dtype(dlpack_key, dtype):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe dumb question--why the positional-only enforcement in the docs below but not the function sig proper? I guess it is in the pyi prototype as well, so might be safe to add

Copy link
Copy Markdown
Member Author

@seberg seberg Apr 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Honestly, no good reason, I'll change it...

EDIT: Added and removed the the duplication (as this is Python defined shouldn't be needed).

Co-authored-by: Tyler Reddy <tyler.je.reddy@gmail.com>
@ngoldbaum
Copy link
Copy Markdown
Member

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants