Skip to content

Commit 68e5a06

Browse files
committed
fix: adding AWSComponentRegistry class to load in diagrams component types, rather than hardcoding
1 parent 8d43d51 commit 68e5a06

File tree

2 files changed

+150
-203
lines changed

2 files changed

+150
-203
lines changed

src/strands_tools/diagram.py

Lines changed: 150 additions & 170 deletions
Original file line numberDiff line numberDiff line change
@@ -1,172 +1,160 @@
11
import matplotlib
2-
32
matplotlib.use("Agg") # Add this line before other imports
3+
44
import importlib
5+
import inspect
56
import logging
67
import os
8+
9+
# Import for AWS component discovery
10+
import pkgutil
711
import platform
812
import subprocess
913
from typing import Any, Dict, List, Union
1014

1115
import graphviz
1216
import matplotlib.pyplot as plt
1317
import networkx as nx
14-
from diagrams import Diagram as CloudDiagram
1518
from strands import tool
1619

17-
# AWS service categories - comprehensive list
18-
AWS_CATEGORIES = [
19-
"analytics",
20-
"compute",
21-
"database",
22-
"network",
23-
"storage",
24-
"security",
25-
"integration",
26-
"management",
27-
"ml",
28-
"general",
29-
"mobile",
30-
"migration",
31-
"devtools",
32-
"blockchain",
33-
"business",
34-
"cost",
35-
"customer",
36-
"game",
37-
"iot",
38-
"media",
39-
"quantum",
40-
"robotics",
41-
"satellite",
42-
]
43-
44-
# Common aliases for AWS components
45-
COMMON_ALIASES = {
46-
# Non-AWS components
47-
"users": "Users",
48-
"user": "Users",
49-
"client": "Users",
50-
"clients": "Users",
51-
"internet": "Internet",
52-
"web": "Internet",
53-
# Common AWS aliases
54-
"api_gateway": "APIGateway",
55-
"apigateway": "APIGateway",
56-
"api-gateway": "APIGateway",
57-
"dynamo": "Dynamodb",
58-
"ddb": "Dynamodb",
59-
"ec2": "EC2",
60-
"instance": "EC2",
61-
"server": "EC2",
62-
"ecr": "EC2ContainerRegistry",
63-
"registry": "EC2ContainerRegistry",
64-
"ecs": "ElasticContainerService",
65-
"container": "ElasticContainerService",
66-
"eks": "ElasticKubernetesService",
67-
"kubernetes": "ElasticKubernetesService",
68-
"k8s": "ElasticKubernetesService",
69-
"redis": "Elasticache",
70-
"memcached": "Elasticache",
71-
"es": "Elasticsearch",
72-
"elb": "ElasticLoadBalancing",
73-
"loadbalancer": "ElasticLoadBalancing",
74-
"alb": "ApplicationLoadBalancer",
75-
"nlb": "NetworkLoadBalancer",
76-
"events": "Eventbridge",
77-
"identity": "IAM",
78-
"streaming": "Kinesis",
79-
"encryption": "KMS",
80-
"function": "Lambda",
81-
"serverless": "Lambda",
82-
"mysql": "RDS",
83-
"postgres": "RDS",
84-
"warehouse": "Redshift",
85-
"dns": "Route53",
86-
"r53": "Route53",
87-
"bucket": "SimpleStorageServiceS3",
88-
"storage": "SimpleStorageServiceS3",
89-
"secrets": "SecretsManager",
90-
"notification": "SimpleNotificationServiceSns",
91-
"topic": "SimpleNotificationServiceSns",
92-
"queue": "SimpleQueueServiceSqs",
93-
"messaging": "SimpleQueueServiceSqs",
94-
"workflow": "StepFunctions",
95-
"firewall": "WAF",
96-
}
97-
98-
# Cache for discovered components
99-
_aws_component_cache = {}
100-
20+
from diagrams import Diagram as CloudDiagram
21+
from diagrams import aws
10122

102-
def get_aws_node(node_type: str) -> Any:
103-
"""Dynamically discover and return AWS component - supports all 532+ components"""
104-
# Check cache first
105-
if node_type in _aws_component_cache:
106-
return _aws_component_cache[node_type]
10723

108-
# Normalize input
109-
normalized = node_type.lower().replace("-", "_").replace(" ", "_")
24+
class AWSComponentRegistry:
25+
"""
26+
Class responsible for discovering and managing AWS components from the diagrams package.
27+
Encapsulates the component discovery, caching and lookup functionality.
28+
"""
11029

111-
# Try common aliases first
112-
canonical_name = COMMON_ALIASES.get(normalized, node_type)
30+
def __init__(self):
31+
"""Initialize the registry with discovered components and aliases"""
32+
self._component_cache = {}
33+
self.categories = self._discover_categories()
34+
self.components = self._discover_components()
35+
self.aliases = self._build_aliases()
11336

114-
# Try non-AWS components (Users, Internet, etc.)
115-
if canonical_name in ["Users", "Internet", "Mobile"]:
116-
# Try main diagrams module first
37+
def _discover_categories(self) -> List[str]:
38+
"""Dynamically discover all AWS categories from the diagrams package"""
39+
categories = []
11740
try:
118-
module = importlib.import_module("diagrams")
119-
if hasattr(module, canonical_name):
120-
component = getattr(module, canonical_name)
121-
_aws_component_cache[node_type] = component
122-
return component
123-
except ImportError:
124-
pass
125-
126-
# Try onprem.network for Internet
127-
if canonical_name == "Internet":
41+
# Use pkgutil to discover all modules in diagrams.aws
42+
for _, name, is_pkg in pkgutil.iter_modules(aws.__path__):
43+
if not is_pkg and not name.startswith("_"):
44+
categories.append(name)
45+
except Exception as e:
46+
logging.warning(f"Failed to discover AWS categories: {e}")
47+
return []
48+
return categories
49+
50+
def _discover_components(self) -> Dict[str, List[str]]:
51+
"""Dynamically discover all available AWS components by category"""
52+
components = {}
53+
for category in self.categories:
54+
try:
55+
module = importlib.import_module(f"diagrams.aws.{category}")
56+
# Get all public classes (components) from the module
57+
components[category] = [
58+
name
59+
for name, obj in inspect.getmembers(module)
60+
if inspect.isclass(obj) and not name.startswith("_")
61+
]
62+
except ImportError:
63+
continue
64+
return components
65+
66+
def _build_aliases(self) -> Dict[str, str]:
67+
"""Build aliases dictionary by analyzing available components"""
68+
aliases = {}
69+
70+
# Add non-AWS components first
71+
aliases.update(
72+
{
73+
"users": "Users",
74+
"user": "Users",
75+
"client": "Users",
76+
"clients": "Users",
77+
"internet": "Internet",
78+
"web": "Internet",
79+
"mobile": "Mobile",
80+
}
81+
)
82+
83+
# Analyze component names to create common aliases
84+
for _, component_list in self.components.items():
85+
for component in component_list:
86+
# Create lowercase alias
87+
aliases[component.lower()] = component
88+
89+
# Create alias without service prefix/suffix
90+
clean_name = component.replace("Service", "").replace("Amazon", "").replace("AWS", "")
91+
if clean_name != component:
92+
aliases[clean_name.lower()] = component
93+
94+
# Add common abbreviations
95+
if component.isupper(): # Likely an acronym
96+
aliases[component.lower()] = component
97+
98+
return aliases
99+
100+
def get_node(self, node_type: str) -> Any:
101+
"""Get AWS component class using dynamic discovery with caching"""
102+
# Check cache first
103+
if node_type in self._component_cache:
104+
return self._component_cache[node_type]
105+
106+
# Normalize input
107+
normalized = node_type.lower()
108+
109+
# Try common aliases first
110+
canonical_name = self.aliases.get(normalized, node_type)
111+
112+
# Search through all discovered components
113+
for category, component_list in self.components.items():
128114
try:
129-
module = importlib.import_module("diagrams.onprem.network")
130-
if hasattr(module, canonical_name):
115+
module = importlib.import_module(f"diagrams.aws.{category}")
116+
# Try exact match first
117+
if canonical_name in component_list:
131118
component = getattr(module, canonical_name)
132-
_aws_component_cache[node_type] = component
119+
self._component_cache[node_type] = component
133120
return component
121+
# Try case-insensitive match
122+
for component_name in component_list:
123+
if component_name.lower() == canonical_name.lower():
124+
component = getattr(module, component_name)
125+
self._component_cache[node_type] = component
126+
return component
134127
except ImportError:
135-
pass
128+
continue
136129

137-
# Search all AWS categories for the component
138-
for category in AWS_CATEGORIES:
139-
try:
140-
module = importlib.import_module(f"diagrams.aws.{category}")
141-
# Try exact match first
142-
if hasattr(module, canonical_name):
143-
component = getattr(module, canonical_name)
144-
_aws_component_cache[node_type] = component
145-
return component
146-
# Try case-insensitive match
147-
for attr in dir(module):
148-
if attr.lower() == canonical_name.lower() and not attr.startswith("_"):
149-
component = getattr(module, attr)
150-
_aws_component_cache[node_type] = component
151-
return component
152-
except ImportError:
153-
continue
130+
raise ValueError(f"Component '{node_type}' not found in available AWS components")
131+
132+
def list_available_components(self, category: str = None) -> Dict[str, List[str]]:
133+
"""List all available AWS components and their aliases"""
134+
if category:
135+
return {category: self.components.get(category, [])}
136+
return self.components
137+
138+
139+
# Initialize the AWS component registry as a singleton
140+
aws_registry = AWSComponentRegistry()
154141

155-
# Try original input as-is
156-
for category in AWS_CATEGORIES:
157-
try:
158-
module = importlib.import_module(f"diagrams.aws.{category}")
159-
if hasattr(module, node_type):
160-
component = getattr(module, node_type)
161-
_aws_component_cache[node_type] = component
162-
return component
163-
except ImportError:
164-
continue
165142

166-
raise ValueError(f"Component '{node_type}' not found. Try: {list(COMMON_ALIASES.keys())[:10]}...")
143+
# Expose necessary functions and variables at module level for backward compatibility
144+
def get_aws_node(node_type: str) -> Any:
145+
"""Get AWS component class using dynamic discovery"""
146+
return aws_registry.get_node(node_type)
147+
148+
149+
def list_available_components(category: str = None) -> Dict[str, List[str]]:
150+
"""List all available AWS components and their aliases"""
151+
return aws_registry.list_available_components(category)
167152

168153

169-
# These functions have been removed as the agent can generate mermaid/ascii directly
154+
# Export variables for backward compatibility
155+
AWS_CATEGORIES = aws_registry.categories
156+
AVAILABLE_AWS_COMPONENTS = aws_registry.components
157+
COMMON_ALIASES = aws_registry.aliases
170158

171159

172160
class DiagramBuilder:
@@ -180,13 +168,7 @@ def __init__(self, nodes, edges=None, title="diagram", style=None):
180168

181169
def render(self, diagram_type: str, output_format: str) -> str:
182170
"""Main render method that delegates to specific renderers"""
183-
if output_format in ["mermaid", "ascii"]:
184-
raise NotImplementedError(
185-
"Mermaid and ASCII rendering has been removed. "
186-
"Use the agent's LLM capabilities to generate mermaid code directly."
187-
)
188171

189-
# Delegate to specific diagram type methods
190172
method_map = {
191173
"cloud": self._render_cloud,
192174
"graph": self._render_graph,
@@ -198,13 +180,6 @@ def render(self, diagram_type: str, output_format: str) -> str:
198180

199181
return method_map[diagram_type](output_format)
200182

201-
def _render_text(self, output_format: str) -> str:
202-
"""Handle text formats (mermaid/ascii) - unified for all diagram types"""
203-
raise NotImplementedError(
204-
"Mermaid and ASCII rendering has been removed. "
205-
"Use the agent's LLM capabilities to generate mermaid code directly."
206-
)
207-
208183
def _render_cloud(self, output_format: str) -> str:
209184
"""Create AWS architecture diagram"""
210185
if not self.nodes:
@@ -227,16 +202,34 @@ def _render_cloud(self, output_format: str) -> str:
227202
for node_id, node_type in invalid_nodes:
228203
# Find close matches
229204
close_matches = [k for k in COMMON_ALIASES.keys() if node_type.lower() in k or k in node_type.lower()]
205+
# Find canonical names for suggestions
206+
canonical_suggestions = [COMMON_ALIASES[k] for k in close_matches[:3]] if close_matches else []
207+
230208
if close_matches:
231-
suggestions.append(f" - '{node_id}' (type: '{node_type}') -> try: {close_matches[:3]}")
209+
suggestions.append(
210+
f" - '{node_id}' (type: '{node_type}') -> try: \
211+
{close_matches[:3]} (maps to: {canonical_suggestions})"
212+
)
232213
else:
233214
suggestions.append(f" - '{node_id}' (type: '{node_type}') -> no close matches found")
234215

235-
common_types = ["ec2", "s3", "lambda", "rds", "api_gateway", "cloudfront", "route53", "elb"]
216+
common_types = [
217+
"ec2",
218+
"s3",
219+
"lambda",
220+
"rds",
221+
"api_gateway",
222+
"cloudfront",
223+
"route53",
224+
"elb",
225+
"opensearch",
226+
"dynamodb",
227+
]
236228
error_msg = (
237229
f"Invalid AWS component types found:\n{chr(10).join(suggestions)}\n\n"
238230
f"Common types: {common_types}\nNote: "
239-
f"All 532+ AWS components are supported - try the exact AWS service name"
231+
f"All 532+ AWS components are supported - \
232+
try using one of the aliases in COMMON_ALIASES or the exact AWS service name"
240233
)
241234
raise ValueError(error_msg)
242235

@@ -381,12 +374,6 @@ def __init__(
381374

382375
def render(self, output_format: str = "png") -> str:
383376
"""Render the UML diagram based on type"""
384-
# Handle text formats (mermaid/ascii) for UML diagrams
385-
if output_format in ["mermaid", "ascii"]:
386-
raise NotImplementedError(
387-
"Mermaid and ASCII rendering has been removed. "
388-
"Use the agent's LLM capabilities to generate mermaid code directly."
389-
)
390377

391378
method_map = {
392379
# Structural diagrams
@@ -1041,13 +1028,6 @@ def _add_class_relationship(self, dot: graphviz.Digraph, rel: Dict):
10411028
multiplicity = rel.get("multiplicity", "")
10421029
dot.edge(rel["from"], rel["to"], label=multiplicity)
10431030

1044-
def _render_text(self, output_format: str) -> str:
1045-
"""Handle text formats (mermaid/ascii) for UML diagrams"""
1046-
raise NotImplementedError(
1047-
"Mermaid and ASCII rendering has been removed. "
1048-
"Use the agent's LLM capabilities to generate mermaid code directly."
1049-
)
1050-
10511031
def _save_diagram(self, dot: graphviz.Digraph, output_format: str) -> str:
10521032
"""Save diagram and return file path"""
10531033
output_path = save_diagram_to_directory(self.title, "")

0 commit comments

Comments
 (0)