11import matplotlib
2-
32matplotlib .use ("Agg" ) # Add this line before other imports
3+
44import importlib
5+ import inspect
56import logging
67import os
8+
9+ # Import for AWS component discovery
10+ import pkgutil
711import platform
812import subprocess
913from typing import Any , Dict , List , Union
1014
1115import graphviz
1216import matplotlib .pyplot as plt
1317import networkx as nx
14- from diagrams import Diagram as CloudDiagram
1518from strands import tool
1619
17- # AWS service categories - comprehensive list
18- AWS_CATEGORIES = [
19- "analytics" ,
20- "compute" ,
21- "database" ,
22- "network" ,
23- "storage" ,
24- "security" ,
25- "integration" ,
26- "management" ,
27- "ml" ,
28- "general" ,
29- "mobile" ,
30- "migration" ,
31- "devtools" ,
32- "blockchain" ,
33- "business" ,
34- "cost" ,
35- "customer" ,
36- "game" ,
37- "iot" ,
38- "media" ,
39- "quantum" ,
40- "robotics" ,
41- "satellite" ,
42- ]
43-
44- # Common aliases for AWS components
45- COMMON_ALIASES = {
46- # Non-AWS components
47- "users" : "Users" ,
48- "user" : "Users" ,
49- "client" : "Users" ,
50- "clients" : "Users" ,
51- "internet" : "Internet" ,
52- "web" : "Internet" ,
53- # Common AWS aliases
54- "api_gateway" : "APIGateway" ,
55- "apigateway" : "APIGateway" ,
56- "api-gateway" : "APIGateway" ,
57- "dynamo" : "Dynamodb" ,
58- "ddb" : "Dynamodb" ,
59- "ec2" : "EC2" ,
60- "instance" : "EC2" ,
61- "server" : "EC2" ,
62- "ecr" : "EC2ContainerRegistry" ,
63- "registry" : "EC2ContainerRegistry" ,
64- "ecs" : "ElasticContainerService" ,
65- "container" : "ElasticContainerService" ,
66- "eks" : "ElasticKubernetesService" ,
67- "kubernetes" : "ElasticKubernetesService" ,
68- "k8s" : "ElasticKubernetesService" ,
69- "redis" : "Elasticache" ,
70- "memcached" : "Elasticache" ,
71- "es" : "Elasticsearch" ,
72- "elb" : "ElasticLoadBalancing" ,
73- "loadbalancer" : "ElasticLoadBalancing" ,
74- "alb" : "ApplicationLoadBalancer" ,
75- "nlb" : "NetworkLoadBalancer" ,
76- "events" : "Eventbridge" ,
77- "identity" : "IAM" ,
78- "streaming" : "Kinesis" ,
79- "encryption" : "KMS" ,
80- "function" : "Lambda" ,
81- "serverless" : "Lambda" ,
82- "mysql" : "RDS" ,
83- "postgres" : "RDS" ,
84- "warehouse" : "Redshift" ,
85- "dns" : "Route53" ,
86- "r53" : "Route53" ,
87- "bucket" : "SimpleStorageServiceS3" ,
88- "storage" : "SimpleStorageServiceS3" ,
89- "secrets" : "SecretsManager" ,
90- "notification" : "SimpleNotificationServiceSns" ,
91- "topic" : "SimpleNotificationServiceSns" ,
92- "queue" : "SimpleQueueServiceSqs" ,
93- "messaging" : "SimpleQueueServiceSqs" ,
94- "workflow" : "StepFunctions" ,
95- "firewall" : "WAF" ,
96- }
97-
98- # Cache for discovered components
99- _aws_component_cache = {}
100-
20+ from diagrams import Diagram as CloudDiagram
21+ from diagrams import aws
10122
102- def get_aws_node (node_type : str ) -> Any :
103- """Dynamically discover and return AWS component - supports all 532+ components"""
104- # Check cache first
105- if node_type in _aws_component_cache :
106- return _aws_component_cache [node_type ]
10723
108- # Normalize input
109- normalized = node_type .lower ().replace ("-" , "_" ).replace (" " , "_" )
24+ class AWSComponentRegistry :
25+ """
26+ Class responsible for discovering and managing AWS components from the diagrams package.
27+ Encapsulates the component discovery, caching and lookup functionality.
28+ """
11029
111- # Try common aliases first
112- canonical_name = COMMON_ALIASES .get (normalized , node_type )
30+ def __init__ (self ):
31+ """Initialize the registry with discovered components and aliases"""
32+ self ._component_cache = {}
33+ self .categories = self ._discover_categories ()
34+ self .components = self ._discover_components ()
35+ self .aliases = self ._build_aliases ()
11336
114- # Try non-AWS components (Users, Internet, etc.)
115- if canonical_name in [ "Users" , "Internet" , "Mobile" ]:
116- # Try main diagrams module first
37+ def _discover_categories ( self ) -> List [ str ]:
38+ """Dynamically discover all AWS categories from the diagrams package"""
39+ categories = []
11740 try :
118- module = importlib .import_module ("diagrams" )
119- if hasattr (module , canonical_name ):
120- component = getattr (module , canonical_name )
121- _aws_component_cache [node_type ] = component
122- return component
123- except ImportError :
124- pass
125-
126- # Try onprem.network for Internet
127- if canonical_name == "Internet" :
41+ # Use pkgutil to discover all modules in diagrams.aws
42+ for _ , name , is_pkg in pkgutil .iter_modules (aws .__path__ ):
43+ if not is_pkg and not name .startswith ("_" ):
44+ categories .append (name )
45+ except Exception as e :
46+ logging .warning (f"Failed to discover AWS categories: { e } " )
47+ return []
48+ return categories
49+
50+ def _discover_components (self ) -> Dict [str , List [str ]]:
51+ """Dynamically discover all available AWS components by category"""
52+ components = {}
53+ for category in self .categories :
54+ try :
55+ module = importlib .import_module (f"diagrams.aws.{ category } " )
56+ # Get all public classes (components) from the module
57+ components [category ] = [
58+ name
59+ for name , obj in inspect .getmembers (module )
60+ if inspect .isclass (obj ) and not name .startswith ("_" )
61+ ]
62+ except ImportError :
63+ continue
64+ return components
65+
66+ def _build_aliases (self ) -> Dict [str , str ]:
67+ """Build aliases dictionary by analyzing available components"""
68+ aliases = {}
69+
70+ # Add non-AWS components first
71+ aliases .update (
72+ {
73+ "users" : "Users" ,
74+ "user" : "Users" ,
75+ "client" : "Users" ,
76+ "clients" : "Users" ,
77+ "internet" : "Internet" ,
78+ "web" : "Internet" ,
79+ "mobile" : "Mobile" ,
80+ }
81+ )
82+
83+ # Analyze component names to create common aliases
84+ for _ , component_list in self .components .items ():
85+ for component in component_list :
86+ # Create lowercase alias
87+ aliases [component .lower ()] = component
88+
89+ # Create alias without service prefix/suffix
90+ clean_name = component .replace ("Service" , "" ).replace ("Amazon" , "" ).replace ("AWS" , "" )
91+ if clean_name != component :
92+ aliases [clean_name .lower ()] = component
93+
94+ # Add common abbreviations
95+ if component .isupper (): # Likely an acronym
96+ aliases [component .lower ()] = component
97+
98+ return aliases
99+
100+ def get_node (self , node_type : str ) -> Any :
101+ """Get AWS component class using dynamic discovery with caching"""
102+ # Check cache first
103+ if node_type in self ._component_cache :
104+ return self ._component_cache [node_type ]
105+
106+ # Normalize input
107+ normalized = node_type .lower ()
108+
109+ # Try common aliases first
110+ canonical_name = self .aliases .get (normalized , node_type )
111+
112+ # Search through all discovered components
113+ for category , component_list in self .components .items ():
128114 try :
129- module = importlib .import_module ("diagrams.onprem.network" )
130- if hasattr (module , canonical_name ):
115+ module = importlib .import_module (f"diagrams.aws.{ category } " )
116+ # Try exact match first
117+ if canonical_name in component_list :
131118 component = getattr (module , canonical_name )
132- _aws_component_cache [node_type ] = component
119+ self . _component_cache [node_type ] = component
133120 return component
121+ # Try case-insensitive match
122+ for component_name in component_list :
123+ if component_name .lower () == canonical_name .lower ():
124+ component = getattr (module , component_name )
125+ self ._component_cache [node_type ] = component
126+ return component
134127 except ImportError :
135- pass
128+ continue
136129
137- # Search all AWS categories for the component
138- for category in AWS_CATEGORIES :
139- try :
140- module = importlib .import_module (f"diagrams.aws.{ category } " )
141- # Try exact match first
142- if hasattr (module , canonical_name ):
143- component = getattr (module , canonical_name )
144- _aws_component_cache [node_type ] = component
145- return component
146- # Try case-insensitive match
147- for attr in dir (module ):
148- if attr .lower () == canonical_name .lower () and not attr .startswith ("_" ):
149- component = getattr (module , attr )
150- _aws_component_cache [node_type ] = component
151- return component
152- except ImportError :
153- continue
130+ raise ValueError (f"Component '{ node_type } ' not found in available AWS components" )
131+
132+ def list_available_components (self , category : str = None ) -> Dict [str , List [str ]]:
133+ """List all available AWS components and their aliases"""
134+ if category :
135+ return {category : self .components .get (category , [])}
136+ return self .components
137+
138+
139+ # Initialize the AWS component registry as a singleton
140+ aws_registry = AWSComponentRegistry ()
154141
155- # Try original input as-is
156- for category in AWS_CATEGORIES :
157- try :
158- module = importlib .import_module (f"diagrams.aws.{ category } " )
159- if hasattr (module , node_type ):
160- component = getattr (module , node_type )
161- _aws_component_cache [node_type ] = component
162- return component
163- except ImportError :
164- continue
165142
166- raise ValueError (f"Component '{ node_type } ' not found. Try: { list (COMMON_ALIASES .keys ())[:10 ]} ..." )
143+ # Expose necessary functions and variables at module level for backward compatibility
144+ def get_aws_node (node_type : str ) -> Any :
145+ """Get AWS component class using dynamic discovery"""
146+ return aws_registry .get_node (node_type )
147+
148+
149+ def list_available_components (category : str = None ) -> Dict [str , List [str ]]:
150+ """List all available AWS components and their aliases"""
151+ return aws_registry .list_available_components (category )
167152
168153
169- # These functions have been removed as the agent can generate mermaid/ascii directly
154+ # Export variables for backward compatibility
155+ AWS_CATEGORIES = aws_registry .categories
156+ AVAILABLE_AWS_COMPONENTS = aws_registry .components
157+ COMMON_ALIASES = aws_registry .aliases
170158
171159
172160class DiagramBuilder :
@@ -180,13 +168,7 @@ def __init__(self, nodes, edges=None, title="diagram", style=None):
180168
181169 def render (self , diagram_type : str , output_format : str ) -> str :
182170 """Main render method that delegates to specific renderers"""
183- if output_format in ["mermaid" , "ascii" ]:
184- raise NotImplementedError (
185- "Mermaid and ASCII rendering has been removed. "
186- "Use the agent's LLM capabilities to generate mermaid code directly."
187- )
188171
189- # Delegate to specific diagram type methods
190172 method_map = {
191173 "cloud" : self ._render_cloud ,
192174 "graph" : self ._render_graph ,
@@ -198,13 +180,6 @@ def render(self, diagram_type: str, output_format: str) -> str:
198180
199181 return method_map [diagram_type ](output_format )
200182
201- def _render_text (self , output_format : str ) -> str :
202- """Handle text formats (mermaid/ascii) - unified for all diagram types"""
203- raise NotImplementedError (
204- "Mermaid and ASCII rendering has been removed. "
205- "Use the agent's LLM capabilities to generate mermaid code directly."
206- )
207-
208183 def _render_cloud (self , output_format : str ) -> str :
209184 """Create AWS architecture diagram"""
210185 if not self .nodes :
@@ -227,16 +202,34 @@ def _render_cloud(self, output_format: str) -> str:
227202 for node_id , node_type in invalid_nodes :
228203 # Find close matches
229204 close_matches = [k for k in COMMON_ALIASES .keys () if node_type .lower () in k or k in node_type .lower ()]
205+ # Find canonical names for suggestions
206+ canonical_suggestions = [COMMON_ALIASES [k ] for k in close_matches [:3 ]] if close_matches else []
207+
230208 if close_matches :
231- suggestions .append (f" - '{ node_id } ' (type: '{ node_type } ') -> try: { close_matches [:3 ]} " )
209+ suggestions .append (
210+ f" - '{ node_id } ' (type: '{ node_type } ') -> try: \
211+ { close_matches [:3 ]} (maps to: { canonical_suggestions } )"
212+ )
232213 else :
233214 suggestions .append (f" - '{ node_id } ' (type: '{ node_type } ') -> no close matches found" )
234215
235- common_types = ["ec2" , "s3" , "lambda" , "rds" , "api_gateway" , "cloudfront" , "route53" , "elb" ]
216+ common_types = [
217+ "ec2" ,
218+ "s3" ,
219+ "lambda" ,
220+ "rds" ,
221+ "api_gateway" ,
222+ "cloudfront" ,
223+ "route53" ,
224+ "elb" ,
225+ "opensearch" ,
226+ "dynamodb" ,
227+ ]
236228 error_msg = (
237229 f"Invalid AWS component types found:\n { chr (10 ).join (suggestions )} \n \n "
238230 f"Common types: { common_types } \n Note: "
239- f"All 532+ AWS components are supported - try the exact AWS service name"
231+ f"All 532+ AWS components are supported - \
232+ try using one of the aliases in COMMON_ALIASES or the exact AWS service name"
240233 )
241234 raise ValueError (error_msg )
242235
@@ -381,12 +374,6 @@ def __init__(
381374
382375 def render (self , output_format : str = "png" ) -> str :
383376 """Render the UML diagram based on type"""
384- # Handle text formats (mermaid/ascii) for UML diagrams
385- if output_format in ["mermaid" , "ascii" ]:
386- raise NotImplementedError (
387- "Mermaid and ASCII rendering has been removed. "
388- "Use the agent's LLM capabilities to generate mermaid code directly."
389- )
390377
391378 method_map = {
392379 # Structural diagrams
@@ -1041,13 +1028,6 @@ def _add_class_relationship(self, dot: graphviz.Digraph, rel: Dict):
10411028 multiplicity = rel .get ("multiplicity" , "" )
10421029 dot .edge (rel ["from" ], rel ["to" ], label = multiplicity )
10431030
1044- def _render_text (self , output_format : str ) -> str :
1045- """Handle text formats (mermaid/ascii) for UML diagrams"""
1046- raise NotImplementedError (
1047- "Mermaid and ASCII rendering has been removed. "
1048- "Use the agent's LLM capabilities to generate mermaid code directly."
1049- )
1050-
10511031 def _save_diagram (self , dot : graphviz .Digraph , output_format : str ) -> str :
10521032 """Save diagram and return file path"""
10531033 output_path = save_diagram_to_directory (self .title , "" )
0 commit comments