diff --git a/fastplotlib/graphics/scatter.py b/fastplotlib/graphics/scatter.py index a2e696a82..5268dcc51 100644 --- a/fastplotlib/graphics/scatter.py +++ b/fastplotlib/graphics/scatter.py @@ -53,7 +53,7 @@ def __init__( image: np.ndarray = None, point_rotations: float | np.ndarray = 0, point_rotation_mode: Literal["uniform", "vertex", "curve"] = "uniform", - sizes: float | np.ndarray | Sequence[float] = 1, + sizes: float | np.ndarray | Sequence[float] = 5, uniform_size: bool = False, size_space: str = "screen", isolated_buffer: bool = True, diff --git a/fastplotlib/utils/__init__.py b/fastplotlib/utils/__init__.py index dd527ca67..8001ae375 100644 --- a/fastplotlib/utils/__init__.py +++ b/fastplotlib/utils/__init__.py @@ -6,6 +6,7 @@ from .gpu import enumerate_adapters, select_adapter, print_wgpu_report from ._plot_helpers import * from .enums import * +from ._protocols import ArrayProtocol @dataclass diff --git a/fastplotlib/utils/_protocols.py b/fastplotlib/utils/_protocols.py new file mode 100644 index 000000000..c168ecfa4 --- /dev/null +++ b/fastplotlib/utils/_protocols.py @@ -0,0 +1,12 @@ +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class ArrayProtocol(Protocol): + @property + def ndim(self) -> int: ... + + @property + def shape(self) -> tuple[int, ...]: ... + + def __getitem__(self, key): ... diff --git a/fastplotlib/widgets/nd_widget/__init__.py b/fastplotlib/widgets/nd_widget/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/fastplotlib/widgets/nd_widget/_nd_image.py b/fastplotlib/widgets/nd_widget/_nd_image.py new file mode 100644 index 000000000..f115e146e --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_image.py @@ -0,0 +1,13 @@ +from typing import Literal + +from ._processor_base import NDProcessor + + +class NDImageProcessor(NDProcessor): + @property + def n_display_dims(self) -> Literal[2, 3]: + pass + + def _validate_n_display_dims(self, n_display_dims): + if n_display_dims not in (2, 3): + raise ValueError("`n_display_dims` must be") diff --git a/fastplotlib/widgets/nd_widget/_nd_positions.py b/fastplotlib/widgets/nd_widget/_nd_positions.py new file mode 100644 index 000000000..db8c80e72 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_positions.py @@ -0,0 +1,137 @@ +import inspect +from typing import Literal, Callable, Any, Type +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + +from ...graphics import ImageGraphic, LineGraphic, LineStack, LineCollection, ScatterGraphic +from ._processor_base import NDProcessor + +# TODO: Maybe get rid of n_display_dims in NDProcessor, +# we will know the display dims automatically here from the last dim +# so maybe we only need it for images? +class NDPositionsProcessor(NDProcessor): + def __init__( + self, + data: ArrayProtocol, + multi: bool = False, # TODO: interpret [n - 2] dimension as n_lines or n_points + display_window: int | float | None = 100, # window for n_datapoints dim only + ): + super().__init__(data=data) + + self._display_window = display_window + + self.multi = multi + + def _validate_data(self, data: ArrayProtocol): + # TODO: determine right validation shape etc. + return data + + @property + def display_window(self) -> int | float | None: + """display window in the reference units for the n_datapoints dim""" + return self._display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + if dw is None: + self._display_window = None + + elif not isinstance(dw, (int, float)): + raise TypeError + + self._display_window = dw + + @property + def multi(self) -> bool: + return self._multi + + @multi.setter + def multi(self, m: bool): + if m and self.data.ndim < 3: + # p is p-datapoints, n is how many lines/scatter to show simultaneously + raise ValueError("ndim must be >= 3 for multi, shape must be [s1..., sn, n, p, 2 | 3]") + + self._multi = m + + def __getitem__(self, indices: tuple[Any, ...]): + """sliders through all slider dims and outputs an array that can be used to set graphic data""" + if self.display_window is not None: + indices_window = self.display_window + + # half window size + hw = indices_window // 2 + + # for now assume just a single index provided that indicates x axis value + start = max(indices - hw, 0) + stop = start + indices_window + + slices = [slice(start, stop)] + + # TODO: implement slicing for multiple slider dims, i.e. [s1, s2, ... n_datapoints, 2 | 3] + # this currently assumes the shape is: [n_datapoints, 2 | 3] + if self.multi: + # n - 2 dim is n_lines or n_scatters + slices.insert(0, slice(None)) + + return self.data[tuple(slices)] + + +class NDPositions: + def __init__(self, data, graphic: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic], multi: bool = False): + self._indices = 0 + + if issubclass(graphic, LineCollection): + multi = True + + self._processor = NDPositionsProcessor(data, multi=multi) + + self._create_graphic(graphic) + + @property + def processor(self) -> NDPositionsProcessor: + return self._processor + + @property + def graphic(self) -> LineGraphic | LineCollection | LineStack | ScatterGraphic | list[ScatterGraphic]: + """LineStack or ImageGraphic for heatmaps""" + return self._graphic + + @property + def indices(self) -> tuple: + return self._indices + + @indices.setter + def indices(self, indices): + data_slice = self.processor[indices] + + if isinstance(self.graphic, list): + # list of scatter + for i in range(len(self.graphic)): + # data_slice shape is [n_scatters, n_datapoints, 2 | 3] + # by using data_slice.shape[-1] it will auto-select if the data is only xy or has xyz + self.graphic[i].data[:, :data_slice.shape[-1]] = data_slice[i] + + elif isinstance(self.graphic, (LineGraphic, ScatterGraphic)): + self.graphic.data[:, :data_slice.shape[-1]] = data_slice + + elif isinstance(self.graphic, LineCollection): + for i in range(len(self.graphic)): + # data_slice shape is [n_lines, n_datapoints, 2 | 3] + self.graphic[i].data[:, :data_slice.shape[-1]] = data_slice[i] + + def _create_graphic(self, graphic_cls: Type[LineGraphic | LineCollection | LineStack | ScatterGraphic]): + if self.processor.multi and issubclass(graphic_cls, ScatterGraphic): + # make list of scatters + self._graphic = list() + data_slice = self.processor[self.indices] + for d in data_slice: + scatter = graphic_cls(d) + self._graphic.append(scatter) + + else: + data_slice = self.processor[self.indices] + self._graphic = graphic_cls(data_slice) diff --git a/fastplotlib/widgets/nd_widget/_nd_timeseries.py b/fastplotlib/widgets/nd_widget/_nd_timeseries.py new file mode 100644 index 000000000..8630044cf --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_nd_timeseries.py @@ -0,0 +1,223 @@ +import inspect +from typing import Literal, Callable, Any +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + +from ...graphics import ImageGraphic, LineStack, LineCollection, ScatterGraphic +from ._processor_base import NDProcessor, WindowFuncCallable + + +VALID_TIMESERIES_Y_DATA_SHAPES = ( + "[n_datapoints] for 1D array of y-values, [n_datapoints, 2] " + "for a 1D array of y and z-values, [n_lines, n_datapoints] for a 2D stack of lines with y-values, " + "or [n_lines, n_datapoints, 2] for a stack of lines with y and z-values." +) + + +# Limitation, no heatmap if z-values present, I don't think you can visualize that +class NDTimeSeriesProcessor(NDProcessor): + def __init__( + self, + data: list[ + ArrayProtocol, ArrayProtocol + ], # list: [x_vals_array, y_vals_and_z_vals_array] + x_values: ArrayProtocol = None, + cmap: str = None, + cmap_transform: ArrayProtocol = None, + display_graphic: Literal["line", "heatmap"] = "line", + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + display_window: int | float | None = 100, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + super().__init__( + data=data, + n_display_dims=n_display_dims, + slider_index_maps=slider_index_maps, + ) + + self._display_window = display_window + + self._display_graphic = None + self.display_graphic = display_graphic + + self._uniform_x_values: ArrayProtocol | None = None + self._interp_yz: ArrayProtocol | None = None + + @property + def data(self) -> list[ArrayProtocol, ArrayProtocol]: + return self._data + + @data.setter + def data(self, data: list[ArrayProtocol, ArrayProtocol]): + self._data = self._validate_data(data) + + def _validate_data(self, data: list[ArrayProtocol, ArrayProtocol]): + x_vals, yz_vals = data + + if x_vals.ndim != 1: + raise ("data x values must be 1D") + + if data[1].ndim > 3: + raise ValueError( + f"data yz values must be of shape: {VALID_TIMESERIES_Y_DATA_SHAPES}. You passed data of shape: {yz_vals.shape}" + ) + + return data + + @property + def display_window(self) -> int | float | None: + """display window in the reference units along the x-axis""" + return self._display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + if dw is None: + self._display_window = None + + elif not isinstance(dw, (int, float)): + raise TypeError + + self._display_window = dw + + def __getitem__(self, indices: tuple[Any, ...]) -> ArrayProtocol: + if self.display_window is not None: + # map reference units -> array int indices if necessary + if self.slider_index_maps is not None: + indices_window = self.slider_index_maps(self.display_window) + else: + indices_window = self.display_window + + # half window size + hw = indices_window // 2 + + # for now assume just a single index provided that indicates x axis value + start = max(indices - hw, 0) + stop = start + indices_window + + # slice dim would be ndim - 1 + return self.data[0][start:stop], self.data[1][:, start:stop] + + +class NDTimeSeries: + def __init__(self, processor: NDTimeSeriesProcessor, graphic): + self._processor = processor + + self._indices = 0 + + if graphic == "line": + self._create_line_stack() + elif graphic == "heatmap": + self._create_heatmap() + else: + raise ValueError + + @property + def processor(self) -> NDTimeSeriesProcessor: + return self._processor + + @property + def graphic(self) -> LineStack | ImageGraphic: + """LineStack or ImageGraphic for heatmaps""" + return self._graphic + + @graphic.setter + def graphic(self, g: Literal["line", "heatmap"]): + if g == "line": + # TODO: remove existing graphic + self._create_line_stack() + + elif g == "heatmap": + # make sure "yz" data is only ys and no z values + # can't represent y and z vals in a heatmap + if self.processor.data[1].ndim > 2: + raise ValueError("Only y-values are supported for heatmaps, not yz-values") + self._create_heatmap() + + @property + def display_window(self) -> int | float | None: + return self.processor.display_window + + @display_window.setter + def display_window(self, dw: int | float | None): + # create new graphic if it changed + if dw != self.display_window: + create_new_graphic = True + else: + create_new_graphic = False + + self.processor.display_window = dw + + if create_new_graphic: + if isinstance(self.graphic, LineStack): + self.set_index(self._indices) + + def set_index(self, indices: tuple[Any, ...]): + # set the graphic at the given data indices + data_slice = self.processor[indices] + + if isinstance(self.graphic, LineStack): + line_stack_data = self._create_line_stack_data(data_slice) + + for g, line_data in zip(self.graphic.graphics, line_stack_data): + if line_data.shape[1] == 2: + # only x and y values + g.data[:, :-1] = line_data + else: + # has z values too + g.data[:] = line_data + + elif isinstance(self.graphic, ImageGraphic): + hm_data, scale = self._create_heatmap_data(data_slice) + self.graphic.data = hm_data + + self._indices = indices + + def _create_line_stack_data(self, data_slice): + xs = data_slice[0] # 1D + yz = data_slice[1] # [n_lines, n_datapoints] for y-vals or [n_lines, n_datapoints, 2] for yz-vals + + # need to go from x_vals and yz_vals arrays to an array of shape: [n_lines, n_datapoints, 2 | 3] + return np.dstack([np.repeat(xs[None], repeats=yz.shape[0], axis=0), yz]) + + def _create_line_stack(self): + data_slice = self.processor[self._indices] + + ls_data = self._create_line_stack_data(data_slice) + + self._graphic = LineStack(ls_data) + + def _create_heatmap_data(self, data_slice) -> tuple[ArrayProtocol, float]: + """Returns [n_lines, y_values] array and scale factor for x dimension""" + # check if x-vals uniformly spaced + # this is very fast to do on the fly, especially for typical small display windows + x, y = data_slice + norm = np.linalg.norm(np.diff(np.diff(x))) / x.size + if norm > 10 ** -12: + # need to create evenly spaced x-values + x_uniform = np.linspace(x[0], x[-1], num=x.size) + # yz is [n_lines, n_datapoints] + y_interp = np.zeros(shape=y.shape, dtype=np.float32) + for i in range(y.shape[0]): + y_interp[i] = np.interp(x_uniform, x, y[i]) + + else: + y_interp = y + + x_scale = x[-1] / x.size + + return y_interp, x_scale + + def _create_heatmap(self): + data_slice = self.processor[self._indices] + + hm_data, x_scale = self._create_heatmap_data(data_slice) + + self._graphic = ImageGraphic(hm_data) + self._graphic.world_object.world.scale_x = x_scale \ No newline at end of file diff --git a/fastplotlib/widgets/nd_widget/_processor_base.py b/fastplotlib/widgets/nd_widget/_processor_base.py new file mode 100644 index 000000000..fa56e4b52 --- /dev/null +++ b/fastplotlib/widgets/nd_widget/_processor_base.py @@ -0,0 +1,74 @@ +import inspect +from typing import Literal, Callable, Any +from warnings import warn + +import numpy as np +from numpy.typing import ArrayLike + +from ...utils import subsample_array, ArrayProtocol + + +# must take arguments: array-like, `axis`: int, `keepdims`: bool +WindowFuncCallable = Callable[[ArrayLike, int, bool], ArrayLike] + + +class NDProcessor: + def __init__( + self, + data, + n_display_dims: Literal[2, 3] = 2, + slider_index_maps: tuple[Callable[[Any], int] | None, ...] | None = None, + window_funcs: tuple[WindowFuncCallable | None] | None = None, + window_sizes: tuple[int | None] | None = None, + spatial_func: Callable[[ArrayProtocol], ArrayProtocol] | None = None, + ): + self._data = self._validate_data(data) + self._slider_index_maps = self._validate_slider_index_maps(slider_index_maps) + + @property + def data(self) -> ArrayProtocol: + return self._data + + @data.setter + def data(self, data: ArrayProtocol): + self._data = self._validate_data(data) + + def _validate_data(self, data: ArrayProtocol): + if not isinstance(data, ArrayProtocol): + raise TypeError("`data` must implement the ArrayProtocol") + + return data + + @property + def window_funcs(self) -> tuple[WindowFuncCallable | None] | None: + pass + + @property + def window_sizes(self) -> tuple[int | None] | None: + pass + + @property + def spatial_func(self) -> Callable[[ArrayProtocol], ArrayProtocol] | None: + pass + + @property + def slider_dims(self) -> tuple[int, ...] | None: + pass + + @property + def slider_index_maps(self) -> tuple[Callable[[Any], int] | None, ...]: + return self._slider_index_maps + + @slider_index_maps.setter + def slider_index_maps(self, maps): + self._maps = self._validate_slider_index_maps(maps) + + def _validate_slider_index_maps(self, maps): + if maps is not None: + if not all([callable(m) or m is None for m in maps]): + raise TypeError + + return maps + + def __getitem__(self, item: tuple[Any, ...]) -> ArrayProtocol: + pass