Fix `StridedMemoryView` by deferring the check for whether a capsule is versioned by leofang · Pull Request #292 · NVIDIA/cuda-python · GitHub
Skip to content

Fix StridedMemoryView by deferring the check for whether a capsule is versioned#292

Merged
leofang merged 2 commits into
NVIDIA:mainfrom
leofang:fix_view
Dec 13, 2024
Merged

Fix StridedMemoryView by deferring the check for whether a capsule is versioned#292
leofang merged 2 commits into
NVIDIA:mainfrom
leofang:fix_view

Conversation

@leofang

@leofang leofang commented Dec 12, 2024

Copy link
Copy Markdown
Member

Close #285.

xref: #285 (comment)

@leofang leofang added bug Something isn't working P0 High priority - Must do! cuda.core Everything related to the cuda.core module labels Dec 12, 2024
@leofang leofang added this to the cuda.core beta 2 milestone Dec 12, 2024
@leofang leofang self-assigned this Dec 12, 2024
@copy-pr-bot

copy-pr-bot Bot commented Dec 12, 2024

Copy link
Copy Markdown
Contributor

@leofang

leofang commented Dec 12, 2024

Copy link
Copy Markdown
Member Author

@yangcal could you test this instead?

@leofang

leofang commented Dec 12, 2024

Copy link
Copy Markdown
Member Author

/ok to test

@leofang

leofang commented Dec 13, 2024

Copy link
Copy Markdown
Member Author

/ok to test

@keenan-simpson keenan-simpson left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved and locally verified

import jax.numpy as jnp
import jax
from cuda.core.experimental.utils import args_viewable_as_strided_memory
from cuda.core.experimental import Device

@args_viewable_as_strided_memory((0,))
def parse_tensor(arr):
    dev = Device(0)
    dev.set_current()
    stream = dev.create_stream()
    view = arr.view(stream.handle)

arr = jnp.array([1, 2, 3], device = jax.devices("cuda")[0])
parse_tensor(arr)

@leofang

leofang commented Dec 13, 2024

Copy link
Copy Markdown
Member Author

@leofang leofang merged commit ddc1f94 into NVIDIA:main Dec 13, 2024
@leofang leofang deleted the fix_view branch December 13, 2024 21:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working cuda.core Everything related to the cuda.core module P0 High priority - Must do!

Projects

None yet

Development

Successfully merging this pull request may close these issues.

StridedMemoryView fails with Jax arrays

2 participants