diff --git a/examples/gridplot_simple.py b/examples/gridplot_simple.py new file mode 100644 index 000000000..382f2b486 --- /dev/null +++ b/examples/gridplot_simple.py @@ -0,0 +1,53 @@ +import numpy as np +from wgpu.gui.auto import WgpuCanvas +import pygfx as gfx +from fastplotlib.layouts import GridPlot +from fastplotlib.graphics import ImageGraphic, LineGraphic, HistogramGraphic +from fastplotlib import run +from math import sin, cos, radians + +# GridPlot of shape 2 x 3 +grid_plot = GridPlot(shape=(2, 3)) + +image_graphics = list() + +hist_data1 = np.random.normal(0, 256, 2048) +hist_data2 = np.random.poisson(0, 256) + +# Make a random image graphic for each subplot +for i, subplot in enumerate(grid_plot): + img = np.random.rand(512, 512) * 255 + ig = ImageGraphic(data=img, vmin=0, vmax=255, cmap='gnuplot2') + image_graphics.append(ig) + + # add the graphic to the subplot + subplot.add_graphic(ig) + + histogram = HistogramGraphic(data=hist_data1, bins=100) + histogram.world_object.rotation.w = cos(radians(45)) + histogram.world_object.rotation.z = sin(radians(45)) + + histogram.world_object.scale.y = 1 + histogram.world_object.scale.x = 8 + + for dv_position in ["right", "top", "bottom", "left"]: + h2 = HistogramGraphic(data=hist_data1, bins=100) + + subplot.docked_viewports[dv_position].size = 60 + subplot.docked_viewports[dv_position].add_graphic(h2) +# + +# Define a function to update the image graphics +# with new randomly generated data +def set_random_frame(): + for ig in image_graphics: + new_data = np.random.rand(512, 512) * 255 + ig.update_data(data=new_data) + + +# add the animation +# grid_plot.add_animations(set_random_frame) + +grid_plot.show() + +run() diff --git a/fastplotlib/__init__.py b/fastplotlib/__init__.py index d11a4e08b..575288b33 100644 --- a/fastplotlib/__init__.py +++ b/fastplotlib/__init__.py @@ -1,5 +1,6 @@ from .plot import Plot from pathlib import Path +from wgpu.gui.auto import run with open(Path(__file__).parent.joinpath("VERSION"), "r") as f: diff --git a/fastplotlib/layouts/_base.py b/fastplotlib/layouts/_base.py new file mode 100644 index 000000000..68221ad11 --- /dev/null +++ b/fastplotlib/layouts/_base.py @@ -0,0 +1,182 @@ +from pygfx import Scene, OrthographicCamera, PerspectiveCamera, PanZoomController, OrbitController, \ + Viewport, WgpuRenderer +from wgpu.gui.auto import WgpuCanvas +from warnings import warn +from ..graphics._base import Graphic +from typing import * + + +class PlotArea: + def __init__( + self, + parent, + position: Any, + camera: Union[OrthographicCamera, PerspectiveCamera], + controller: Union[PanZoomController, OrbitController], + scene: Scene, + canvas: WgpuCanvas, + renderer: WgpuRenderer, + name: str = None, + ): + self._parent: PlotArea = parent + self._position = position + + self._scene = scene + self._canvas = canvas + self._renderer = renderer + if parent is None: + self._viewport: Viewport = Viewport(renderer) + else: + self._viewport = Viewport(parent.renderer) + + self._camera = camera + self._controller = controller + + self.controller.add_default_event_handlers( + self.viewport, + self.camera + ) + + self.renderer.add_event_handler(self.set_viewport_rect, "resize") + + self._graphics: List[Graphic] = list() + + self.name = name + + # need to think about how to deal with children better + self.children = list() + + self.set_viewport_rect() + + # several read-only properties + @property + def parent(self): + return self._parent + + @property + def position(self) -> Union[Tuple[int, int], Any]: + """Used by subclass based on its referencing system""" + return self._position + + @property + def scene(self) -> Scene: + return self._scene + + @property + def canvas(self) -> WgpuCanvas: + return self._canvas + + @property + def renderer(self) -> WgpuRenderer: + return self._renderer + + @property + def viewport(self) -> Viewport: + return self._viewport + + @property + def camera(self) -> Union[OrthographicCamera, PerspectiveCamera]: + return self._camera + + # in the future we can think about how to allow changing the controller + @property + def controller(self) -> Union[PanZoomController, OrbitController]: + return self._controller + + def get_rect(self) -> Tuple[float, float, float, float]: + """allows setting the region occupied by the viewport w.r.t. the parent""" + raise NotImplementedError("Must be implemented in subclass") + + def set_viewport_rect(self, *args): + self.viewport.rect = self.get_rect() + + def render(self): + self.controller.update_camera(self.camera) + self.viewport.render(self.scene, self.camera) + + for child in self.children: + child.render() + + def add_graphic(self, graphic, center: bool = True): + if graphic.name is not None: # skip for those that have no name + graphic_names = list() + + for g in self._graphics: + graphic_names.append(g.name) + + if graphic.name in graphic_names: + raise ValueError(f"graphics must have unique names, current graphic names are:\n {graphic_names}") + + self._graphics.append(graphic) + self.scene.add(graphic.world_object) + + if center: + self.center_graphic(graphic) + + def _refresh_camera(self): + self.controller.update_camera(self.camera) + if sum(self.renderer.logical_size) > 0: + scene_lsize = self.viewport.rect[2], self.viewport.rect[3] + else: + scene_lsize = (1, 1) + + self.camera.set_view_size(*scene_lsize) + self.camera.update_projection_matrix() + + def center_graphic(self, graphic, zoom: float = 1.3): + if not isinstance(self.camera, OrthographicCamera): + warn("`center_graphic()` not yet implemented for `PerspectiveCamera`") + return + + self._refresh_camera() + + self.controller.show_object(self.camera, graphic.world_object) + + self.controller.zoom(zoom) + + def center_scene(self, zoom: float = 1.3): + if not len(self.scene.children) > 0: + return + + if not isinstance(self.camera, OrthographicCamera): + warn("`center_scene()` not yet implemented for `PerspectiveCamera`") + return + + self._refresh_camera() + + self.controller.show_object(self.camera, self.scene) + + self.controller.zoom(zoom) + + def get_graphics(self): + return self._graphics + + def remove_graphic(self, graphic): + self.scene.remove(graphic.world_object) + + def __getitem__(self, name: str): + for graphic in self._graphics: + if graphic.name == name: + return graphic + + graphic_names = list() + for g in self._graphics: + graphic_names.append(g.name) + raise IndexError(f"no graphic of given name, the current graphics are:\n {graphic_names}") + + def __str__(self): + if self.name is None: + name = "unnamed" + else: + name = self.name + + return f"{name}: {self.__class__.__name__} @ {hex(id(self))}" + + def __repr__(self): + newline = "\n\t" + + return f"{self}\n" \ + f" parent: {self.parent}\n" \ + f" Graphics:\n" \ + f"\t{newline.join(graphic.__repr__() for graphic in self.get_graphics())}" \ + f"\n" diff --git a/fastplotlib/layouts/_gridplot.py b/fastplotlib/layouts/_gridplot.py index e5e4884e4..f6899c485 100644 --- a/fastplotlib/layouts/_gridplot.py +++ b/fastplotlib/layouts/_gridplot.py @@ -154,9 +154,9 @@ def __getitem__(self, index: Union[Tuple[int, int], str]): else: return self._subplots[index[0], index[1]] - def animate(self): + def render(self): for subplot in self: - subplot.animate(self.canvas.get_logical_size()) + subplot.render() for f in self._animate_funcs: f() @@ -173,7 +173,7 @@ def add_animations(self, *funcs: callable): self._animate_funcs += funcs def show(self): - self.canvas.request_draw(self.animate) + self.canvas.request_draw(self.render) for subplot in self: subplot.center_scene() @@ -193,6 +193,3 @@ def __next__(self) -> Subplot: def __repr__(self): return f"fastplotlib.{self.__class__.__name__} @ {hex(id(self))}\n" - - - diff --git a/fastplotlib/layouts/_subplot.py b/fastplotlib/layouts/_subplot.py index e2ab465a3..bb847c4ec 100644 --- a/fastplotlib/layouts/_subplot.py +++ b/fastplotlib/layouts/_subplot.py @@ -1,65 +1,46 @@ -import pygfx -from pygfx import Scene, OrthographicCamera, PerspectiveCamera, PanZoomController, Viewport, AxesHelper, GridHelper +from pygfx import Scene, OrthographicCamera, PanZoomController, OrbitOrthoController, \ + AxesHelper, GridHelper, WgpuRenderer, Background, BackgroundMaterial from ..graphics import HeatmapGraphic from ._defaults import create_camera, create_controller from typing import * from wgpu.gui.auto import WgpuCanvas -from warnings import warn +import numpy as np from math import copysign +from ._base import PlotArea -class Subplot: +class Subplot(PlotArea): def __init__( self, position: Tuple[int, int] = None, parent_dims: Tuple[int, int] = None, camera: str = '2d', - controller: Union[pygfx.PanZoomController, pygfx.OrbitOrthoController] = None, + controller: Union[PanZoomController, OrbitOrthoController] = None, canvas: WgpuCanvas = None, - renderer: pygfx.Renderer = None, + renderer: WgpuRenderer = None, + name: str = None, **kwargs ): - self.scene: pygfx.Scene = pygfx.Scene() - - self._graphics = list() - if canvas is None: canvas = WgpuCanvas() if renderer is None: - renderer = pygfx.renderers.WgpuRenderer(canvas) - - self.canvas = canvas - self.renderer = renderer - - if "name" in kwargs.keys(): - self.name = kwargs["name"] - else: - self.name = None + renderer = WgpuRenderer(canvas) if position is None: position = (0, 0) - self.position: Tuple[int, int] = position if parent_dims is None: parent_dims = (1, 1) self.nrows, self.ncols = parent_dims - self.camera: Union[pygfx.OrthographicCamera, pygfx.PerspectiveCamera] = create_camera(camera) - if controller is None: controller = create_controller(camera) - self.controller: Union[pygfx.PanZoomController, pygfx.OrbitOrthoController] = controller - # might be better as an attribute of GridPlot - # but easier to iterate when in same object as camera and scene - self.viewport: pygfx.Viewport = pygfx.Viewport(renderer) + self.docked_viewports = dict() - self.controller.add_default_event_handlers( - self.viewport, - self.camera - ) + self.spacing = 2 self._axes: AxesHelper = AxesHelper(size=100) for arrow in self._axes.children: @@ -69,91 +50,64 @@ def __init__( self._animate_funcs = list() - self.renderer.add_event_handler(self._produce_rect, "resize") + super(Subplot, self).__init__( + parent=None, + position=position, + camera=create_camera(camera), + controller=controller, + scene=Scene(), + canvas=canvas, + renderer=renderer, + name=name + ) - def _produce_rect(self, *args):#, w, h): - i, j = self.position + for pos in ["left", "top", "right", "bottom"]: + dv = _DockedViewport(self, pos, size=0) + dv.name = pos + self.docked_viewports[pos] = dv + self.children.append(dv) - w, h = self.renderer.logical_size + def get_rect(self): + row_ix, col_ix = self.position + width_canvas, height_canvas = self.renderer.logical_size - spacing = 2 # spacing in pixels + x_pos = ((width_canvas / self.ncols) + ((col_ix - 1) * (width_canvas / self.ncols))) + self.spacing + y_pos = ((height_canvas / self.nrows) + ((row_ix - 1) * (height_canvas / self.nrows))) + self.spacing + width_subplot = (width_canvas / self.ncols) - self.spacing + height_suplot = (height_canvas / self.nrows) - self.spacing + + rect = np.array([ + x_pos, + y_pos, + width_subplot, + height_suplot + ]) + + for dv in self.docked_viewports.values(): + rect = rect + dv.get_parent_rect_adjust() - self.viewport.rect = [ - ((w / self.ncols) + ((j - 1) * (w / self.ncols))) + spacing, - ((h / self.nrows) + ((i - 1) * (h / self.nrows))) + spacing, - (w / self.ncols) - spacing, - (h / self.nrows) - spacing - ] + return rect - def animate(self, canvas_dims: Tuple[int, int] = None): - self.controller.update_camera(self.camera) - self.viewport.render(self.scene, self.camera) + def render(self): + super(Subplot, self).render() for f in self._animate_funcs: f() def add_animations(self, *funcs: callable): - for f in funcs: - if not callable(f): - raise TypeError( - f"all positional arguments to add_animations() must be callable types, you have passed a: {type(f)}" - ) - self._animate_funcs += funcs + if not all([callable(f) for f in funcs]): + raise TypeError( + f"all positional arguments to add_animations() must be callable types" + ) - def add_graphic(self, graphic, center: bool = True): - if graphic.name is not None: # skip for those that have no name - graphic_names = list() - - for g in self._graphics: - graphic_names.append(g.name) - - if graphic.name in graphic_names: - raise ValueError(f"graphics must have unique names, current graphic names are:\n {graphic_names}") + self._animate_funcs += funcs - self._graphics.append(graphic) - self.scene.add(graphic.world_object) + def add_graphic(self, graphic, center: bool = True): + super(Subplot, self).add_graphic(graphic, center) if isinstance(graphic, HeatmapGraphic): self.controller.scale.y = copysign(self.controller.scale.y, -1) - if center: - self.center_graphic(graphic) - - def _refresh_camera(self): - self.controller.update_camera(self.camera) - if sum(self.renderer.logical_size) > 0: - scene_lsize = self.viewport.rect[2], self.viewport.rect[3] - else: - scene_lsize = (1, 1) - - self.camera.set_view_size(*scene_lsize) - self.camera.update_projection_matrix() - - def center_graphic(self, graphic, zoom: float = 1.3): - if not isinstance(self.camera, pygfx.OrthographicCamera): - warn("`center_graphic()` not yet implemented for `PerspectiveCamera`") - return - - self._refresh_camera() - - self.controller.show_object(self.camera, graphic.world_object) - - self.controller.zoom(zoom) - - def center_scene(self, zoom: float = 1.3): - if not len(self.scene.children) > 0: - return - - if not isinstance(self.camera, pygfx.OrthographicCamera): - warn("`center_scene()` not yet implemented for `PerspectiveCamera`") - return - - self._refresh_camera() - - self.controller.show_object(self.camera, self.scene) - - self.controller.zoom(zoom) - def set_axes_visibility(self, visible: bool): if visible: self.scene.add(self._axes) @@ -166,30 +120,123 @@ def set_grid_visibility(self, visible: bool): else: self.scene.remove(self._grid) - def remove_graphic(self, graphic): - self.scene.remove(graphic.world_object) - def get_graphics(self): - return self._graphics +class _DockedViewport(PlotArea): + _valid_positions = [ + "right", + "left", + "top", + "bottom" + ] + + def __init__( + self, + parent: Subplot, + position: str, + size: int, + ): + if position not in self._valid_positions: + raise ValueError(f"the `position` of an AnchoredViewport must be one of: {self._valid_positions}") + + self._size = size + + super(_DockedViewport, self).__init__( + parent=parent, + position=position, + camera=OrthographicCamera(), + controller=PanZoomController(), + scene=Scene(), + canvas=parent.canvas, + renderer=parent.renderer + ) + + self.scene.add( + Background(None, BackgroundMaterial((0.2, 0.0, 0, 1), (0, 0.0, 0.2, 1))) + ) - def __getitem__(self, name: str): - for graphic in self._graphics: - if graphic.name == name: - return graphic + @property + def size(self) -> int: + return self._size - graphic_names = list() - for g in self._graphics: - graphic_names.append(g.name) - raise IndexError(f"no graphic of given name, the current graphics are:\n {graphic_names}") + @size.setter + def size(self, s: int): + self._size = s + self.parent.set_viewport_rect() + self.set_viewport_rect() - def __repr__(self): - newline = "\n " - if self.name is not None: - return f"'{self.name}' fastplotlib.{self.__class__.__name__} @ {hex(id(self))}\n" \ - f"Graphics: \n " \ - f"{newline.join(graphic.__repr__() for graphic in self.get_graphics())}" + def get_rect(self, *args): + if self.size == 0: + self.viewport.rect = None + return + + row_ix_parent, col_ix_parent = self.parent.position + width_canvas, height_canvas = self.parent.renderer.logical_size + + spacing = 2 # spacing in pixels + + if self.position == "right": + x_pos = (width_canvas / self.parent.ncols) + ((col_ix_parent - 1) * (width_canvas / self.parent.ncols)) + (width_canvas / self.parent.ncols) - self.size + y_pos = ((height_canvas / self.parent.nrows) + ((row_ix_parent - 1) * (height_canvas / self.parent.nrows))) + spacing + width_viewport = self.size + height_viewport = (height_canvas / self.parent.nrows) - spacing + + elif self.position == "left": + x_pos = (width_canvas / self.parent.ncols) + ((col_ix_parent - 1) * (width_canvas / self.parent.ncols)) + y_pos = ((height_canvas / self.parent.nrows) + ((row_ix_parent - 1) * (height_canvas / self.parent.nrows))) + spacing + width_viewport = self.size + height_viewport = (height_canvas / self.parent.nrows) - spacing + + elif self.position == "top": + x_pos = (width_canvas / self.parent.ncols) + ((col_ix_parent - 1) * (width_canvas / self.parent.ncols)) + spacing + y_pos = ((height_canvas / self.parent.nrows) + ((row_ix_parent - 1) * (height_canvas / self.parent.nrows))) + spacing + width_viewport = (width_canvas / self.parent.ncols) - spacing + height_viewport = self.size + + elif self.position == "bottom": + x_pos = (width_canvas / self.parent.ncols) + ((col_ix_parent - 1) * (width_canvas / self.parent.ncols)) + spacing + y_pos = ((height_canvas / self.parent.nrows) + ((row_ix_parent - 1) * (height_canvas / self.parent.nrows))) + (height_canvas / self.parent.nrows) - self.size + width_viewport = (width_canvas / self.parent.ncols) - spacing + height_viewport = self.size else: - return f"fastplotlib.{self.__class__.__name__} @ {hex(id(self))} \n" \ - f"Graphics: \n " \ - f"{newline.join(graphic.__repr__() for graphic in self.get_graphics())}" + raise ValueError("invalid position") + + return [x_pos, y_pos, width_viewport, height_viewport] + + def get_parent_rect_adjust(self): + if self.position == "right": + return np.array([ + 0, # parent subplot x-position is same + 0, + -self.size, # width of parent subplot is `self.size` smaller + 0 + ]) + + elif self.position == "left": + return np.array([ + self.size, # `self.size` added to parent subplot x-position + 0, + -self.size, # width of parent subplot is `self.size` smaller + 0 + ]) + + elif self.position == "top": + return np.array([ + 0, + self.size, # `self.size` added to parent subplot y-position + 0, + -self.size, # height of parent subplot is `self.size` smaller + ]) + + elif self.position == "bottom": + return np.array([ + 0, + 0, # parent subplot y-position is same, + 0, + -self.size, # height of parent subplot is `self.size` smaller + ]) + + def render(self): + if self.size == 0: + return + super(_DockedViewport, self).render() diff --git a/fastplotlib/plot.py b/fastplotlib/plot.py index 491b223ec..7fc26ecd3 100644 --- a/fastplotlib/plot.py +++ b/fastplotlib/plot.py @@ -39,14 +39,14 @@ def _create_graphic(self, graphic_class, *args, **kwargs): return graphic - def animate(self): - super(Plot, self).animate(canvas_dims=None) + def render(self): + super(Plot, self).render() self.renderer.flush() self.canvas.request_draw() def show(self): - self.canvas.request_draw(self.animate) + self.canvas.request_draw(self.render) self.center_scene() return self.canvas