ENH: Allow downstream to register for DLPack support#31256
ENH: Allow downstream to register for DLPack support#31256seberg wants to merge 4 commits intonumpy:mainfrom
Conversation
This is a pragmatic version just using dicts internally. The code is slightly longer then I had hoped, but straight forward. Now, generally, I don't like regsistering on the dtype instances, but: * It seems simple and focused enough that it isn't a big long-term burden. * The more proper "put it on the dtype" style is actually a bit more complex because it would need assumptions like that we mean the normal byte-order (or multiple methods). I.e. my expectation is that it would be enough added complexity to not be all that nice. In the end, basically, I suspect this is good enough for all we need and not bad enough that need to worry a lot that it might stand in the way one day. (I had an agent execute much of this, but all my design and all audited, just not all typed out by hand really...)
092dc8a to
7bb13a7
Compare
fd809b2 to
61f2bc5
Compare
61f2bc5 to
0f2fc9d
Compare
There was a problem hiding this comment.
| * Return a registered dtype or raise an error if non is found. | |
| * Return a registered dtype or raise an error if none is found. |
(just reading for curiosity)
| PyErr_SetString(PyExc_BufferError, | ||
| "DLPack only supports signed/unsigned integers, float " | ||
| "and complex dtypes (or dtypes registered via " | ||
| "by importing external packages)."); |
There was a problem hiding this comment.
maybe slight reword on the "via by" ?
There was a problem hiding this comment.
Made it: (or dtypes registered by third-party packages).
|
|
||
| def test_roundtrip(self, dtype=np.dtype("S1")): # noqa: B008 | ||
| # Register "S1" as kDLFloat8_e3m4 == 7 | ||
| # (use of kwarg ensure singleton in free-threading) |
There was a problem hiding this comment.
hmm, that seems pretty tricky to do right!
There was a problem hiding this comment.
Not in normal use, I think. np.dtype("bfloat16") does return a singleton. It also is only important during registration to avoid the error. "S1" is just a weird choice (but I chose it because it is weird) :).
| def test_register_conflict(self): | ||
| np.dtypes.register_dlpack_dtype((4, 16), np.dtype(np.float16)) | ||
| with pytest.raises(ValueError, match="already exported"): | ||
| np.dtypes.register_dlpack_dtype((5, 16), np.dtype(np.float16)) |
There was a problem hiding this comment.
This is based on the same object or the same properties of the dtype? I guess if someone tries hard enough to encode the same dtype "behavior" in different objects that's their problem (and not necessarily a problem anyway, apart from sanity)
There was a problem hiding this comment.
Right, you can definitely register stuff that would break dlpack with dtype that are not supported by NumPy itself. (up to the point of crashing things)
|
|
||
| # But... accept that this now does get exported (but won't roundtrip) | ||
| arr = np.from_dlpack(np.array(["12", "23"], dtype="S2")) | ||
| assert arr.dtype == np.float16 |
There was a problem hiding this comment.
A bit weird I guess--on main it is: BufferError: DLPack only supports signed/unsigned integers, float and complex dtypes..
No point in i.e., warning since nobody pays attention to warnings anyway?
Is there a good reason that someone would want to do this on purpose or is it most likely to happen by mistake in a control flow?
There was a problem hiding this comment.
I dunno... I thought I'd just allow this in theory just as a (maybe silly) hedge towards two competing bfloat16 implementations.
(Although even then, it might already work, due to the dict lookup in practice.)
|
|
||
| __all__ = [] | ||
|
|
||
| def register_dlpack_dtype(dlpack_key, dtype): |
There was a problem hiding this comment.
maybe dumb question--why the positional-only enforcement in the docs below but not the function sig proper? I guess it is in the pyi prototype as well, so might be safe to add
There was a problem hiding this comment.
Honestly, no good reason, I'll change it...
EDIT: Added and removed the the duplication (as this is Python defined shouldn't be needed).
Co-authored-by: Tyler Reddy <tyler.je.reddy@gmail.com>

This basically allows
ml_dtypesto add something like:after defining their bfloat16 dtype. After that import and export will work. (And of course the same for all other dtypes.)
This is a pragmatic version just using dicts internally. The code is slightly longer then I had hoped, but straight forward.
Now, generally, I don't like registering on the dtype instances, but:
In the end, basically, I suspect this is good enough for all we need and not bad enough that need to worry a lot that it might stand in the way one day.
(I had an agent execute much of this, but all my design and all audited, just not all typed out by hand really...)