18
18
from src .shared .constants import MODEL_VERSIONS
19
19
20
20
21
- def get_llm (model_version : str ):
21
+ def get_llm (model : str ):
22
22
"""Retrieve the specified language model based on the model name."""
23
- env_key = "LLM_MODEL_CONFIG_" + model_version
23
+ env_key = "LLM_MODEL_CONFIG_" + model
24
24
env_value = os .environ .get (env_key )
25
25
logging .info ("Model: {}" .format (env_key ))
26
- if "gemini" in model_version :
26
+ if "gemini" in model :
27
27
credentials , project_id = google .auth .default ()
28
- model_name = MODEL_VERSIONS [model_version ]
28
+ model_name = MODEL_VERSIONS [model ]
29
29
llm = ChatVertexAI (
30
30
model_name = model_name ,
31
31
convert_system_message_to_human = True ,
@@ -40,15 +40,15 @@ def get_llm(model_version: str):
40
40
HarmCategory .HARM_CATEGORY_SEXUALLY_EXPLICIT : HarmBlockThreshold .BLOCK_NONE ,
41
41
},
42
42
)
43
- elif "openai" in model_version :
44
- model_name = MODEL_VERSIONS [model_version ]
43
+ elif "openai" in model :
44
+ model_name = MODEL_VERSIONS [model ]
45
45
llm = ChatOpenAI (
46
46
api_key = os .environ .get ("OPENAI_API_KEY" ),
47
47
model = model_name ,
48
48
temperature = 0 ,
49
49
)
50
50
51
- elif "azure" in model_version :
51
+ elif "azure" in model :
52
52
model_name , api_endpoint , api_key , api_version = env_value .split ("," )
53
53
llm = AzureChatOpenAI (
54
54
api_key = api_key ,
@@ -60,21 +60,21 @@ def get_llm(model_version: str):
60
60
timeout = None ,
61
61
)
62
62
63
- elif "anthropic" in model_version :
63
+ elif "anthropic" in model :
64
64
model_name , api_key = env_value .split ("," )
65
65
llm = ChatAnthropic (
66
66
api_key = api_key , model = model_name , temperature = 0 , timeout = None
67
67
)
68
68
69
- elif "fireworks" in model_version :
69
+ elif "fireworks" in model :
70
70
model_name , api_key = env_value .split ("," )
71
71
llm = ChatFireworks (api_key = api_key , model = model_name )
72
72
73
- elif "groq" in model_version :
73
+ elif "groq" in model :
74
74
model_name , base_url , api_key = env_value .split ("," )
75
75
llm = ChatGroq (api_key = api_key , model_name = model_name , temperature = 0 )
76
76
77
- elif "bedrock" in model_version :
77
+ elif "bedrock" in model :
78
78
model_name , aws_access_key , aws_secret_key , region_name = env_value .split ("," )
79
79
bedrock_client = boto3 .client (
80
80
service_name = "bedrock-runtime" ,
@@ -87,17 +87,27 @@ def get_llm(model_version: str):
87
87
client = bedrock_client , model_id = model_name , model_kwargs = dict (temperature = 0 )
88
88
)
89
89
90
- elif "ollama" in model_version :
90
+ elif "ollama" in model :
91
91
model_name , base_url = env_value .split ("," )
92
92
llm = ChatOllama (base_url = base_url , model = model_name )
93
93
94
- else :
94
+ elif "diffbot" in model :
95
95
model_name = "diffbot"
96
96
llm = DiffbotGraphTransformer (
97
97
diffbot_api_key = os .environ .get ("DIFFBOT_API_KEY" ),
98
98
extract_types = ["entities" , "facts" ],
99
99
)
100
- logging .info (f"Model created - Model Version: { model_version } " )
100
+
101
+ else :
102
+ model_name , api_endpoint , api_key = env_value .split ("," )
103
+ llm = ChatOpenAI (
104
+ api_key = api_key ,
105
+ base_url = api_endpoint ,
106
+ model = model_name ,
107
+ temperature = 0 ,
108
+ )
109
+
110
+ logging .info (f"Model created - Model Version: { model } " )
101
111
return llm , model_name
102
112
103
113
@@ -162,8 +172,19 @@ def get_graph_document_list(
162
172
163
173
164
174
def get_graph_from_llm (model , chunkId_chunkDoc_list , allowedNodes , allowedRelationship ):
175
+
165
176
llm , model_name = get_llm (model )
166
177
combined_chunk_document_list = get_combined_chunks (chunkId_chunkDoc_list )
178
+
179
+ if allowedNodes is None or allowedNodes == "" :
180
+ allowedNodes = []
181
+ else :
182
+ allowedNodes = allowedNodes .split (',' )
183
+ if allowedRelationship is None or allowedRelationship == "" :
184
+ allowedRelationship = []
185
+ else :
186
+ allowedRelationship = allowedRelationship .split (',' )
187
+
167
188
graph_document_list = get_graph_document_list (
168
189
llm , combined_chunk_document_list , allowedNodes , allowedRelationship
169
190
)
0 commit comments