From 9a01cd5ee7fda8e4dd923670f0466d1233bc6de0 Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Thu, 25 Dec 2025 01:04:06 -0800 Subject: [PATCH 1/5] start ndprocessors --- fastplotlib/utils/__init__.py | 1 + fastplotlib/utils/_protocols.py | 12 ++ fastplotlib/widgets/nd_widget/_processor.py | 141 ++++++++++++++++++++ 3 files changed, 154 insertions(+) create mode 100644 fastplotlib/utils/_protocols.py create mode 100644 fastplotlib/widgets/nd_widget/_processor.py diff --git a/fastplotlib/utils/__init__.py b/fastplotlib/utils/__init__.py index dd527ca67..8001ae375 100644 --- a/fastplotlib/utils/__init__.py +++ b/fastplotlib/utils/__init__.py @@ -6,6 +6,7 @@ from .gpu import enumerate_adapters, select_adapter, print_wgpu_report from ._plot_helpers import * from .enums import * +from ._protocols import ArrayProtocol @dataclass diff --git a/fastplotlib/utils/_protocols.py b/fastplotlib/utils/_protocols.py new file mode 100644 index 000000000..c168ecfa4 --- /dev/null +++ b/fastplotlib/utils/_protocols.py @@ -0,0 +1,12 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class ArrayProtocol(Protocol): + @property + def ndim(self) -> int: ... + + @property + def shape(self) -> tuple[int, ...]: ... + + def __getitem__(self, key): ... diff --git a/fastplotlib/widgets/nd_widget/_processor.py b/fastplotlib/widgets/nd_widget/_processor.py new file mode 100644 index 000000000..9e5299118 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_processor.py @@ -0,0 +1,141 @@ +import inspect +from typing import Literal, Callable, Any +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +class NDProcessor: + def __init__( + self, + data: ArrayProtocol, + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + self._data = self._validate_data(data) + self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) + + @property + def data(self) -> ArrayProtocol: + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + self._data = self._validate_data(data) + + def _validate_data(self, data: ArrayProtocol): + if not isinstance(data, ArrayProtocol): + raise TypeError("`data` must implement the ArrayProtocol") + + return data + + @property + def window_funcs(self) -> tuple[WindowFuncCallable | None] | None: + pass + + @property + def window_sizes(self) -> tuple[int | None] | None: + pass + + @property + def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: + pass + + @property + def slider_dims(self) -> tuple[int, ...] | None: + pass + + @property + def slider_index_maps(self) -> tuple[Callable[[Any], int] | None, ...]: + return self._slider_index_maps + + @slider_index_maps.setter + def slider_index_maps(self, maps): + self._maps = self._validate_slider_index_maps(maps) + + def _validate_slider_index_maps(self, maps): + if maps is not None: + if not all([callable(m) or m is None for m in maps]): + raise TypeError + + return maps + + def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: + pass + + +class NDImageProcessor(NDProcessor): + @property + def n_display_dims(self) -> Literal[2, 3]: + pass + + def _validate_n_display_dims(self, n_display_dims): + if n_display_dims not in (2, 3): + raise ValueError("`n_display_dims` must be") + + +class NDTimeSeriesProcessor(NDProcessor): + def __init__( + self, + data: ArrayProtocol, + graphic: Literal["line", "heatmap"] = "line", + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + display_window: int | float | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + super().__init__( + data=data, + n_display_dims=n_display_dims, + slider_index_maps=slider_index_maps, + ) + + self._display_window = display_window + + def _validate_data(self, data: ArrayProtocol): + data = super()._validate_data(data) + + # need to make shape be [n_lines, n_datapoints, 2] + # this will work for displaying a linestack and heatmap + # for heatmap just slice: [..., 1] + # TODO: Think about how to allow n-dimensional lines, + # maybe [d1, d2, ..., d(n - 1), n_lines, n_datapoint, 2] + # and dn is the x-axis values?? + if data.ndim == 1: + pass + + @property + def display_window(self) -> int | float | None: + """display window in the reference units along the x-axis""" + return self._display_window + + def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: + if self.display_window is not None: + # map reference units -> array int indices if necessary + if self.slider_index_maps is not None: + indices_window = self.slider_index_maps(self.display_window) + else: + indices_window = self.display_window + + # half window size + hw = indices_window // 2 + + # for now assume just a single index provided that indicates x axis value + start = max(indices - hw, 0) + stop = indices + hw + + # slice dim would be ndim - 1 + + return self.data[start:stop] From c46455ff71e460772148bf629dd906beffaf3cca Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 27 Dec 2025 03:52:32 -0800 Subject: [PATCH 2/5] basic timeseries --- fastplotlib/widgets/nd_widget/_processor.py | 187 +++++++++++++++++--- 1 file changed, 159 insertions(+), 28 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_processor.py b/fastplotlib/widgets/nd_widget/_processor.py index 9e5299118..d0a8e66ab 100644 --- a/fastplotlib/widgets/nd_widget/_processor.py +++ b/fastplotlib/widgets/nd_widget/_processor.py @@ -5,6 +5,7 @@ import numpy as np from numpy.typing import ArrayLike +from ...graphics import ImageGraphic, LineStack, LineCollection, ScatterGraphic from ...utils import subsample_array, ArrayProtocol @@ -14,13 +15,13 @@ class NDProcessor: def __init__( - self, - data: ArrayProtocol, - n_display_dims: Literal[2, 3] = 2, - slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, - window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, - spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + self, + data, + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): self._data = self._validate_data(data) self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) @@ -84,17 +85,30 @@ def _validate_n_display_dims(self, n_display_dims): raise ValueError("`n_display_dims` must be") +VALID_TIMESERIES_Y_DATA_SHAPES = ( + "[n_datapoints] for 1D array of y-values, [n_datapoints, 2] " + "for a 1D array of y and z-values, [n_lines, n_datapoints] for a 2D stack of lines with y-values, " + "or [n_lines, n_datapoints, 2] for a stack of lines with y and z-values." +) + + +# Limitation, no heatmap if z-values present, I don't think you can visualize that class NDTimeSeriesProcessor(NDProcessor): def __init__( - self, - data: ArrayProtocol, - graphic: Literal["line", "heatmap"] = "line", - n_display_dims: Literal[2, 3] = 2, - slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, - display_window: int | float | None = None, - window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, - spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + self, + data: list[ + ArrayProtocol, ArrayProtocol + ], # list: [x_vals_array, y_vals_and_z_vals_array] + x_values: ArrayProtocol = None, + cmap: str = None, + cmap_transform: ArrayProtocol = None, + display_graphic: Literal["line", "heatmap"] = "line", + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + display_window: int | float | None = 100, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, ): super().__init__( data=data, @@ -104,23 +118,73 @@ def __init__( self._display_window = display_window - def _validate_data(self, data: ArrayProtocol): - data = super()._validate_data(data) + self._display_graphic = None + self.display_graphic = display_graphic - # need to make shape be [n_lines, n_datapoints, 2] - # this will work for displaying a linestack and heatmap - # for heatmap just slice: [..., 1] - # TODO: Think about how to allow n-dimensional lines, - # maybe [d1, d2, ..., d(n - 1), n_lines, n_datapoint, 2] - # and dn is the x-axis values?? - if data.ndim == 1: - pass + self._uniform_x_values: ArrayProtocol | None = None + self._interp_yz: ArrayProtocol | None = None + + @property + def data(self) -> list[ArrayProtocol, ArrayProtocol]: + return self._data + + @data.setter + def data(self, data: list[ArrayProtocol, ArrayProtocol]): + self._data = self._validate_data(data) + + def _validate_data(self, data: list[ArrayProtocol, ArrayProtocol]): + x_vals, yz_vals = data + + if x_vals.ndim != 1: + raise ("data x values must be 1D") + + if data[1].ndim > 3: + raise ValueError( + f"data yz values must be of shape: {VALID_TIMESERIES_Y_DATA_SHAPES}. You passed data of shape: {yz_vals.shape}" + ) + + return data + + @property + def display_graphic(self) -> Literal["line", "heatmap"]: + return self._display_graphic + + @display_graphic.setter + def display_graphic(self, dg: Literal["line", "heatmap"]): + dg = self._validate_display_graphic(dg) + + if dg == "heatmap": + # check if x-vals uniformly spaced + norm = np.linalg.norm(np.diff(np.diff(self.x_values))) / len(self.x_values) + if norm > 10 ** -12: + # need to create evenly spaced x-values + x0 = self.data[0][0] + xn = self.data[0][-1] + self._uniform_x_values = np.linspace(x0, xn, num=len(self.data[0])) + + # TODO: interpolate yz values on the fly only when within the display window + + def _validate_display_graphic(self, dg): + if dg not in ("line", "heatmap"): + raise ValueError + + return dg @property def display_window(self) -> int | float | None: """display window in the reference units along the x-axis""" return self._display_window + @display_window.setter + def display_window(self, dw: int | float | None): + if dw is None: + self._display_window = None + + elif not isinstance(dw, (int, float)): + raise TypeError + + self._display_window = dw + def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: if self.display_window is not None: # map reference units -> array int indices if necessary @@ -134,8 +198,75 @@ def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: # for now assume just a single index provided that indicates x axis value start = max(indices - hw, 0) - stop = indices + hw + stop = start + indices_window # slice dim would be ndim - 1 + return self.data[0][start:stop], self.data[1][:, start:stop] + + +class NDTimeSeries: + def __init__(self, processor: NDTimeSeriesProcessor, display_graphic): + self._processor = processor + + self._indices = 0 + + if display_graphic == "line": + self._create_line_stack() + + @property + def processor(self) -> NDTimeSeriesProcessor: + return self._processor + + @property + def graphic(self) -> LineStack | ImageGraphic: + """LineStack or ImageGraphic for heatmaps""" + return self._graphic + + @property + def display_window(self) -> int | float | None: + return self.processor.display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + # create new graphic if it changed + if dw != self.display_window: + create_new_graphic = True + else: + create_new_graphic = False + + self.processor.display_window = dw + + if create_new_graphic: + if isinstance(self.graphic, LineStack): + self.set_index(self._indices) + + def set_index(self, indices: tuple[Any, ...]): + # set the graphic at the given data indices + data_slice = self.processor[indices] + + if isinstance(self.graphic, LineStack): + line_stack_data = self._create_line_stack_data(data_slice) + + for g, line_data in zip(self.graphic.graphics, line_stack_data): + if line_data.shape[1] == 2: + # only x and y values + g.data[:, :-1] = line_data + else: + # has z values too + g.data[:] = line_data + + self._indices = indices + + def _create_line_stack_data(self, data_slice): + xs = data_slice[0] # 1D + yz = data_slice[1] # [n_lines, n_datapoints] for y-vals or [n_lines, n_datapoints, 2] for yz-vals + + # need to go from x_vals and yz_vals arrays to an array of shape: [n_lines, n_datapoints, 2 | 3] + return np.dstack([np.repeat(xs[None], repeats=yz.shape[0], axis=0), yz]) + + def _create_line_stack(self): + data_slice = self.processor[self._indices] + + ls_data = self._create_line_stack_data(data_slice) - return self.data[start:stop] + self._graphic = LineStack(ls_data) From d93fa5d5fdc685b8d7f2b7bc38a95abb100f31da Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 27 Dec 2025 03:52:52 -0800 Subject: [PATCH 3/5] add __init__ --- fastplotlib/widgets/nd_widget/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 fastplotlib/widgets/nd_widget/__init__.py diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py new file mode 100644 index 000000000..e69de29bb From fddefb826f44f443c2504557c6d5e76b2e50c05f Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Sat, 27 Dec 2025 17:46:54 -0800 Subject: [PATCH 4/5] heatmap for timeseries works! --- fastplotlib/widgets/nd_widget/_processor.py | 55 ++++++++++++++++++++- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/fastplotlib/widgets/nd_widget/_processor.py b/fastplotlib/widgets/nd_widget/_processor.py index d0a8e66ab..0add36594 100644 --- a/fastplotlib/widgets/nd_widget/_processor.py +++ b/fastplotlib/widgets/nd_widget/_processor.py @@ -155,6 +155,7 @@ def display_graphic(self, dg: Literal["line", "heatmap"]): if dg == "heatmap": # check if x-vals uniformly spaced + # this is very fast to do on the fly, especially for typical small display windows norm = np.linalg.norm(np.diff(np.diff(self.x_values))) / len(self.x_values) if norm > 10 ** -12: # need to create evenly spaced x-values @@ -205,13 +206,17 @@ def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: class NDTimeSeries: - def __init__(self, processor: NDTimeSeriesProcessor, display_graphic): + def __init__(self, processor: NDTimeSeriesProcessor, graphic): self._processor = processor self._indices = 0 - if display_graphic == "line": + if graphic == "line": self._create_line_stack() + elif graphic == "heatmap": + self._create_heatmap() + else: + raise ValueError @property def processor(self) -> NDTimeSeriesProcessor: @@ -222,6 +227,19 @@ def graphic(self) -> LineStack | ImageGraphic: """LineStack or ImageGraphic for heatmaps""" return self._graphic + @graphic.setter + def graphic(self, g: Literal["line", "heatmap"]): + if g == "line": + # TODO: remove existing graphic + self._create_line_stack() + + elif g == "heatmap": + # make sure "yz" data is only ys and no z values + # can't represent y and z vals in a heatmap + if self.processor.data[1].ndim > 2: + raise ValueError("Only y-values are supported for heatmaps, not yz-values") + self._create_heatmap() + @property def display_window(self) -> int | float | None: return self.processor.display_window @@ -255,6 +273,10 @@ def set_index(self, indices: tuple[Any, ...]): # has z values too g.data[:] = line_data + elif isinstance(self.graphic, ImageGraphic): + hm_data, scale = self._create_heatmap_data(data_slice) + self.graphic.data = hm_data + self._indices = indices def _create_line_stack_data(self, data_slice): @@ -270,3 +292,32 @@ def _create_line_stack(self): ls_data = self._create_line_stack_data(data_slice) self._graphic = LineStack(ls_data) + + def _create_heatmap_data(self, data_slice) -> tuple[ArrayProtocol, float]: + """Returns [n_lines, y_values] array and scale factor for x dimension""" + # check if x-vals uniformly spaced + # this is very fast to do on the fly, especially for typical small display windows + x, y = data_slice + norm = np.linalg.norm(np.diff(np.diff(x))) / x.size + if norm > 10 ** -12: + # need to create evenly spaced x-values + x_uniform = np.linspace(x[0], x[-1], num=x.size) + # yz is [n_lines, n_datapoints] + y_interp = np.zeros(shape=y.shape, dtype=np.float32) + for i in range(y.shape[0]): + y_interp[i] = np.interp(x_uniform, x, y[i]) + + else: + y_interp = y + + x_scale = x[-1] / x.size + + return y_interp, x_scale + + def _create_heatmap(self): + data_slice = self.processor[self._indices] + + hm_data, x_scale = self._create_heatmap_data(data_slice) + + self._graphic = ImageGraphic(hm_data) + self._graphic.world_object.world.scale_x = x_scale \ No newline at end of file From d5e4c7d45901b1f5f2de89e68ef4d416d0ea7dde Mon Sep 17 00:00:00 2001 From: kushalkolar Date: Mon, 29 Dec 2025 01:44:11 -0800 Subject: [PATCH 5/5] NDPositions, basics work, reorganize, increase default scatter size --- fastplotlib/graphics/scatter.py | 2 +- fastplotlib/widgets/nd_widget/_nd_image.py | 13 ++ .../widgets/nd_widget/_nd_positions.py | 137 ++++++++++++++++++ .../{_processor.py => _nd_timeseries.py} | 104 +------------ .../widgets/nd_widget/_processor_base.py | 74 ++++++++++ 5 files changed, 227 insertions(+), 103 deletions(-) create mode 100644 fastplotlib/widgets/nd_widget/_nd_image.py create mode 100644 fastplotlib/widgets/nd_widget/_nd_positions.py rename fastplotlib/widgets/nd_widget/{_processor.py => _nd_timeseries.py} (70%) create mode 100644 fastplotlib/widgets/nd_widget/_processor_base.py diff --git a/fastplotlib/graphics/scatter.py b/fastplotlib/graphics/scatter.py index a2e696a82..5268dcc51 100644 --- a/fastplotlib/graphics/scatter.py +++ b/fastplotlib/graphics/scatter.py @@ -53,7 +53,7 @@ def __init__( image: np.ndarray = None, point_rotations: float | np.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", - sizes: float | np.ndarray | Sequence[float] = 1, + sizes: float | np.ndarray | Sequence[float] = 5, uniform_size: bool = False, size_space: str = "screen", isolated_buffer: bool = True, diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py new file mode 100644 index 000000000..f115e146e --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -0,0 +1,13 @@ +from typing import Literal + +from ._processor_base import NDProcessor + + +class NDImageProcessor(NDProcessor): + @property + def n_display_dims(self) -> Literal[2, 3]: + pass + + def _validate_n_display_dims(self, n_display_dims): + if n_display_dims not in (2, 3): + raise ValueError("`n_display_dims` must be") diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py new file mode 100644 index 000000000..db8c80e72 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -0,0 +1,137 @@ +import inspect +from typing import Literal, Callable, Any, Type +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + +from ...graphics import ImageGraphic, LineGraphic, LineStack, LineCollection, ScatterGraphic +from ._processor_base import NDProcessor + +# TODO: Maybe get rid of n_display_dims in NDProcessor, +# we will know the display dims automatically here from the last dim +# so maybe we only need it for images? +class NDPositionsProcessor(NDProcessor): + def __init__( + self, + data: ArrayProtocol, + multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points + display_window: int | float | None = 100, # window for n_datapoints dim only + ): + super().__init__(data=data) + + self._display_window = display_window + + self.multi = multi + + def _validate_data(self, data: ArrayProtocol): + # TODO: determine right validation shape etc. + return data + + @property + def display_window(self) -> int | float | None: + """display window in the reference units for the n_datapoints dim""" + return self._display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + if dw is None: + self._display_window = None + + elif not isinstance(dw, (int, float)): + raise TypeError + + self._display_window = dw + + @property + def multi(self) -> bool: + return self._multi + + @multi.setter + def multi(self, m: bool): + if m and self.data.ndim < 3: + # p is p-datapoints, n is how many lines/scatter to show simultaneously + raise ValueError("ndim must be >= 3 for multi, shape must be [s1..., sn, n, p, 2 | 3]") + + self._multi = m + + def __getitem__(self, indices: tuple[Any, ...]): + """sliders through all slider dims and outputs an array that can be used to set graphic data""" + if self.display_window is not None: + indices_window = self.display_window + + # half window size + hw = indices_window // 2 + + # for now assume just a single index provided that indicates x axis value + start = max(indices - hw, 0) + stop = start + indices_window + + slices = [slice(start, stop)] + + # TODO: implement slicing for multiple slider dims, i.e. [s1, s2, ... n_datapoints, 2 | 3] + # this currently assumes the shape is: [n_datapoints, 2 | 3] + if self.multi: + # n - 2 dim is n_lines or n_scatters + slices.insert(0, slice(None)) + + return self.data[tuple(slices)] + + +class NDPositions: + def __init__(self, data, graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic], multi: bool = False): + self._indices = 0 + + if issubclass(graphic, LineCollection): + multi = True + + self._processor = NDPositionsProcessor(data, multi=multi) + + self._create_graphic(graphic) + + @property + def processor(self) -> NDPositionsProcessor: + return self._processor + + @property + def graphic(self) -> LineGraphic | LineCollection | LineStack | ScatterGraphic | list[ScatterGraphic]: + """LineStack or ImageGraphic for heatmaps""" + return self._graphic + + @property + def indices(self) -> tuple: + return self._indices + + @indices.setter + def indices(self, indices): + data_slice = self.processor[indices] + + if isinstance(self.graphic, list): + # list of scatter + for i in range(len(self.graphic)): + # data_slice shape is [n_scatters, n_datapoints, 2 | 3] + # by using data_slice.shape[-1] it will auto-select if the data is only xy or has xyz + self.graphic[i].data[:, :data_slice.shape[-1]] = data_slice[i] + + elif isinstance(self.graphic, (LineGraphic, ScatterGraphic)): + self.graphic.data[:, :data_slice.shape[-1]] = data_slice + + elif isinstance(self.graphic, LineCollection): + for i in range(len(self.graphic)): + # data_slice shape is [n_lines, n_datapoints, 2 | 3] + self.graphic[i].data[:, :data_slice.shape[-1]] = data_slice[i] + + def _create_graphic(self, graphic_cls: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic]): + if self.processor.multi and issubclass(graphic_cls, ScatterGraphic): + # make list of scatters + self._graphic = list() + data_slice = self.processor[self.indices] + for d in data_slice: + scatter = graphic_cls(d) + self._graphic.append(scatter) + + else: + data_slice = self.processor[self.indices] + self._graphic = graphic_cls(data_slice) diff --git a/fastplotlib/widgets/nd_widget/_processor.py b/fastplotlib/widgets/nd_widget/_nd_timeseries.py similarity index 70% rename from fastplotlib/widgets/nd_widget/_processor.py rename to fastplotlib/widgets/nd_widget/_nd_timeseries.py index 0add36594..8630044cf 100644 --- a/fastplotlib/widgets/nd_widget/_processor.py +++ b/fastplotlib/widgets/nd_widget/_nd_timeseries.py @@ -5,84 +5,10 @@ import numpy as np from numpy.typing import ArrayLike -from ...graphics import ImageGraphic, LineStack, LineCollection, ScatterGraphic from ...utils import subsample_array, ArrayProtocol - -# must take arguments: array-like, `axis`: int, `keepdims`: bool -WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] - - -class NDProcessor: - def __init__( - self, - data, - n_display_dims: Literal[2, 3] = 2, - slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, - window_funcs: tuple[WindowFuncCallable | None] | None = None, - window_sizes: tuple[int | None] | None = None, - spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, - ): - self._data = self._validate_data(data) - self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) - - @property - def data(self) -> ArrayProtocol: - return self._data - - @data.setter - def data(self, data: ArrayProtocol): - self._data = self._validate_data(data) - - def _validate_data(self, data: ArrayProtocol): - if not isinstance(data, ArrayProtocol): - raise TypeError("`data` must implement the ArrayProtocol") - - return data - - @property - def window_funcs(self) -> tuple[WindowFuncCallable | None] | None: - pass - - @property - def window_sizes(self) -> tuple[int | None] | None: - pass - - @property - def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: - pass - - @property - def slider_dims(self) -> tuple[int, ...] | None: - pass - - @property - def slider_index_maps(self) -> tuple[Callable[[Any], int] | None, ...]: - return self._slider_index_maps - - @slider_index_maps.setter - def slider_index_maps(self, maps): - self._maps = self._validate_slider_index_maps(maps) - - def _validate_slider_index_maps(self, maps): - if maps is not None: - if not all([callable(m) or m is None for m in maps]): - raise TypeError - - return maps - - def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: - pass - - -class NDImageProcessor(NDProcessor): - @property - def n_display_dims(self) -> Literal[2, 3]: - pass - - def _validate_n_display_dims(self, n_display_dims): - if n_display_dims not in (2, 3): - raise ValueError("`n_display_dims` must be") +from ...graphics import ImageGraphic, LineStack, LineCollection, ScatterGraphic +from ._processor_base import NDProcessor, WindowFuncCallable VALID_TIMESERIES_Y_DATA_SHAPES = ( @@ -145,32 +71,6 @@ def _validate_data(self, data: list[ArrayProtocol, ArrayProtocol]): return data - @property - def display_graphic(self) -> Literal["line", "heatmap"]: - return self._display_graphic - - @display_graphic.setter - def display_graphic(self, dg: Literal["line", "heatmap"]): - dg = self._validate_display_graphic(dg) - - if dg == "heatmap": - # check if x-vals uniformly spaced - # this is very fast to do on the fly, especially for typical small display windows - norm = np.linalg.norm(np.diff(np.diff(self.x_values))) / len(self.x_values) - if norm > 10 ** -12: - # need to create evenly spaced x-values - x0 = self.data[0][0] - xn = self.data[0][-1] - self._uniform_x_values = np.linspace(x0, xn, num=len(self.data[0])) - - # TODO: interpolate yz values on the fly only when within the display window - - def _validate_display_graphic(self, dg): - if dg not in ("line", "heatmap"): - raise ValueError - - return dg - @property def display_window(self) -> int | float | None: """display window in the reference units along the x-axis""" diff --git a/fastplotlib/widgets/nd_widget/_processor_base.py b/fastplotlib/widgets/nd_widget/_processor_base.py new file mode 100644 index 000000000..fa56e4b52 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_processor_base.py @@ -0,0 +1,74 @@ +import inspect +from typing import Literal, Callable, Any +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +class NDProcessor: + def __init__( + self, + data, + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + self._data = self._validate_data(data) + self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) + + @property + def data(self) -> ArrayProtocol: + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + self._data = self._validate_data(data) + + def _validate_data(self, data: ArrayProtocol): + if not isinstance(data, ArrayProtocol): + raise TypeError("`data` must implement the ArrayProtocol") + + return data + + @property + def window_funcs(self) -> tuple[WindowFuncCallable | None] | None: + pass + + @property + def window_sizes(self) -> tuple[int | None] | None: + pass + + @property + def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: + pass + + @property + def slider_dims(self) -> tuple[int, ...] | None: + pass + + @property + def slider_index_maps(self) -> tuple[Callable[[Any], int] | None, ...]: + return self._slider_index_maps + + @slider_index_maps.setter + def slider_index_maps(self, maps): + self._maps = self._validate_slider_index_maps(maps) + + def _validate_slider_index_maps(self, maps): + if maps is not None: + if not all([callable(m) or m is None for m in maps]): + raise TypeError + + return maps + + def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: + pass