1616from datetime import datetime
1717from enum import Enum
1818from typing import Optional , Union , List , Dict
19+ import re
1920
2021from sagemaker .lineage ._utils import get_resource_name_from_arn , get_module
2122
@@ -260,6 +261,8 @@ def __init__(self, graph_styles):
260261 (
261262 self .Network ,
262263 self .Options ,
264+ self .IFrame ,
265+ self .BeautifulSoup ,
263266 ) = self ._import_visual_modules ()
264267
265268 self .graph_styles = graph_styles
@@ -300,13 +303,60 @@ def _import_visual_modules(self):
300303 get_module ("pyvis" )
301304 from pyvis .network import Network
302305 from pyvis .options import Options
306+ from IPython .display import IFrame
303307
304- return Network , Options
308+ get_module ("bs4" )
309+ from bs4 import BeautifulSoup
310+
311+ return Network , Options , IFrame , BeautifulSoup
305312
306313 def _node_color (self , entity ):
307314 """Return node color by background-color specified in graph styles."""
308315 return self .graph_styles [entity ]["style" ]["background-color" ]
309316
317+ def _get_legend_line (self , component_name ):
318+ """Generate lengend div line for each graph component in graph_styles."""
319+ if self .graph_styles [component_name ]["isShape" ] == "False" :
320+ return '<div><div style="background-color: {color}; width: 1.6vw; height: 1.6vw;\
321+ display: inline-block; font-size: 1.5vw; vertical-align: -0.2em;"></div>\
322+ <div style="width: 0.3vw; height: 1.5vw; display: inline-block;"></div>\
323+ <div style="display: inline-block; font-size: 1.5vw;">{name}</div></div>' .format (
324+ color = self .graph_styles [component_name ]["style" ]["background-color" ],
325+ name = self .graph_styles [component_name ]["name" ],
326+ )
327+ else :
328+ return '<div style="background-color: #ffffff; width: 1.6vw; height: 1.6vw;\
329+ display: inline-block; font-size: 0.9vw; vertical-align: -0.2em;">{shape}</div>\
330+ <div style="width: 0.3vw; height: 1.5vw; display: inline-block;"></div>\
331+ <div style="display: inline-block; font-size: 1.5vw;">{name}</div></div>' .format (
332+ shape = self .graph_styles [component_name ]["style" ]["shape" ],
333+ name = self .graph_styles [component_name ]["name" ],
334+ )
335+
336+ def _add_legend (self , path ):
337+ """Embed legend to html file generated by pyvis."""
338+ f = open (path , "r" )
339+ content = self .BeautifulSoup (f , "html.parser" )
340+
341+ legend = """
342+ <div style="display: inline-block; font-size: 1vw; font-family: verdana;
343+ vertical-align: top; padding: 1vw;">
344+ """
345+ # iterate through graph styles to get legend
346+ for component in self .graph_styles .keys ():
347+ legend += self ._get_legend_line (component_name = component )
348+
349+ legend += "</div>"
350+
351+ legend_div = self .BeautifulSoup (legend , "html.parser" )
352+
353+ content .div .insert_after (legend_div )
354+
355+ html = content .prettify ()
356+
357+ with open (path , "w" , encoding = "utf8" ) as file :
358+ file .write (html )
359+
310360 def render (self , elements , path = "pyvisExample.html" ):
311361 """Render graph for lineage query result.
312362
@@ -325,23 +375,51 @@ def render(self, elements, path="pyvisExample.html"):
325375 display graph: The interactive visualization is presented as a static HTML file.
326376
327377 """
328- net = self .Network (height = "500px " , width = "100 %" , notebook = True , directed = True )
378+ net = self .Network (height = "600px " , width = "82 %" , notebook = True , directed = True )
329379 net .set_options (self ._options )
330380
331381 # add nodes to graph
332382 for arn , source , entity , is_start_arn in elements ["nodes" ]:
383+ entity_text = re .sub (r"(\w)([A-Z])" , r"\1 \2" , entity )
384+ source = re .sub (r"(\w)([A-Z])" , r"\1 \2" , source )
385+ account_id = re .search (r":\d{12}:" , arn )
386+ name = re .search (r"\/.*" , arn )
387+ node_info = (
388+ "Entity: "
389+ + entity_text
390+ + "\n Type: "
391+ + source
392+ + "\n Account ID: "
393+ + str (account_id .group ()[1 :- 1 ])
394+ + "\n Name: "
395+ + str (name .group ()[1 :])
396+ )
333397 if is_start_arn : # startarn
334398 net .add_node (
335- arn , label = source , title = entity , color = self ._node_color (entity ), shape = "star"
399+ arn ,
400+ label = source ,
401+ title = node_info ,
402+ color = self ._node_color (entity ),
403+ shape = "star" ,
404+ borderWidth = 3 ,
336405 )
337406 else :
338- net .add_node (arn , label = source , title = entity , color = self ._node_color (entity ))
407+ net .add_node (
408+ arn ,
409+ label = source ,
410+ title = node_info ,
411+ color = self ._node_color (entity ),
412+ borderWidth = 3 ,
413+ )
339414
340415 # add edges to graph
341416 for src , dest , asso_type in elements ["edges" ]:
342- net .add_edge (src , dest , title = asso_type )
417+ net .add_edge (src , dest , title = asso_type , width = 2 )
418+
419+ net .write_html (path )
420+ self ._add_legend (path )
343421
344- return net . show (path )
422+ return self . IFrame (path , width = "100%" , height = "600px" )
345423
346424
347425class LineageQueryResult (object ):
@@ -391,7 +469,7 @@ def __str__(self):
391469
392470 """
393471 return (
394- "{\n "
472+ "{"
395473 + "\n \n " .join ("'{}': {}," .format (key , val ) for key , val in self .__dict__ .items ())
396474 + "\n }"
397475 )
0 commit comments