diff --git a/fastplotlib/graphics/selectors/_rectangle.py b/fastplotlib/graphics/selectors/_rectangle.py index 38bb0d2f9..02e5ec6b3 100644 --- a/fastplotlib/graphics/selectors/_rectangle.py +++ b/fastplotlib/graphics/selectors/_rectangle.py @@ -18,7 +18,7 @@ def parent(self) -> Graphic | None: return self._parent @property - def selection(self) -> Sequence[float] | List[Sequence[float]]: + def selection(self) -> np.ndarray[float]: """ (xmin, xmax, ymin, ymax) of the rectangle selection """ @@ -319,10 +319,11 @@ def get_selected_data( # do not need to check for mode for images, because the selector is bounded by the image shape # will always be `full` if "Image" in source.__class__.__name__: - col_ixs = slice(ixs[0][0], ixs[0][-1] + 1) - row_ixs = slice(ixs[1][0], ixs[1][-1] + 1) + row_ixs, col_ixs = ixs + row_slice = slice(row_ixs[0], row_ixs[-1] + 1) + col_slice = slice(col_ixs[0], col_ixs[-1] + 1) - return source.data[row_ixs, col_ixs] + return source.data[row_slice, col_slice] if mode not in ["full", "partial", "ignore"]: raise ValueError( @@ -414,7 +415,7 @@ def get_selected_data( def get_selected_indices( self, graphic: Graphic = None - ) -> Union[np.ndarray, List[np.ndarray]]: + ) -> np.ndarray | tuple[np.ndarray]: """ Returns the indices of the ``Graphic`` data bounded by the current selection. @@ -429,7 +430,7 @@ def get_selected_indices( ------- Union[np.ndarray, List[np.ndarray]] data indicies of the selection - | list of [x_indices_array, y_indices_array] if the graphic is an image + | tuple of [row_indices, y_indices_array] if the graphic is an image | list of indices along the x-dimension for each line if graphic is a line collection | array of indices along the x-dimension if graphic is a line """ @@ -437,14 +438,14 @@ def get_selected_indices( source = self._get_source(graphic) # selector (xmin, xmax, ymin, ymax) values - bounds = self.selection + xmin, xmax, ymin, ymax = self.selection # image data does not need to check for mode because the selector is always bounded # to the image if "Image" in source.__class__.__name__: - xs = np.arange(bounds[0], bounds[1], dtype=int) - ys = np.arange(bounds[2], bounds[3], dtype=int) - return [xs, ys] + col_ixs = np.arange(xmin, xmax, dtype=int) + row_ixs = np.arange(ymin, ymax, dtype=int) + return row_ixs, col_ixs if "Line" in source.__class__.__name__: if isinstance(source, GraphicCollection): @@ -452,20 +453,20 @@ def get_selected_indices( for g in source.graphics: data = g.data.value g_ixs = np.where( - (data[:, 0] >= bounds[0] - g.offset[0]) - & (data[:, 0] <= bounds[1] - g.offset[0]) - & (data[:, 1] >= bounds[2] - g.offset[1]) - & (data[:, 1] <= bounds[3] - g.offset[1]) + (data[:, 0] >= xmin - g.offset[0]) + & (data[:, 0] <= xmax - g.offset[0]) + & (data[:, 1] >= ymin - g.offset[1]) + & (data[:, 1] <= ymax - g.offset[1]) )[0] ixs.append(g_ixs) else: # map only this graphic data = source.data.value ixs = np.where( - (data[:, 0] >= bounds[0]) - & (data[:, 0] <= bounds[1]) - & (data[:, 1] >= bounds[2]) - & (data[:, 1] <= bounds[3]) + (data[:, 0] >= xmin) + & (data[:, 0] <= xmax) + & (data[:, 1] >= ymin) + & (data[:, 1] <= ymax) )[0] return ixs