Skip to content

Commit 03de7af

Browse files
2 parents 3d0a95a + 9821894 commit 03de7af

File tree

6 files changed

+37
-186
lines changed

6 files changed

+37
-186
lines changed

backend/src/gemini_llm.py

Lines changed: 0 additions & 54 deletions
This file was deleted.

backend/src/generate_graphDocuments_from_llm.py

Lines changed: 0 additions & 44 deletions
This file was deleted.

backend/src/groq_llama3_llm.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

backend/src/llm.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818
from src.shared.constants import MODEL_VERSIONS
1919

2020

21-
def get_llm(model_version: str):
21+
def get_llm(model: str):
2222
"""Retrieve the specified language model based on the model name."""
23-
env_key = "LLM_MODEL_CONFIG_" + model_version
23+
env_key = "LLM_MODEL_CONFIG_" + model
2424
env_value = os.environ.get(env_key)
2525
logging.info("Model: {}".format(env_key))
26-
if "gemini" in model_version:
26+
if "gemini" in model:
2727
credentials, project_id = google.auth.default()
28-
model_name = MODEL_VERSIONS[model_version]
28+
model_name = MODEL_VERSIONS[model]
2929
llm = ChatVertexAI(
3030
model_name=model_name,
3131
convert_system_message_to_human=True,
@@ -40,15 +40,15 @@ def get_llm(model_version: str):
4040
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
4141
},
4242
)
43-
elif "openai" in model_version:
44-
model_name = MODEL_VERSIONS[model_version]
43+
elif "openai" in model:
44+
model_name = MODEL_VERSIONS[model]
4545
llm = ChatOpenAI(
4646
api_key=os.environ.get("OPENAI_API_KEY"),
4747
model=model_name,
4848
temperature=0,
4949
)
5050

51-
elif "azure" in model_version:
51+
elif "azure" in model:
5252
model_name, api_endpoint, api_key, api_version = env_value.split(",")
5353
llm = AzureChatOpenAI(
5454
api_key=api_key,
@@ -60,21 +60,21 @@ def get_llm(model_version: str):
6060
timeout=None,
6161
)
6262

63-
elif "anthropic" in model_version:
63+
elif "anthropic" in model:
6464
model_name, api_key = env_value.split(",")
6565
llm = ChatAnthropic(
6666
api_key=api_key, model=model_name, temperature=0, timeout=None
6767
)
6868

69-
elif "fireworks" in model_version:
69+
elif "fireworks" in model:
7070
model_name, api_key = env_value.split(",")
7171
llm = ChatFireworks(api_key=api_key, model=model_name)
7272

73-
elif "groq" in model_version:
73+
elif "groq" in model:
7474
model_name, base_url, api_key = env_value.split(",")
7575
llm = ChatGroq(api_key=api_key, model_name=model_name, temperature=0)
7676

77-
elif "bedrock" in model_version:
77+
elif "bedrock" in model:
7878
model_name, aws_access_key, aws_secret_key, region_name = env_value.split(",")
7979
bedrock_client = boto3.client(
8080
service_name="bedrock-runtime",
@@ -87,17 +87,27 @@ def get_llm(model_version: str):
8787
client=bedrock_client, model_id=model_name, model_kwargs=dict(temperature=0)
8888
)
8989

90-
elif "ollama" in model_version:
90+
elif "ollama" in model:
9191
model_name, base_url = env_value.split(",")
9292
llm = ChatOllama(base_url=base_url, model=model_name)
9393

94-
else:
94+
elif "diffbot" in model:
9595
model_name = "diffbot"
9696
llm = DiffbotGraphTransformer(
9797
diffbot_api_key=os.environ.get("DIFFBOT_API_KEY"),
9898
extract_types=["entities", "facts"],
9999
)
100-
logging.info(f"Model created - Model Version: {model_version}")
100+
101+
else:
102+
model_name, api_endpoint, api_key = env_value.split(",")
103+
llm = ChatOpenAI(
104+
api_key=api_key,
105+
base_url=api_endpoint,
106+
model=model_name,
107+
temperature=0,
108+
)
109+
110+
logging.info(f"Model created - Model Version: {model}")
101111
return llm, model_name
102112

103113

@@ -162,8 +172,19 @@ def get_graph_document_list(
162172

163173

164174
def get_graph_from_llm(model, chunkId_chunkDoc_list, allowedNodes, allowedRelationship):
175+
165176
llm, model_name = get_llm(model)
166177
combined_chunk_document_list = get_combined_chunks(chunkId_chunkDoc_list)
178+
179+
if allowedNodes is None or allowedNodes=="":
180+
allowedNodes =[]
181+
else:
182+
allowedNodes = allowedNodes.split(',')
183+
if allowedRelationship is None or allowedRelationship=="":
184+
allowedRelationship=[]
185+
else:
186+
allowedRelationship = allowedRelationship.split(',')
187+
167188
graph_document_list = get_graph_document_list(
168189
llm, combined_chunk_document_list, allowedNodes, allowedRelationship
169190
)

backend/src/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from src.graphDB_dataAccess import graphDBdataAccess
99
from src.document_sources.local_file import get_documents_from_file_by_path
1010
from src.entities.source_node import sourceNode
11-
from src.generate_graphDocuments_from_llm import generate_graphDocuments
11+
from src.llm import get_graph_from_llm
1212
from src.document_sources.gcs_bucket import *
1313
from src.document_sources.s3_bucket import *
1414
from src.document_sources.wikipedia import *
@@ -373,7 +373,7 @@ def processing_chunks(chunkId_chunkDoc_list,graph,uri, userName, password, datab
373373

374374
update_embedding_create_vector_index( graph, chunkId_chunkDoc_list, file_name)
375375
logging.info("Get graph document list from models")
376-
graph_documents = generate_graphDocuments(model, graph, chunkId_chunkDoc_list, allowedNodes, allowedRelationship)
376+
graph_documents = get_graph_from_llm(model, chunkId_chunkDoc_list, allowedNodes, allowedRelationship)
377377
cleaned_graph_documents = handle_backticks_nodes_relationship_id_type(graph_documents)
378378
save_graphDocuments_in_neo4j(graph, cleaned_graph_documents)
379379
chunks_and_graphDocuments_list = get_chunk_and_graphDocument(cleaned_graph_documents, chunkId_chunkDoc_list)

backend/src/openAI_llm.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

0 commit comments

Comments
 (0)