Get nearest graphics by kushalkolar · Pull Request #519 · fastplotlib/fastplotlib · GitHub
Skip to content
2 changes: 1 addition & 1 deletion fastplotlib/graphics/line_collection.py
1 change: 1 addition & 0 deletions fastplotlib/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .functions import *
from .gpu import enumerate_adapters, select_adapter, print_wgpu_report
from ._plot_helpers import *


@dataclass
Expand Down
53 changes: 53 additions & 0 deletions fastplotlib/utils/_plot_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Sequence

import numpy as np

from ..graphics._base import Graphic
from ..graphics._collection_base import GraphicCollection


def get_nearest_graphics(
pos: tuple[float, float] | tuple[float, float, float],
graphics: Sequence[Graphic] | GraphicCollection,
) -> np.ndarray[Graphic]:
"""
Returns the nearest ``graphics`` to the passed position ``pos`` in world space.
Uses the distance between ``pos`` and the center of the bounding sphere for each graphic.

Parameters
----------
pos: (x, y) | (x, y, z)
position in world space, z-axis is ignored when calculating L2 norms if ``pos`` is 2D

graphics: Sequence, i.e. array, list, tuple, etc. of Graphic | GraphicCollection
the graphics from which to return a sorted array of graphics in order of closest
to furthest graphic

Returns
-------
tuple[Graphic]
nearest graphics to ``pos`` in order

"""

if isinstance(graphics, GraphicCollection):
graphics = graphics.graphics

if not all(isinstance(g, Graphic) for g in graphics):
raise TypeError("all elements of `graphics` must be Graphic objects")
Comment on lines +36 to +37
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there also be a check for if all the graphics are in the same subplot? Or would that not make a difference?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We leave that to the user, so it's just a very simple function unaware of plot areas


pos = np.asarray(pos)

if pos.shape != (2,) or not pos.shape != (3,):
raise TypeError

# get centers
centers = np.empty(shape=(len(graphics), len(pos)))
for i in range(centers.shape[0]):
centers[i] = graphics[i].world_object.get_world_bounding_sphere()[: len(pos)]

# l2
distances = np.linalg.norm(centers[:, : len(pos)] - pos, ord=2, axis=1)

sort_indices = np.argsort(distances)
return np.asarray(graphics)[sort_indices]
33 changes: 33 additions & 0 deletions tests/test_plot_helpers.py