1515
1616from datetime import datetime
1717from enum import Enum
18- from typing import Optional , Union , List , Dict
19- import re
18+ from typing import Any , Optional , Union , List , Dict
19+ from json import dumps
20+ from re import sub , search
2021
2122from sagemaker .utils import get_module
2223from sagemaker .lineage ._utils import get_resource_name_from_arn
@@ -235,7 +236,7 @@ def _artifact_to_lineage_object(self):
235236class PyvisVisualizer (object ):
236237 """Create object used for visualizing graph using Pyvis library."""
237238
238- def __init__ (self , graph_styles ):
239+ def __init__ (self , graph_styles , pyvis_options : Optional [ Dict [ str , Any ]] = None ):
239240 """Init for PyvisVisualizer.
240241
241242 Args:
@@ -260,7 +261,8 @@ def __init__(self, graph_styles):
260261 "symbol": "★", # shape symbol for legend
261262 },
262263 }
263-
264+ pyvis_options(optional): A dict containing PyVis options to customize visualization.
265+ (see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
264266 """
265267 # import visualization packages
266268 (
@@ -272,36 +274,29 @@ def __init__(self, graph_styles):
272274
273275 self .graph_styles = graph_styles
274276
275- # pyvis graph options
276- self ._options = """
277- var options = {
278- "configure":{
279- "enabled": false
280- },
281- "layout": {
282- "hierarchical": {
283- "enabled": true,
284- "blockShifting": true,
285- "direction": "LR",
286- "sortMethod": "directed",
287- "shakeTowards": "leaves"
288- }
289- },
290- "interaction": {
291- "multiselect": true,
292- "navigationButtons": true
293- },
294- "physics": {
295- "enabled": false,
296- "hierarchicalRepulsion": {
297- "centralGravity": 0,
298- "avoidOverlap": null
277+ if pyvis_options is None :
278+ # default pyvis graph options
279+ pyvis_options = {
280+ "configure" : {"enabled" : False },
281+ "layout" : {
282+ "hierarchical" : {
283+ "enabled" : True ,
284+ "blockShifting" : True ,
285+ "direction" : "LR" ,
286+ "sortMethod" : "directed" ,
287+ "shakeTowards" : "leaves" ,
288+ }
289+ },
290+ "interaction" : {"multiselect" : True , "navigationButtons" : True },
291+ "physics" : {
292+ "enabled" : False ,
293+ "hierarchicalRepulsion" : {"centralGravity" : 0 , "avoidOverlap" : None },
294+ "minVelocity" : 0.75 ,
295+ "solver" : "hierarchicalRepulsion" ,
299296 },
300- "minVelocity": 0.75,
301- "solver": "hierarchicalRepulsion"
302297 }
303- }
304- "" "
298+ # A string representation of a Javascript-like object used to override pyvis options
299+ self . _pyvis_options = f"var options = { dumps ( pyvis_options ) } "
305300
306301 def _import_visual_modules (self ):
307302 """Import modules needed for visualization."""
@@ -382,14 +377,14 @@ def render(self, elements, path="lineage_graph_pyvis.html"):
382377
383378 """
384379 net = self .Network (height = "600px" , width = "82%" , notebook = True , directed = True )
385- net .set_options (self ._options )
380+ net .set_options (self ._pyvis_options )
386381
387382 # add nodes to graph
388383 for arn , source , entity , is_start_arn in elements ["nodes" ]:
389- entity_text = re . sub (r"(\w)([A-Z])" , r"\1 \2" , entity )
390- source = re . sub (r"(\w)([A-Z])" , r"\1 \2" , source )
391- account_id = re . search (r":\d{12}:" , arn )
392- name = re . search (r"\/.*" , arn )
384+ entity_text = sub (r"(\w)([A-Z])" , r"\1 \2" , entity )
385+ source = sub (r"(\w)([A-Z])" , r"\1 \2" , source )
386+ account_id = search (r":\d{12}:" , arn )
387+ name = search (r"\/.*" , arn )
393388 node_info = (
394389 "Entity: "
395390 + entity_text
@@ -516,7 +511,11 @@ def _get_visualization_elements(self):
516511 elements = {"nodes" : verts , "edges" : edges }
517512 return elements
518513
519- def visualize (self , path : Optional [str ] = "lineage_graph_pyvis.html" ):
514+ def visualize (
515+ self ,
516+ path : Optional [str ] = "lineage_graph_pyvis.html" ,
517+ pyvis_options : Optional [Dict [str , Any ]] = None ,
518+ ):
520519 """Visualize lineage query result.
521520
522521 Creates a PyvisVisualizer object to render network graph with Pyvis library.
@@ -527,6 +526,8 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
527526 Args:
528527 path(optional): The path/filename of the rendered graph html file.
529528 (default path: "lineage_graph_pyvis.html")
529+ pyvis_options(optional): A dict containing PyVis options to customize visualization.
530+ (see https://visjs.github.io/vis-network/docs/network/#options for supported fields)
530531
531532 Returns:
532533 display graph: The interactive visualization is presented as a static HTML file.
@@ -561,7 +562,7 @@ def visualize(self, path: Optional[str] = "lineage_graph_pyvis.html"):
561562 },
562563 }
563564
564- pyvis_vis = PyvisVisualizer (lineage_graph_styles )
565+ pyvis_vis = PyvisVisualizer (lineage_graph_styles , pyvis_options )
565566 elements = self ._get_visualization_elements ()
566567 return pyvis_vis .render (elements = elements , path = path )
567568
0 commit comments