diff --git a/changelog.md b/changelog.md index 500e1c18..9fbd6059 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,9 @@ ## New features +* Add `GraphWidget` methods to change render options in place without re-rendering: `set_layout`, `set_zoom`, `set_pan`, `set_renderer`, and `set_show_layout_button` +* Add `GraphWidget` methods to change styling in place without re-rendering such as `color_relationships` + ## Bug fixes ## Improvements diff --git a/docs/source/api-reference/widget.rst b/docs/source/api-reference/widget.rst new file mode 100644 index 00000000..8ec6510f --- /dev/null +++ b/docs/source/api-reference/widget.rst @@ -0,0 +1,2 @@ +.. autoclass:: neo4j_viz.GraphWidget + :members: diff --git a/examples/getting-started.ipynb b/examples/getting-started.ipynb index 2ae449b0..7a012a5a 100644 --- a/examples/getting-started.ipynb +++ b/examples/getting-started.ipynb @@ -1633,12 +1633,12 @@ "

Expected window.__NEO4J_VIZ_DATA__ to be set.

\n", "

This page should be generated by neo4j_viz's render() method.

\n", " \n", - " `,new Error(\"window.__NEO4J_VIZ_DATA__ is not defined\");const ypr={get(t){return x3[t]},on(){},off(){},set(){},save_changes(){}},_3=document.getElementById(\"neo4j-viz-f81941349bdf\");if(!_3)throw new Error(\"Container element #neo4j-viz-f81941349bdf not found\");_3.style.width=x3.width??\"100%\";_3.style.height=x3.height??\"100vh\";mpr.render({model:ypr,el:_3});\n", + " `,new Error(\"window.__NEO4J_VIZ_DATA__ is not defined\");const ypr={get(t){return x3[t]},on(){},off(){},set(){},save_changes(){}},_3=document.getElementById(\"neo4j-viz-75a25c3af547\");if(!_3)throw new Error(\"Container element #neo4j-viz-75a25c3af547 not found\");_3.style.width=x3.width??\"100%\";_3.style.height=x3.height??\"100vh\";mpr.render({model:ypr,el:_3});\n", " \n", - " \n", + " \n", "\n", " \n", - "
\n", + "
\n", " \n", "\n" ], @@ -1681,1660 +1681,28 @@ "\n", "VG = VisualizationGraph(nodes=nodes, relationships=relationships)\n", "\n", - "VG.render(initial_zoom=2)" - ] - }, - { - "cell_type": "markdown", - "id": "365a1c31", - "metadata": {}, - "source": [ - "As we can see in the graph above, the radius of one of the nodes is larger than the others.\n", - "This is because we set the \"size\" field of the node to 20, while the others are set to 10.\n", - "\n", - "At this time all nodes have the same color.\n", - "If we want to distinguish between the different types of nodes, we can color them differently with the `color_nodes` method.\n", - "We can pass the field we want to use to color the nodes as an argument.\n", - "In this case, we will use the \"caption\" field.\n", - "Nodes with the same \"caption\" will have the same color.\n", - "We will use the default colorscheme, which is the Neo4j colorscheme.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "d935b3d4", - "metadata": { - "tags": [ - "preserve-output" - ] - }, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "\n", - " \n", - " \n", - " \n", - " neo4j-viz\n", - " \n", - " \n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - " \n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ "VG.color_nodes(field=\"size\")\n", "VG.set_node_captions(field=\"size\")\n", "\n", "VG.render(initial_zoom=2)" ] }, + { + "cell_type": "markdown", + "id": "365a1c31", + "metadata": {}, + "source": [ + "As we can see in the graph above, the radius of one of the nodes is larger than the others.\n", + "This is because we set the \"size\" field of the node to 20, while the others are set to 10.\n", + "\n", + "At this time all nodes have the same color.\n", + "If we want to distinguish between the different types of nodes, we can color them differently with the `color_nodes` method.\n", + "We can pass the field we want to use to color the nodes as an argument.\n", + "In this case, we will use the \"caption\" field.\n", + "Nodes with the same \"caption\" will have the same color.\n", + "We will use the default colorscheme, which is the Neo4j colorscheme.\n" + ] + }, { "cell_type": "markdown", "id": "a28bd5aa", @@ -3358,10 +1726,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "6j6duo4v7p9", - "metadata": {}, - "outputs": [], + "metadata": { + "tags": [ + "preserve-output" + ] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8f9849af878743d4b73329f4fd7cc977", + "version_major": 2, + "version_minor": 1 + }, + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "widget = VG.render_widget()\n", "widget" @@ -3395,6 +1783,37 @@ "\n", "widget.add_data(nodes=new_node, relationships=new_rel)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "686e0beb", + "metadata": {}, + "outputs": [], + "source": [ + "widget.color_relationships(field=\"caption\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c68712f0", + "metadata": {}, + "outputs": [], + "source": [ + "widget.nodes[0].size = 50\n", + "widget.sync_nodes() # manually trigger sync to update widget" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f174b6ed00027bf5", + "metadata": {}, + "outputs": [], + "source": [ + "widget.set_zoom(1.5) # change the rendering options dynamically" + ] } ], "metadata": { diff --git a/python-wrapper/src/neo4j_viz/_graph_entity_operations.py b/python-wrapper/src/neo4j_viz/_graph_entity_operations.py new file mode 100644 index 00000000..5ef0ee8a --- /dev/null +++ b/python-wrapper/src/neo4j_viz/_graph_entity_operations.py @@ -0,0 +1,408 @@ +from __future__ import annotations + +import warnings +from collections.abc import Hashable, Iterable +from typing import Any, Callable, Protocol + +from pydantic.alias_generators import to_snake +from pydantic_extra_types.color import Color, ColorType + +from .colors import NEO4J_COLORS_CONTINUOUS, NEO4J_COLORS_DISCRETE, ColorSpace, ColorsType +from .node import Node, NodeIdType +from .node_size import RealNumber, verify_radii +from .relationship import Relationship + + +class EntityHost(Protocol): + """The interface a host must expose to be driven by `GraphEntityOperations`.""" + + nodes: list[Node] + relationships: list[Relationship] + + def _sync_entities(self, *, nodes: bool = ..., relationships: bool = ...) -> None: ... + + +class GraphEntityOperations: + """Recolor, resize, caption and pin operations over a host's graph entities. + + This is a composable component: it does not own the data, but reads the `nodes` and + `relationships` from its `host` and mutates the entities in place. After each mutation + it calls the host's `_sync_entities` hook so the host can react (e.g. the widget pushes + the changes to its frontend). + """ + + def __init__(self, host: EntityHost) -> None: + self._host = host + + @property + def nodes(self) -> list[Node]: + return self._host.nodes + + @property + def relationships(self) -> list[Relationship]: + return self._host.relationships + + def toggle_nodes_pinned(self, pinned: dict[NodeIdType, bool]) -> None: + """Pin or unpin nodes. See `VisualizationGraph.toggle_nodes_pinned` for details.""" + for node in self.nodes: + node_pinned = pinned.get(node.id) + + if node_pinned is None: + continue + + node.pinned = node_pinned + + self._host._sync_entities(nodes=True) + + def set_node_captions( + self, + *, + field: str | None = None, + property: str | None = None, + override: bool = True, + ) -> None: + """Set node captions from a field or property. See `VisualizationGraph.set_node_captions` for details.""" + if not ((field is None) ^ (property is None)): + raise ValueError( + f"Exactly one of the arguments `field` (received '{field}') and `property` (received '{property}') must be provided" + ) + + if property: + # Use property + for node in self.nodes: + if not override and node.caption is not None: + continue + + value = node.properties.get(property, "") + node.caption = str(value) + else: + # Use field + assert field is not None + attribute = to_snake(field) + + for node in self.nodes: + if not override and node.caption is not None: + continue + + value = getattr(node, attribute, "") + node.caption = str(value) + + self._host._sync_entities(nodes=True) + + def resize_nodes( + self, + sizes: dict[NodeIdType, RealNumber] | None = None, + node_radius_min_max: tuple[RealNumber, RealNumber] | None = (3, 60), + property: str | None = None, + ) -> None: + """Resize nodes from explicit sizes or a property. See `VisualizationGraph.resize_nodes` for details.""" + if sizes is not None and property is not None: + raise ValueError("At most one of the arguments `sizes` and `property` can be provided") + + if sizes is None and property is None and node_radius_min_max is None: + raise ValueError("At least one of `sizes`, `property` or `node_radius_min_max` must be given") + + # Gather node sizes + all_sizes = {} + if sizes is not None: + for node in self.nodes: + size = sizes.get(node.id, node.size) + if size is not None: + all_sizes[node.id] = size + elif property is not None: + for node in self.nodes: + size = node.properties.get(property, node.size) + if size is not None: + all_sizes[node.id] = size + else: + for node in self.nodes: + if node.size is not None: + all_sizes[node.id] = node.size + + # Validate node sizes + for id, size in all_sizes.items(): + if size is None: + continue + + if not isinstance(size, (int, float)): + raise ValueError(f"Size for node '{id}' must be a real number, but was {size}") + + if size < 0: + raise ValueError(f"Size for node '{id}' must be non-negative, but was {size}") + + if node_radius_min_max is not None: + verify_radii(node_radius_min_max) + + final_sizes = self._normalize_values(all_sizes, node_radius_min_max) + else: + final_sizes = all_sizes + + # Apply the final sizes to the nodes + for node in self.nodes: + size = final_sizes.get(node.id) + + if size is None: + continue + + node.size = size + + self._host._sync_entities(nodes=True) + + def resize_relationships( + self, + widths: dict[str | int, RealNumber] | None = None, + property: str | None = None, + ) -> None: + """Resize relationship widths from explicit widths or a property. See `VisualizationGraph.resize_relationships` for details.""" + if widths is not None and property is not None: + raise ValueError("At most one of the arguments `widths` and `property` can be provided") + + if widths is None and property is None: + raise ValueError("At least one of `widths` or `property` must be given") + + # Gather relationship widths + all_widths = {} + if widths is not None: + for rel in self.relationships: + width = widths.get(rel.id, rel.width) + if width is not None: + all_widths[rel.id] = width + elif property is not None: + for rel in self.relationships: + width = rel.properties.get(property, rel.width) + if width is not None: + all_widths[rel.id] = width + + # Validate and apply relationship widths + for rel in self.relationships: + width = all_widths.get(rel.id) + + if width is None: + continue + + if not isinstance(width, (int, float)): + raise ValueError(f"Width for relationship '{rel.id}' must be a real number, but was {width}") + + if width <= 0: + raise ValueError(f"Width for relationship '{rel.id}' must be positive, but was {width}") + + rel.width = width + + self._host._sync_entities(relationships=True) + + @staticmethod + def _normalize_values( + node_map: dict[NodeIdType, RealNumber], min_max: tuple[float, float] = (0, 1) + ) -> dict[NodeIdType, RealNumber]: + unscaled_min_size = min(node_map.values()) + unscaled_max_size = max(node_map.values()) + unscaled_size_range = float(unscaled_max_size - unscaled_min_size) + + new_min_size, new_max_size = min_max + new_size_range = new_max_size - new_min_size + + if abs(unscaled_size_range) < 1e-6: + default_node_size = new_min_size + new_size_range / 2.0 + new_map = {id: default_node_size for id in node_map} + else: + new_map = { + id: new_min_size + new_size_range * ((nz - unscaled_min_size) / unscaled_size_range) + for id, nz in node_map.items() + } + + return new_map + + def color_nodes( + self, + *, + field: str | None = None, + property: str | None = None, + colors: ColorsType | None = None, + color_space: ColorSpace = ColorSpace.DISCRETE, + override: bool = True, + ) -> None: + """Color nodes by a field or property (discrete or continuous). See `VisualizationGraph.color_nodes` for details.""" + if not ((field is None) ^ (property is None)): + raise ValueError( + f"Exactly one of the arguments `field` (received '{field}') and `property` (received '{property}') must be provided" + ) + + if field is None: + assert property is not None + attribute = property + + def node_to_attr(node: Node) -> Any: + return node.properties.get(attribute) + + else: + assert field is not None + attribute = to_snake(field) + + def node_to_attr(node: Node) -> Any: + return getattr(node, attribute) + + if color_space == ColorSpace.DISCRETE: + if colors is None: + colors = NEO4J_COLORS_DISCRETE + else: + node_map = {node.id: node_to_attr(node) for node in self.nodes if node_to_attr(node) is not None} + normalized_map = self._normalize_values(node_map) + + if colors is None: + colors = NEO4J_COLORS_CONTINUOUS + + if not isinstance(colors, list): + raise ValueError("For continuous properties, `colors` must be a list of colors representing a range") + + num_colors = len(colors) + colors = { + node_to_attr(node): colors[round(normalized_map[node.id] * (num_colors - 1))] + for node in self.nodes + if node_to_attr(node) is not None + } + + if isinstance(colors, dict): + self._color_items_dict(self.nodes, colors, override, node_to_attr) + else: + self._color_items_iter(self.nodes, attribute, colors, override, node_to_attr) + + self._host._sync_entities(nodes=True) + + def color_relationships( + self, + *, + field: str | None = None, + property: str | None = None, + colors: ColorsType | None = None, + color_space: ColorSpace = ColorSpace.DISCRETE, + override: bool = True, + ) -> None: + """Color relationships by a field or property (discrete or continuous). See `VisualizationGraph.color_relationships` for details.""" + if not ((field is None) ^ (property is None)): + raise ValueError( + f"Exactly one of the arguments `field` (received '{field}') and `property` (received '{property}') must be provided" + ) + + if field is None: + assert property is not None + attribute = property + + def rel_to_attr(rel: Relationship) -> Any: + return rel.properties.get(attribute) + + else: + assert field is not None + attribute = to_snake(field) + + def rel_to_attr(rel: Relationship) -> Any: + return getattr(rel, attribute) + + if color_space == ColorSpace.DISCRETE: + if colors is None: + colors = NEO4J_COLORS_DISCRETE + else: + rel_map = {rel.id: rel_to_attr(rel) for rel in self.relationships if rel_to_attr(rel) is not None} + normalized_map = self._normalize_values(rel_map) + + if colors is None: + colors = NEO4J_COLORS_CONTINUOUS + + if not isinstance(colors, list): + raise ValueError("For continuous properties, `colors` must be a list of colors representing a range") + + num_colors = len(colors) + colors = { + rel_to_attr(rel): colors[round(normalized_map[rel.id] * (num_colors - 1))] + for rel in self.relationships + if rel_to_attr(rel) is not None + } + + if isinstance(colors, dict): + self._color_items_dict(self.relationships, colors, override, rel_to_attr) + else: + self._color_items_iter(self.relationships, attribute, colors, override, rel_to_attr) + + self._host._sync_entities(relationships=True) + + def _color_items_dict( + self, + items: list[Node] | list[Relationship], + colors: dict[Hashable, ColorType], + override: bool, + item_to_attr: Callable[[Any], Any], + ) -> None: + for item in items: + color = colors.get(item_to_attr(item)) + + if color is None: + continue + + if item.color is not None and not override: + continue + + if not isinstance(color, Color): + item.color = Color(color) + else: + item.color = color + + def _color_items_iter( + self, + items: list[Node] | list[Relationship], + attribute: str, + colors: Iterable[ColorType], + override: bool, + item_to_attr: Callable[[Any], Any], + ) -> None: + exhausted_colors = False + prop_to_color = {} + colors_iter = iter(colors) + for item in items: + raw_prop = item_to_attr(item) + try: + prop = self._make_hashable(raw_prop) + except ValueError: + item_type = "nodes" if isinstance(item, Node) else "relationships" + raise ValueError(f"Unable to color {item_type} by unhashable property type '{type(raw_prop)}'") + + if prop not in prop_to_color: + next_color = next(colors_iter, None) + if next_color is None: + exhausted_colors = True + colors_iter = iter(colors) + next_color = next(colors_iter) + prop_to_color[prop] = next_color + + color = prop_to_color[prop] + + if item.color is not None and not override: + continue + + if not isinstance(color, Color): + item.color = Color(color) + else: + item.color = color + + if exhausted_colors: + warnings.warn( + f"Ran out of colors for property '{attribute}'. {len(prop_to_color)} colors were needed, but only " + f"{len(set(prop_to_color.values()))} were given, so reused colors" + ) + + @staticmethod + def _make_hashable(raw_prop: Any) -> Hashable: + prop = raw_prop + if isinstance(raw_prop, list): + prop = tuple(raw_prop) + elif isinstance(raw_prop, set): + prop = frozenset(raw_prop) + elif isinstance(raw_prop, dict): + prop = tuple(sorted(raw_prop.items())) + + try: + hash(prop) + except TypeError: + raise ValueError(f"Unable to convert '{raw_prop}' of type {type(raw_prop)} to a hashable type") + + assert isinstance(prop, Hashable) + + return prop diff --git a/python-wrapper/src/neo4j_viz/options.py b/python-wrapper/src/neo4j_viz/options.py index 2c4eb4ca..ad36bbf2 100644 --- a/python-wrapper/src/neo4j_viz/options.py +++ b/python-wrapper/src/neo4j_viz/options.py @@ -2,7 +2,7 @@ import warnings from enum import Enum -from typing import Any, Optional, Union +from typing import Any, Optional, TypedDict, Union import enum_tools.documentation from pydantic import BaseModel, Field, ValidationError, model_validator @@ -144,6 +144,41 @@ def check(self, renderer: Renderer, num_nodes: int) -> None: } +class PanPosition(TypedDict): + """The ``{x, y}`` pan position consumed by the frontend.""" + + x: float + y: float + + +class NvlOptionsDict(TypedDict, total=False): + """The subset of NVL instance options set from Python, nested under ``nvlOptions``. + + The frontend's ``nvlOptions`` is a ``Partial`` with many more fields; this only + types the keys the Python wrapper writes. Other keys round-trip through unchanged at runtime. + """ + + disableWebGL: bool + minZoom: float + maxZoom: float + allowDynamicMinZoom: bool + + +class RenderOptionsDict(TypedDict, total=False): + """The JS-shaped render options consumed by the ``GraphWidget`` frontend. + + This mirrors the ``GraphOptions`` type in ``js-applet/src/graph-widget.tsx`` and is the + structure stored in :attr:`GraphWidget.options`. + """ + + layout: str + layoutOptions: dict[str, Any] + nvlOptions: NvlOptionsDict + zoom: float + pan: PanPosition + showLayoutButton: bool + + class RenderOptions(BaseModel, extra="allow"): """ Options as documented at https://neo4j.com/docs/nvl/current/base-library/#_options @@ -178,7 +213,7 @@ def check_layout_options_match(self) -> RenderOptions: raise ValueError("layout_options must be of type ForceDirectedLayoutOptions for force-directed layout") return self - def to_js_options(self) -> dict[str, Any]: + def to_js_options(self) -> RenderOptionsDict: """Convert render options to the JS-compatible format for the GraphVisualization component. Returns a dict with keys that map to React component props and NVL options: @@ -188,7 +223,7 @@ def to_js_options(self) -> dict[str, Any]: - ``pan``: ``{x, y}`` pan position - ``layoutOptions``: layout-specific options """ - result: dict[str, Any] = {} + result: RenderOptionsDict = {} if self.layout is not None: match self.layout: @@ -206,7 +241,7 @@ def to_js_options(self) -> dict[str, Any]: if self.layout_options is not None: result["layoutOptions"] = self.layout_options.model_dump(exclude_none=True) - nvl_options: dict[str, Any] = {} + nvl_options: NvlOptionsDict = {} if self.renderer is not None: nvl_options["disableWebGL"] = self.renderer != Renderer.WEB_GL if self.min_zoom is not None: diff --git a/python-wrapper/src/neo4j_viz/visualization_graph.py b/python-wrapper/src/neo4j_viz/visualization_graph.py index c25b8604..28ba65c6 100644 --- a/python-wrapper/src/neo4j_viz/visualization_graph.py +++ b/python-wrapper/src/neo4j_viz/visualization_graph.py @@ -1,16 +1,14 @@ from __future__ import annotations -import warnings -from collections.abc import Hashable, Iterable -from typing import Any, Callable, Literal +from functools import cached_property +from typing import Any, Literal from IPython.display import HTML -from pydantic.alias_generators import to_snake -from pydantic_extra_types.color import Color, ColorType -from .colors import NEO4J_COLORS_CONTINUOUS, NEO4J_COLORS_DISCRETE, ColorSpace, ColorsType +from ._graph_entity_operations import GraphEntityOperations +from .colors import ColorSpace, ColorsType from .node import Node, NodeIdType -from .node_size import RealNumber, verify_radii +from .node_size import RealNumber from .nvl import NVL from .options import ( Layout, @@ -87,202 +85,12 @@ def __init__(self, nodes: list[Node], relationships: list[Relationship]) -> None def __str__(self) -> str: return f"VisualizationGraph(nodes={len(self.nodes)}, relationships={len(self.relationships)})" - def _build_render_options( - self, - layout: Layout | str | None, - layout_options: dict[str, Any] | LayoutOptions | None, - renderer: Renderer | str, - pan_position: tuple[float, float] | None, - initial_zoom: float | None, - min_zoom: float, - max_zoom: float, - allow_dynamic_min_zoom: bool, - max_allowed_nodes: int, - show_layout_button: bool, - ) -> RenderOptions: - """Shared validation + option building for render / render_widget.""" - num_nodes = len(self.nodes) - if num_nodes > max_allowed_nodes: - raise ValueError( - f"Too many nodes ({num_nodes}) to render. Maximum allowed nodes is set " - f"to {max_allowed_nodes} for performance reasons. It can be increased by " - "overriding `max_allowed_nodes`, but rendering could then take a long time" - ) - - if isinstance(renderer, str): - renderer = Renderer(renderer) - - Renderer.check(renderer, num_nodes) - - if not layout: - layout = Layout.FORCE_DIRECTED - if isinstance(layout, str): - layout = Layout(layout.lower()) - if not layout_options: - layout_options = {} - - if isinstance(layout_options, dict): - layout_options_typed = construct_layout_options(layout, layout_options) - else: - layout_options_typed = layout_options - - return RenderOptions( - layout=layout, - layout_options=layout_options_typed, - renderer=renderer, - pan_X=pan_position[0] if pan_position is not None else None, - pan_Y=pan_position[1] if pan_position is not None else None, - initial_zoom=initial_zoom, - min_zoom=min_zoom, - max_zoom=max_zoom, - allow_dynamic_min_zoom=allow_dynamic_min_zoom, - show_layout_button=show_layout_button, - ) - - def render( - self, - layout: Layout | str | None = None, - layout_options: dict[str, Any] | LayoutOptions | None = None, - renderer: Renderer | str = Renderer.CANVAS, - width: str = "100%", - height: str = "600px", - pan_position: tuple[float, float] | None = None, - initial_zoom: float | None = None, - min_zoom: float = 0.075, - max_zoom: float = 10, - allow_dynamic_min_zoom: bool = True, - max_allowed_nodes: int = 10_000, - theme: Literal["auto"] | Literal["light"] | Literal["dark"] = "auto", - ) -> HTML: - """ - Render the graph as an HTML object. - - Returns an :class:`IPython.display.HTML` object that will be displayed in environments - that support HTML rendering, such as Jupyter notebooks or Streamlit applications. - - Parameters - ---------- - layout: - The `Layout` to use. - layout_options: - The `LayoutOptions` to use. - renderer: - The `Renderer` to use. - width: - The width of the rendered graph. - height: - The height of the rendered graph. - pan_position: - The initial pan position. - initial_zoom: - The initial zoom level. - min_zoom: - The minimum zoom level. - max_zoom: - The maximum zoom level. - allow_dynamic_min_zoom: - Whether to allow dynamic minimum zoom level. - max_allowed_nodes: - The maximum allowed number of nodes to render. - theme: - The theme of the rendered graph. Can be 'auto', 'light', or 'dark' - - Example - ------- - Basic rendering of a VisualizationGraph: - >>> from neo4j_viz import Node, Relationship, VisualizationGraph - """ - render_options = self._build_render_options( - layout, - layout_options, - renderer, - pan_position, - initial_zoom, - min_zoom, - max_zoom, - allow_dynamic_min_zoom, - max_allowed_nodes, - show_layout_button=False, # The button only works with the widget - ) - - return NVL().render( - self.nodes, - self.relationships, - render_options, - width, - height, - theme, - ) - - def render_widget( - self, - layout: Layout | str | None = None, - layout_options: dict[str, Any] | LayoutOptions | None = None, - renderer: Renderer | str = Renderer.CANVAS, - width: str = "100%", - height: str = "600px", - pan_position: tuple[float, float] | None = None, - initial_zoom: float | None = None, - min_zoom: float = 0.075, - max_zoom: float = 10, - allow_dynamic_min_zoom: bool = True, - max_allowed_nodes: int = 10_000, - theme: Literal["auto"] | Literal["light"] | Literal["dark"] = "auto", - ) -> GraphWidget: - """ - Render the graph as an interactive Jupyter widget (anywidget). - - Returns a :class:`GraphWidget` that provides two-way data sync between Python - and JavaScript. Works in JupyterLab, Notebook 7, VS Code, and Colab. - - Parameters - ---------- - layout: - The `Layout` to use. - layout_options: - The `LayoutOptions` to use. - renderer: - The `Renderer` to use. - width: - The width of the rendered graph. - height: - The height of the rendered graph. - pan_position: - The initial pan position. - initial_zoom: - The initial zoom level. - min_zoom: - The minimum zoom level. - max_zoom: - The maximum zoom level. - allow_dynamic_min_zoom: - Whether to allow dynamic minimum zoom level. - max_allowed_nodes: - The maximum allowed number of nodes to render. - theme: - The theme to use for the rendered graph. - """ - render_options = self._build_render_options( - layout, - layout_options, - renderer, - pan_position, - initial_zoom, - min_zoom, - max_zoom, - allow_dynamic_min_zoom, - max_allowed_nodes, - show_layout_button=True, - ) + @cached_property + def _entity_ops(self) -> GraphEntityOperations: + return GraphEntityOperations(self) - return GraphWidget.from_graph_data( - self.nodes, - self.relationships, - width=width, - height=height, - options=render_options, - theme=theme, - ) + def _sync_entities(self, *, nodes: bool = False, relationships: bool = False) -> None: + """Hook invoked after entities are mutated in place. A no-op for a plain graph.""" def toggle_nodes_pinned(self, pinned: dict[NodeIdType, bool]) -> None: """ @@ -293,13 +101,7 @@ def toggle_nodes_pinned(self, pinned: dict[NodeIdType, bool]) -> None: pinned: A dictionary mapping from node ID to whether the node should be pinned or not. """ - for node in self.nodes: - node_pinned = pinned.get(node.id) - - if node_pinned is None: - continue - - node.pinned = node_pinned + self._entity_ops.toggle_nodes_pinned(pinned) def set_node_captions( self, @@ -328,7 +130,7 @@ def set_node_captions( ... Node(id="0", properties={"name": "Alice", "age": 30}), ... Node(id="1", properties={"name": "Bob", "age": 25}), ... ] - >>> VG = VisualizationGraph(nodes=nodes) + >>> VG = VisualizationGraph(nodes=nodes, relationships=[]) Set node captions from a property: @@ -337,38 +139,8 @@ def set_node_captions( Set node captions from a field, only if not already set: >>> VG.set_node_captions(field="id", override=False) - - Set captions from multiple properties with fallback: - - >>> for node in VG.nodes: - ... caption = node.properties.get("name") or node.properties.get("title") or node.id - ... if override or node.caption is None: - ... node.caption = str(caption) """ - if not ((field is None) ^ (property is None)): - raise ValueError( - f"Exactly one of the arguments `field` (received '{field}') and `property` (received '{property}') must be provided" - ) - - if property: - # Use property - for node in self.nodes: - if not override and node.caption is not None: - continue - - value = node.properties.get(property, "") - node.caption = str(value) - else: - # Use field - assert field is not None - attribute = to_snake(field) - - for node in self.nodes: - if not override and node.caption is not None: - continue - - value = getattr(node, attribute, "") - node.caption = str(value) + self._entity_ops.set_node_captions(field=field, property=property, override=override) def resize_nodes( self, @@ -391,55 +163,7 @@ def resize_nodes( property: The property of the nodes to use for sizing. Must be None if `sizes` is provided. """ - if sizes is not None and property is not None: - raise ValueError("At most one of the arguments `sizes` and `property` can be provided") - - if sizes is None and property is None and node_radius_min_max is None: - raise ValueError("At least one of `sizes`, `property` or `node_radius_min_max` must be given") - - # Gather node sizes - all_sizes = {} - if sizes is not None: - for node in self.nodes: - size = sizes.get(node.id, node.size) - if size is not None: - all_sizes[node.id] = size - elif property is not None: - for node in self.nodes: - size = node.properties.get(property, node.size) - if size is not None: - all_sizes[node.id] = size - else: - for node in self.nodes: - if node.size is not None: - all_sizes[node.id] = node.size - - # Validate node sizes - for id, size in all_sizes.items(): - if size is None: - continue - - if not isinstance(size, (int, float)): - raise ValueError(f"Size for node '{id}' must be a real number, but was {size}") - - if size < 0: - raise ValueError(f"Size for node '{id}' must be non-negative, but was {size}") - - if node_radius_min_max is not None: - verify_radii(node_radius_min_max) - - final_sizes = self._normalize_values(all_sizes, node_radius_min_max) - else: - final_sizes = all_sizes - - # Apply the final sizes to the nodes - for node in self.nodes: - size = final_sizes.get(node.id) - - if size is None: - continue - - node.size = size + self._entity_ops.resize_nodes(sizes=sizes, node_radius_min_max=node_radius_min_max, property=property) def resize_relationships( self, @@ -458,61 +182,7 @@ def resize_relationships( property: The property of the relationships to use for sizing. Must be None if `widths` is provided. """ - if widths is not None and property is not None: - raise ValueError("At most one of the arguments `widths` and `property` can be provided") - - if widths is None and property is None: - raise ValueError("At least one of `widths` or `property` must be given") - - # Gather relationship widths - all_widths = {} - if widths is not None: - for rel in self.relationships: - width = widths.get(rel.id, rel.width) - if width is not None: - all_widths[rel.id] = width - elif property is not None: - for rel in self.relationships: - width = rel.properties.get(property, rel.width) - if width is not None: - all_widths[rel.id] = width - - # Validate and apply relationship widths - for rel in self.relationships: - width = all_widths.get(rel.id) - - if width is None: - continue - - if not isinstance(width, (int, float)): - raise ValueError(f"Width for relationship '{rel.id}' must be a real number, but was {width}") - - if width <= 0: - raise ValueError(f"Width for relationship '{rel.id}' must be positive, but was {width}") - - rel.width = width - - @staticmethod - def _normalize_values( - node_map: dict[NodeIdType, RealNumber], min_max: tuple[float, float] = (0, 1) - ) -> dict[NodeIdType, RealNumber]: - unscaled_min_size = min(node_map.values()) - unscaled_max_size = max(node_map.values()) - unscaled_size_range = float(unscaled_max_size - unscaled_min_size) - - new_min_size, new_max_size = min_max - new_size_range = new_max_size - new_min_size - - if abs(unscaled_size_range) < 1e-6: - default_node_size = new_min_size + new_size_range / 2.0 - new_map = {id: default_node_size for id in node_map} - else: - new_map = { - id: new_min_size + new_size_range * ((nz - unscaled_min_size) / unscaled_size_range) - for id, nz in node_map.items() - } - - return new_map + self._entity_ops.resize_relationships(widths=widths, property=property) def color_nodes( self, @@ -561,7 +231,7 @@ def color_nodes( ... Node(id="0", properties={"label": "Person", "score": 10}), ... Node(id="1", properties={"label": "Person", "score": 20}), ... ] - >>> VG = VisualizationGraph(nodes=nodes) + >>> VG = VisualizationGraph(nodes=nodes, relationships=[]) Color nodes based on a discrete field such as "label": @@ -576,49 +246,9 @@ def color_nodes( >>> from palettable.wesanderson import Moonrise1_5 # type: ignore[import-untyped] >>> VG.color_nodes(field="label", colors=Moonrise1_5.colors) """ - if not ((field is None) ^ (property is None)): - raise ValueError( - f"Exactly one of the arguments `field` (received '{field}') and `property` (received '{property}') must be provided" - ) - - if field is None: - assert property is not None - attribute = property - - def node_to_attr(node: Node) -> Any: - return node.properties.get(attribute) - - else: - assert field is not None - attribute = to_snake(field) - - def node_to_attr(node: Node) -> Any: - return getattr(node, attribute) - - if color_space == ColorSpace.DISCRETE: - if colors is None: - colors = NEO4J_COLORS_DISCRETE - else: - node_map = {node.id: node_to_attr(node) for node in self.nodes if node_to_attr(node) is not None} - normalized_map = self._normalize_values(node_map) - - if colors is None: - colors = NEO4J_COLORS_CONTINUOUS - - if not isinstance(colors, list): - raise ValueError("For continuous properties, `colors` must be a list of colors representing a range") - - num_colors = len(colors) - colors = { - node_to_attr(node): colors[round(normalized_map[node.id] * (num_colors - 1))] - for node in self.nodes - if node_to_attr(node) is not None - } - - if isinstance(colors, dict): - self._color_items_dict(self.nodes, colors, override, node_to_attr) - else: - self._color_items_iter(self.nodes, attribute, colors, override, node_to_attr) + self._entity_ops.color_nodes( + field=field, property=property, colors=colors, color_space=color_space, override=override + ) def color_relationships( self, @@ -678,129 +308,203 @@ def color_relationships( >>> VG.color_relationships(property="score", color_space=ColorSpace.CONTINUOUS) """ - if not ((field is None) ^ (property is None)): + self._entity_ops.color_relationships( + field=field, property=property, colors=colors, color_space=color_space, override=override + ) + + def _build_render_options( + self, + layout: Layout | str | None, + layout_options: dict[str, Any] | LayoutOptions | None, + renderer: Renderer | str, + pan_position: tuple[float, float] | None, + initial_zoom: float | None, + min_zoom: float, + max_zoom: float, + allow_dynamic_min_zoom: bool, + max_allowed_nodes: int, + show_layout_button: bool, + ) -> RenderOptions: + """Shared validation + option building for render / render_widget.""" + num_nodes = len(self.nodes) + if num_nodes > max_allowed_nodes: raise ValueError( - f"Exactly one of the arguments `field` (received '{field}') and `property` (received '{property}') must be provided" + f"Too many nodes ({num_nodes}) to render. Maximum allowed nodes is set " + f"to {max_allowed_nodes} for performance reasons. It can be increased by " + "overriding `max_allowed_nodes`, but rendering could then take a long time" ) - if field is None: - assert property is not None - attribute = property - - def rel_to_attr(rel: Relationship) -> Any: - return rel.properties.get(attribute) + if isinstance(renderer, str): + renderer = Renderer(renderer) - else: - assert field is not None - attribute = to_snake(field) + Renderer.check(renderer, num_nodes) - def rel_to_attr(rel: Relationship) -> Any: - return getattr(rel, attribute) + if not layout: + layout = Layout.FORCE_DIRECTED + if isinstance(layout, str): + layout = Layout(layout.lower()) + if not layout_options: + layout_options = {} - if color_space == ColorSpace.DISCRETE: - if colors is None: - colors = NEO4J_COLORS_DISCRETE + if isinstance(layout_options, dict): + layout_options_typed = construct_layout_options(layout, layout_options) else: - rel_map = {rel.id: rel_to_attr(rel) for rel in self.relationships if rel_to_attr(rel) is not None} - normalized_map = self._normalize_values(rel_map) - - if colors is None: - colors = NEO4J_COLORS_CONTINUOUS - - if not isinstance(colors, list): - raise ValueError("For continuous properties, `colors` must be a list of colors representing a range") - - num_colors = len(colors) - colors = { - rel_to_attr(rel): colors[round(normalized_map[rel.id] * (num_colors - 1))] - for rel in self.relationships - if rel_to_attr(rel) is not None - } + layout_options_typed = layout_options - if isinstance(colors, dict): - self._color_items_dict(self.relationships, colors, override, rel_to_attr) - else: - self._color_items_iter(self.relationships, attribute, colors, override, rel_to_attr) + return RenderOptions( + layout=layout, + layout_options=layout_options_typed, + renderer=renderer, + pan_X=pan_position[0] if pan_position is not None else None, + pan_Y=pan_position[1] if pan_position is not None else None, + initial_zoom=initial_zoom, + min_zoom=min_zoom, + max_zoom=max_zoom, + allow_dynamic_min_zoom=allow_dynamic_min_zoom, + show_layout_button=show_layout_button, + ) - def _color_items_dict( + def render( self, - items: list[Node] | list[Relationship], - colors: dict[Hashable, ColorType], - override: bool, - item_to_attr: Callable[[Any], Any], - ) -> None: - for item in items: - color = colors.get(item_to_attr(item)) + layout: Layout | str | None = None, + layout_options: dict[str, Any] | LayoutOptions | None = None, + renderer: Renderer | str = Renderer.CANVAS, + width: str = "100%", + height: str = "600px", + pan_position: tuple[float, float] | None = None, + initial_zoom: float | None = None, + min_zoom: float = 0.075, + max_zoom: float = 10, + allow_dynamic_min_zoom: bool = True, + max_allowed_nodes: int = 10_000, + theme: Literal["auto"] | Literal["light"] | Literal["dark"] = "auto", + ) -> HTML: + """ + Render the graph as an HTML object. + + Returns an :class:`IPython.display.HTML` object that will be displayed in environments + that support HTML rendering, such as Jupyter notebooks or Streamlit applications. - if color is None: - continue + Parameters + ---------- + layout: + The `Layout` to use. + layout_options: + The `LayoutOptions` to use. + renderer: + The `Renderer` to use. + width: + The width of the rendered graph. + height: + The height of the rendered graph. + pan_position: + The initial pan position. + initial_zoom: + The initial zoom level. + min_zoom: + The minimum zoom level. + max_zoom: + The maximum zoom level. + allow_dynamic_min_zoom: + Whether to allow dynamic minimum zoom level. + max_allowed_nodes: + The maximum allowed number of nodes to render. + theme: + The theme of the rendered graph. Can be 'auto', 'light', or 'dark' - if item.color is not None and not override: - continue + Example + ------- + Basic rendering of a VisualizationGraph: + >>> from neo4j_viz import Node, Relationship, VisualizationGraph + """ + render_options = self._build_render_options( + layout, + layout_options, + renderer, + pan_position, + initial_zoom, + min_zoom, + max_zoom, + allow_dynamic_min_zoom, + max_allowed_nodes, + show_layout_button=False, # The button only works with the widget + ) - if not isinstance(color, Color): - item.color = Color(color) - else: - item.color = color + return NVL().render( + self.nodes, + self.relationships, + render_options, + width, + height, + theme, + ) - def _color_items_iter( + def render_widget( self, - items: list[Node] | list[Relationship], - attribute: str, - colors: Iterable[ColorType], - override: bool, - item_to_attr: Callable[[Any], Any], - ) -> None: - exhausted_colors = False - prop_to_color = {} - colors_iter = iter(colors) - for item in items: - raw_prop = item_to_attr(item) - try: - prop = self._make_hashable(raw_prop) - except ValueError: - item_type = "nodes" if isinstance(item, Node) else "relationships" - raise ValueError(f"Unable to color {item_type} by unhashable property type '{type(raw_prop)}'") - - if prop not in prop_to_color: - next_color = next(colors_iter, None) - if next_color is None: - exhausted_colors = True - colors_iter = iter(colors) - next_color = next(colors_iter) - prop_to_color[prop] = next_color - - color = prop_to_color[prop] - - if item.color is not None and not override: - continue - - if not isinstance(color, Color): - item.color = Color(color) - else: - item.color = color - - if exhausted_colors: - warnings.warn( - f"Ran out of colors for property '{attribute}'. {len(prop_to_color)} colors were needed, but only " - f"{len(set(prop_to_color.values()))} were given, so reused colors" - ) - - @staticmethod - def _make_hashable(raw_prop: Any) -> Hashable: - prop = raw_prop - if isinstance(raw_prop, list): - prop = tuple(raw_prop) - elif isinstance(raw_prop, set): - prop = frozenset(raw_prop) - elif isinstance(raw_prop, dict): - prop = tuple(sorted(raw_prop.items())) + layout: Layout | str | None = None, + layout_options: dict[str, Any] | LayoutOptions | None = None, + renderer: Renderer | str = Renderer.CANVAS, + width: str = "100%", + height: str = "600px", + pan_position: tuple[float, float] | None = None, + initial_zoom: float | None = None, + min_zoom: float = 0.075, + max_zoom: float = 10, + allow_dynamic_min_zoom: bool = True, + max_allowed_nodes: int = 10_000, + theme: Literal["auto"] | Literal["light"] | Literal["dark"] = "auto", + ) -> GraphWidget: + """ + Render the graph as an interactive Jupyter widget (anywidget). - try: - hash(prop) - except TypeError: - raise ValueError(f"Unable to convert '{raw_prop}' of type {type(raw_prop)} to a hashable type") + Returns a :class:`GraphWidget` that provides two-way data sync between Python + and JavaScript. Works in JupyterLab, Notebook 7, VS Code, and Colab. - assert isinstance(prop, Hashable) + Parameters + ---------- + layout: + The `Layout` to use. + layout_options: + The `LayoutOptions` to use. + renderer: + The `Renderer` to use. + width: + The width of the rendered graph. + height: + The height of the rendered graph. + pan_position: + The initial pan position. + initial_zoom: + The initial zoom level. + min_zoom: + The minimum zoom level. + max_zoom: + The maximum zoom level. + allow_dynamic_min_zoom: + Whether to allow dynamic minimum zoom level. + max_allowed_nodes: + The maximum allowed number of nodes to render. + theme: + The theme to use for the rendered graph. + """ + render_options = self._build_render_options( + layout, + layout_options, + renderer, + pan_position, + initial_zoom, + min_zoom, + max_zoom, + allow_dynamic_min_zoom, + max_allowed_nodes, + show_layout_button=True, + ) - return prop + return GraphWidget.from_graph_data( + self.nodes, + self.relationships, + width=width, + height=height, + options=render_options, + theme=theme, + ) diff --git a/python-wrapper/src/neo4j_viz/widget.py b/python-wrapper/src/neo4j_viz/widget.py index 433d411d..0abfa7c9 100644 --- a/python-wrapper/src/neo4j_viz/widget.py +++ b/python-wrapper/src/neo4j_viz/widget.py @@ -2,13 +2,25 @@ import json import pathlib -from typing import Any, Union +from functools import cached_property +from typing import Any, Union, cast import anywidget import traitlets +from ._graph_entity_operations import GraphEntityOperations +from .colors import ColorSpace, ColorsType from .node import Node, NodeIdType -from .options import RenderOptions +from .node_size import RealNumber +from .options import ( + Layout, + LayoutOptions, + NvlOptionsDict, + Renderer, + RenderOptions, + RenderOptionsDict, + construct_layout_options, +) from .relationship import Relationship, RelationshipIdType @@ -42,14 +54,18 @@ def entity_to_json(entity_list: list[Node | Relationship], widget: anywidget.Any return [_serialize_entity(entity) for entity in entity_list] +# Dev mode: set ANYWIDGET_HMR=1 and run ``yarn dev`` in js-applet/ +# for hot module replacement during development. + + class GraphWidget(anywidget.AnyWidget): """Jupyter widget for interactive graph visualization. Uses anywidget to render a React-based graph component with two-way data sync between Python and JavaScript. - Dev mode: set ANYWIDGET_HMR=1 and run ``yarn dev`` in js-applet/ - for hot module replacement during development. + The widget exposes utility methods that mutate the graph in place and + automatically sync the changes to the frontend. """ _esm = _STATIC / "widget.js" @@ -87,6 +103,347 @@ def from_graph_data( def __str__(self) -> str: return f"GraphWidget(nodes={len(self.nodes)}, relationships={len(self.relationships)}, options={self.options}, theme={self.theme}, width={self.width}, height={self.height})" + @cached_property + def _entity_ops(self) -> GraphEntityOperations: + return GraphEntityOperations(self) + + def sync_nodes(self) -> None: + """Manually trigger a sync of the `nodes` list to the frontend.""" + self._sync_entities(nodes=True) + + def sync_relationships(self) -> None: + """Manually trigger a sync of the `relationships` list to the frontend.""" + self._sync_entities(relationships=True) + + def _sync_entities(self, *, nodes: bool = False, relationships: bool = False) -> None: + """Propagate in-place entity mutations to the frontend. + + The utility methods delegated to :class:`GraphEntityOperations` mutate the `Node` + and `Relationship` objects in place. This does not change the identity (or equality) + of the `nodes`/`relationships` lists, so traitlets does not detect a change and would + not sync. We therefore explicitly push the affected trait(s) to JavaScript, which + re-serializes them via `entity_to_json`. When the widget is not connected to a + frontend (e.g. outside a notebook), `send_state` is a no-op. + """ + keys = [] + if nodes: + keys.append("nodes") + if relationships: + keys.append("relationships") + if keys: + self.send_state(keys if len(keys) > 1 else keys[0]) + + def toggle_nodes_pinned(self, pinned: dict[NodeIdType, bool]) -> None: + """ + Toggle whether nodes should be pinned or not. + + Parameters + ---------- + pinned: + A dictionary mapping from node ID to whether the node should be pinned or not. + """ + self._entity_ops.toggle_nodes_pinned(pinned) + + def set_node_captions( + self, + *, + field: str | None = None, + property: str | None = None, + override: bool = True, + ) -> None: + """ + Set the caption for nodes in the graph based on either a node field or a node property. + + Parameters + ---------- + field: + The field of the nodes to use as the caption. Must be None if `property` is provided. + property: + The property of the nodes to use as the caption. Must be None if `field` is provided. + override: + Whether to override existing captions of the nodes, if they have any. + + Examples + -------- + Given a GraphWidget `widget`: + + >>> nodes = [ + ... Node(id="0", properties={"name": "Alice", "age": 30}), + ... Node(id="1", properties={"name": "Bob", "age": 25}), + ... ] + >>> widget = GraphWidget(nodes=nodes) + + Set node captions from a property: + + >>> widget.set_node_captions(property="name") + + Set node captions from a field, only if not already set: + + >>> widget.set_node_captions(field="id", override=False) + """ + self._entity_ops.set_node_captions(field=field, property=property, override=override) + + def resize_nodes( + self, + sizes: dict[NodeIdType, RealNumber] | None = None, + node_radius_min_max: tuple[RealNumber, RealNumber] | None = (3, 60), + property: str | None = None, + ) -> None: + """ + Resize the nodes in the graph. + + Parameters + ---------- + sizes: + A dictionary mapping from node ID to the new size of the node. + If a node ID is not in the dictionary, the size of the node is not changed. + Must be None if `property` is provided. + node_radius_min_max: + Minimum and maximum node size radius as a tuple. To avoid tiny or huge nodes in the visualization, the + node sizes are scaled to fit in the given range. If None, the sizes are used as is. + property: + The property of the nodes to use for sizing. Must be None if `sizes` is provided. + """ + self._entity_ops.resize_nodes(sizes=sizes, node_radius_min_max=node_radius_min_max, property=property) + + def resize_relationships( + self, + widths: dict[str | int, RealNumber] | None = None, + property: str | None = None, + ) -> None: + """ + Resize the width of relationships in the graph. + + Parameters + ---------- + widths: + A dictionary mapping from relationship ID to the new width of the relationship. + If a relationship ID is not in the dictionary, the width of the relationship is not changed. + Must be None if `property` is provided. + property: + The property of the relationships to use for sizing. Must be None if `widths` is provided. + """ + self._entity_ops.resize_relationships(widths=widths, property=property) + + def color_nodes( + self, + *, + field: str | None = None, + property: str | None = None, + colors: ColorsType | None = None, + color_space: ColorSpace = ColorSpace.DISCRETE, + override: bool = True, + ) -> None: + """ + Color the nodes in the graph based on either a node field, or a node property. + + It's possible to color the nodes based on a discrete or continuous color space. In the discrete case, a new + color from the `colors` provided is assigned to each unique value of the node field/property. + In the continuous case, the `colors` should be a list of colors representing a range that are used to + create a gradient of colors based on the values of the node field/property. + + Parameters + ---------- + field: + The field of the nodes to base the coloring on. The type of this field must be hashable, or be a + list, set or dict containing only hashable types. Must be None if `property` is provided. + property: + The property of the nodes to base the coloring on. The type of this property must be hashable, or be a + list, set or dict containing only hashable types. Must be None if `field` is provided. + colors: + The colors to use for the nodes. + If `color_space` is `ColorSpace.DISCRETE`, the colors can be a dictionary mapping from field/property value + to color, or an iterable of colors in which case the colors are used in order. + If `color_space` is `ColorSpace.CONTINUOUS`, the colors must be a list of colors representing a range. + Allowed color values are for example “#FF0000”, “red” or (255, 0, 0) (full list: https://docs.pydantic.dev/2.0/usage/types/extra_types/color_types/). + The default colors are the Neo4j graph colors. + color_space: + The type of space of the provided `colors`. Either `ColorSpace.DISCRETE` or `ColorSpace.CONTINUOUS`. It determines whether + colors are assigned based on unique field/property values or a gradient of the values of the field/property. + override: + Whether to override existing colors of the nodes, if they have any. + + Examples + -------- + + Given a GraphWidget `widget`: + + >>> nodes = [ + ... Node(id="0", properties={"label": "Person", "score": 10}), + ... Node(id="1", properties={"label": "Person", "score": 20}), + ... ] + >>> widget = GraphWidget(nodes=nodes) + + Color nodes based on a discrete field such as "label": + + >>> widget.color_nodes(field="label", color_space=ColorSpace.DISCRETE) + + Color nodes based on a continuous field such as "score": + + >>> widget.color_nodes(field="score", color_space=ColorSpace.CONTINUOUS) + + Color nodes based on a custom colors such as from palettable: + + >>> from palettable.wesanderson import Moonrise1_5 # type: ignore[import-untyped] + >>> widget.color_nodes(field="label", colors=Moonrise1_5.colors) + """ + self._entity_ops.color_nodes( + field=field, property=property, colors=colors, color_space=color_space, override=override + ) + + def color_relationships( + self, + *, + field: str | None = None, + property: str | None = None, + colors: ColorsType | None = None, + color_space: ColorSpace = ColorSpace.DISCRETE, + override: bool = True, + ) -> None: + """ + Color the relationships in the graph based on either a relationship field, or a relationship property. + + It's possible to color the relationships based on a discrete or continuous color space. In the discrete case, + a new color from the `colors` provided is assigned to each unique value of the relationship field/property. + In the continuous case, the `colors` should be a list of colors representing a range that are used to + create a gradient of colors based on the values of the relationship field/property. + + Parameters + ---------- + field: + The field of the relationships to base the coloring on. The type of this field must be hashable, or be a + list, set or dict containing only hashable types. Must be None if `property` is provided. + property: + The property of the relationships to base the coloring on. The type of this property must be hashable, or be a + list, set or dict containing only hashable types. Must be None if `field` is provided. + colors: + The colors to use for the relationships. + If `color_space` is `ColorSpace.DISCRETE`, the colors can be a dictionary mapping from field/property value + to color, or an iterable of colors in which case the colors are used in order. + If `color_space` is `ColorSpace.CONTINUOUS`, the colors must be a list of colors representing a range. + Allowed color values are for example “#FF0000”, “red” or (255, 0, 0) (full list: https://docs.pydantic.dev/2.0/usage/types/extra_types/color_types/). + The default colors are the Neo4j graph colors. + color_space: + The type of space of the provided `colors`. Either `ColorSpace.DISCRETE` or `ColorSpace.CONTINUOUS`. It determines whether + colors are assigned based on unique field/property values or a gradient of the values of the field/property. + override: + Whether to override existing colors of the relationships, if they have any. + + Examples + -------- + + Given a GraphWidget `widget`: + + >>> nodes = [Node(id="0"), Node(id="1")] + >>> relationships = [ + ... Relationship(source="0", target="1", caption="ACTED_IN", properties={"score": 10}), + ... Relationship(source="1", target="0", caption="DIRECTED", properties={"score": 20}), + ... ] + >>> widget = GraphWidget(nodes=nodes, relationships=relationships) + + Color relationships based on a discrete field such as "caption": + + >>> widget.color_relationships(field="caption", color_space=ColorSpace.DISCRETE) + + Color relationships based on a continuous field such as "score": + + >>> widget.color_relationships(property="score", color_space=ColorSpace.CONTINUOUS) + """ + self._entity_ops.color_relationships( + field=field, property=property, colors=colors, color_space=color_space, override=override + ) + + def _render_options(self) -> RenderOptionsDict: + """Return a typed, mutable copy of the current JS-shaped render options.""" + return cast(RenderOptionsDict, dict(self.options)) + + def set_layout(self, layout: Layout | str, layout_options: dict[str, Any] | LayoutOptions | None = None) -> None: + """ + Change the layout algorithm used to position the graph, in place. + + Parameters + ----------- + layout: + The layout algorithm to use (e.g. `Layout.FORCE_DIRECTED`, `Layout.HIERARCHICAL`). + layout_options: + Optional layout-specific options. Either a `HierarchicalLayoutOptions`/`ForceDirectedLayoutOptions` + instance or a plain dict, which is validated against the chosen layout. Layout options are only + supported for the force-directed and hierarchical layouts. + """ + if isinstance(layout, str): + layout = Layout(layout) + + if isinstance(layout_options, dict): + layout_options = construct_layout_options(layout, layout_options) + + js = RenderOptions(layout=layout, layout_options=layout_options).to_js_options() + + new = self._render_options() + new["layout"] = js["layout"] + if "layoutOptions" in js: + new["layoutOptions"] = js["layoutOptions"] + else: + new.pop("layoutOptions", None) + self.options = dict(new) + + def set_zoom(self, zoom: float) -> None: + """ + Change the zoom level of the graph, in place. + + Parameters + ----------- + zoom: + The zoom level to apply. + """ + new = self._render_options() + new["zoom"] = zoom + self.options = dict(new) + + def set_pan(self, x: float, y: float) -> None: + """ + Change the pan position of the graph, in place. + + Parameters + ----------- + x: + The pan position along the x-axis. + y: + The pan position along the y-axis. + """ + new = self._render_options() + new["pan"] = {"x": x, "y": y} + self.options = dict(new) + + def set_renderer(self, renderer: Renderer) -> None: + """ + Change the renderer used to draw the graph, in place. + + Parameters + ----------- + renderer: + The renderer to use, either `Renderer.WEB_GL` or `Renderer.CANVAS`. + """ + Renderer.check(renderer, len(self.nodes)) + + new = self._render_options() + nvl_options = cast(NvlOptionsDict, dict(new.get("nvlOptions", {}))) + nvl_options["disableWebGL"] = renderer != Renderer.WEB_GL + new["nvlOptions"] = nvl_options + self.options = dict(new) + + def set_show_layout_button(self, show: bool = True) -> None: + """ + Toggle the layout selector button in the widget UI, in place. + + Parameters + ----------- + show: + Whether the layout button should be shown. + """ + new = self._render_options() + new["showLayoutButton"] = show + self.options = dict(new) + def add_data( self, nodes: Node | list[Node] | None = None, relationships: Relationship | list[Relationship] | None = None ) -> None: diff --git a/python-wrapper/tests/test_widget.py b/python-wrapper/tests/test_widget.py index 8033411c..83bc26c8 100644 --- a/python-wrapper/tests/test_widget.py +++ b/python-wrapper/tests/test_widget.py @@ -4,7 +4,7 @@ import pytest from neo4j_viz import GraphWidget, Node, Relationship, VisualizationGraph -from neo4j_viz.options import Layout, RenderOptions +from neo4j_viz.options import Layout, Renderer, RenderOptions from neo4j_viz.widget import _serialize_entity @@ -213,6 +213,87 @@ def test_remove_data(self) -> None: assert {r.id for r in widget.relationships} == {43} +class TestWidgetUtilityMethods: + def _spy_send_state(self, widget: GraphWidget) -> list[Any]: + synced: list[Any] = [] + widget.send_state = lambda key=None: synced.append(key) + return synced + + def test_color_nodes(self) -> None: + widget = GraphWidget(nodes=[Node(id="n1", properties={"label": "A"}), Node(id="n2", properties={"label": "B"})]) + synced = self._spy_send_state(widget) + + widget.color_nodes(property="label") + + assert widget.nodes[0].color is not None + assert widget.nodes[1].color is not None + assert widget.nodes[0].color != widget.nodes[1].color + # Mutating in place must still push the updated nodes to the frontend. + assert synced == ["nodes"] + + def test_color_relationships(self) -> None: + widget = GraphWidget( + nodes=[Node(id="n1"), Node(id="n2")], + relationships=[ + Relationship(source="n1", target="n2", caption="KNOWS"), + Relationship(source="n2", target="n1", caption="LIKES"), + ], + ) + synced = self._spy_send_state(widget) + + widget.color_relationships(field="caption") + + assert widget.relationships[0].color is not None + assert widget.relationships[0].color != widget.relationships[1].color + assert synced == ["relationships"] + + def test_resize_nodes(self) -> None: + widget = GraphWidget( + nodes=[ + Node(id="n1", properties={"score": 10}), + Node(id="n2", properties={"score": 20}), + ] + ) + synced = self._spy_send_state(widget) + + widget.resize_nodes(property="score", node_radius_min_max=(10, 50)) + + assert widget.nodes[0].size == 10 + assert widget.nodes[1].size == 50 + assert synced == ["nodes"] + + def test_resize_relationships(self) -> None: + widget = GraphWidget( + nodes=[Node(id="n1"), Node(id="n2")], + relationships=[Relationship(id="r1", source="n1", target="n2")], + ) + synced = self._spy_send_state(widget) + + widget.resize_relationships(widths={"r1": 5}) + + assert widget.relationships[0].width == 5 + assert synced == ["relationships"] + + def test_set_node_captions(self) -> None: + widget = GraphWidget(nodes=[Node(id="n1", properties={"name": "Alice"})]) + synced = self._spy_send_state(widget) + + widget.set_node_captions(property="name") + + assert widget.nodes[0].caption == "Alice" + assert synced == ["nodes"] + + def test_toggle_nodes_pinned(self) -> None: + widget = GraphWidget(nodes=[Node(id="n1", pinned=False), Node(id="n2")]) + synced = self._spy_send_state(widget) + + widget.toggle_nodes_pinned({"n1": True}) + + assert widget.nodes[0].pinned is True + assert widget.nodes[1].pinned is None + assert synced == ["nodes"] + + render_widget_cases = { "default": {}, "force layout": {"layout": Layout.FORCE_DIRECTED}, @@ -284,3 +365,98 @@ def test_render_widget_options_passed_through(self) -> None: assert widget.options["zoom"] == 2.0 assert widget.options["nvlOptions"]["minZoom"] == 0.1 assert widget.options["nvlOptions"]["maxZoom"] == 5.0 + + +class TestRenderOptionSetters: + def test_set_layout(self) -> None: + widget = GraphWidget() + + widget.set_layout(Layout.HIERARCHICAL) + + assert widget.options["layout"] == "hierarchical" + + def test_set_layout_with_options(self) -> None: + widget = GraphWidget() + + widget.set_layout(Layout.FORCE_DIRECTED, {"gravity": 0.1}) + + assert widget.options["layout"] == "d3Force" + assert widget.options["layoutOptions"] == {"gravity": 0.1} + + def test_set_layout_clears_stale_layout_options(self) -> None: + widget = GraphWidget(options={"layoutOptions": {"gravity": 0.1}}) + + widget.set_layout(Layout.GRID) + + assert widget.options["layout"] == "grid" + assert "layoutOptions" not in widget.options + + def test_set_layout_with_mismatched_options_raises(self) -> None: + widget = GraphWidget() + + with pytest.raises(ValueError): + widget.set_layout(Layout.HIERARCHICAL, {"gravity": 0.1}) + + def test_set_zoom(self) -> None: + widget = GraphWidget() + + widget.set_zoom(2.0) + + assert widget.options["zoom"] == 2.0 + + def test_set_pan(self) -> None: + widget = GraphWidget() + + widget.set_pan(100, 50) + + assert widget.options["pan"] == {"x": 100, "y": 50} + + def test_set_renderer_canvas(self) -> None: + widget = GraphWidget() + + widget.set_renderer(Renderer.CANVAS) + + assert widget.options["nvlOptions"]["disableWebGL"] is True + + def test_set_renderer_webgl(self) -> None: + widget = GraphWidget() + + with pytest.warns(UserWarning): + widget.set_renderer(Renderer.WEB_GL) + + assert widget.options["nvlOptions"]["disableWebGL"] is False + + def test_set_renderer_preserves_other_nvl_options(self) -> None: + widget = GraphWidget(options={"nvlOptions": {"minZoom": 0.1}}) + + widget.set_renderer(Renderer.CANVAS) + + assert widget.options["nvlOptions"]["minZoom"] == 0.1 + assert widget.options["nvlOptions"]["disableWebGL"] is True + + def test_set_show_layout_button(self) -> None: + widget = GraphWidget() + + widget.set_show_layout_button() + assert widget.options["showLayoutButton"] is True + + widget.set_show_layout_button(False) + assert widget.options["showLayoutButton"] is False + + def test_setter_preserves_unrelated_options(self) -> None: + widget = GraphWidget(options={"layout": "hierarchical"}) + + widget.set_zoom(3.0) + + assert widget.options["zoom"] == 3.0 + assert widget.options["layout"] == "hierarchical" + + def test_setter_triggers_sync(self) -> None: + widget = GraphWidget() + changes: list[dict[str, Any]] = [] + widget.observe(lambda change: changes.append(change), names=["options"]) + + widget.set_zoom(2.0) + + assert len(changes) == 1 + assert changes[0]["name"] == "options"