Skip to content

Commit ca65432

Browse files
aashipandyakartikpersistent
authored andcommitted
connection creation in extract and CancelledError handling for sse (#584)
1 parent 2091420 commit ca65432

File tree

2 files changed

+57
-45
lines changed

2 files changed

+57
-45
lines changed

backend/score.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -165,33 +165,34 @@ async def extract_knowledge_graph_from_file(
165165
Nodes and Relations created in Neo4j databse for the pdf file
166166
"""
167167
try:
168-
graph = create_graph_database_connection(uri, userName, password, database)
169-
graphDb_data_Access = graphDBdataAccess(graph)
170168
if source_type == 'local file':
171169
result = await asyncio.to_thread(
172-
extract_graph_from_file_local_file, graph, model, merged_file_path, file_name, allowedNodes, allowedRelationship, uri)
170+
extract_graph_from_file_local_file, uri, userName, password, database, model, merged_file_path, file_name, allowedNodes, allowedRelationship)
173171

174172
elif source_type == 's3 bucket' and source_url:
175173
result = await asyncio.to_thread(
176-
extract_graph_from_file_s3, graph, model, source_url, aws_access_key_id, aws_secret_access_key, allowedNodes, allowedRelationship)
174+
extract_graph_from_file_s3, uri, userName, password, database, model, source_url, aws_access_key_id, aws_secret_access_key, allowedNodes, allowedRelationship)
177175

178176
elif source_type == 'web-url':
179177
result = await asyncio.to_thread(
180-
extract_graph_from_web_page, graph, model, source_url, allowedNodes, allowedRelationship)
178+
extract_graph_from_web_page, uri, userName, password, database, model, source_url, allowedNodes, allowedRelationship)
181179

182180
elif source_type == 'youtube' and source_url:
183181
result = await asyncio.to_thread(
184-
extract_graph_from_file_youtube, graph, model, source_url, allowedNodes, allowedRelationship)
182+
extract_graph_from_file_youtube, uri, userName, password, database, model, source_url, allowedNodes, allowedRelationship)
185183

186184
elif source_type == 'Wikipedia' and wiki_query:
187185
result = await asyncio.to_thread(
188-
extract_graph_from_file_Wikipedia, graph, model, wiki_query, max_sources, language, allowedNodes, allowedRelationship)
186+
extract_graph_from_file_Wikipedia, uri, userName, password, database, model, wiki_query, max_sources, language, allowedNodes, allowedRelationship)
189187

190188
elif source_type == 'gcs bucket' and gcs_bucket_name:
191189
result = await asyncio.to_thread(
192-
extract_graph_from_file_gcs, graph, model, gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, allowedNodes, allowedRelationship)
190+
extract_graph_from_file_gcs, uri, userName, password, database, model, gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, access_token, allowedNodes, allowedRelationship)
193191
else:
194192
return create_api_response('Failed',message='source_type is other than accepted source')
193+
194+
graph = create_graph_database_connection(uri, userName, password, database)
195+
graphDb_data_Access = graphDBdataAccess(graph)
195196
if result is not None:
196197
result['db_url'] = uri
197198
result['api_name'] = 'extract'
@@ -443,28 +444,32 @@ async def generate():
443444
if " " in url:
444445
uri= url.replace(" ","+")
445446
while True:
446-
if await request.is_disconnected():
447-
logging.info("Request disconnected")
448-
break
449-
#get the current status of document node
450-
graph = create_graph_database_connection(uri, userName, decoded_password, database)
451-
graphDb_data_Access = graphDBdataAccess(graph)
452-
result = graphDb_data_Access.get_current_status_document_node(file_name)
453-
if result is not None:
454-
status = json.dumps({'fileName':file_name,
455-
'status':result[0]['Status'],
456-
'processingTime':result[0]['processingTime'],
457-
'nodeCount':result[0]['nodeCount'],
458-
'relationshipCount':result[0]['relationshipCount'],
459-
'model':result[0]['model'],
460-
'total_chunks':result[0]['total_chunks'],
461-
'total_pages':result[0]['total_pages'],
462-
'fileSize':result[0]['fileSize'],
463-
'processed_chunk':result[0]['processed_chunk']
464-
})
465-
else:
466-
status = json.dumps({'fileName':file_name, 'status':'Failed'})
467-
yield status
447+
try:
448+
if await request.is_disconnected():
449+
logging.info(" SSE Client disconnected")
450+
break
451+
# get the current status of document node
452+
graph = create_graph_database_connection(uri, userName, decoded_password, database)
453+
graphDb_data_Access = graphDBdataAccess(graph)
454+
result = graphDb_data_Access.get_current_status_document_node(file_name)
455+
if result is not None:
456+
status = json.dumps({'fileName':file_name,
457+
'status':result[0]['Status'],
458+
'processingTime':result[0]['processingTime'],
459+
'nodeCount':result[0]['nodeCount'],
460+
'relationshipCount':result[0]['relationshipCount'],
461+
'model':result[0]['model'],
462+
'total_chunks':result[0]['total_chunks'],
463+
'total_pages':result[0]['total_pages'],
464+
'fileSize':result[0]['fileSize'],
465+
'processed_chunk':result[0]['processed_chunk'],
466+
'fileSource':result[0]['fileSource']
467+
})
468+
else:
469+
status = json.dumps({'fileName':file_name, 'status':'Failed'})
470+
yield status
471+
except asyncio.CancelledError:
472+
logging.info("SSE Connection cancelled")
468473

469474
return EventSourceResponse(generate(),ping=60)
470475

backend/src/main.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def create_source_node_graph_url_wikipedia(graph, model, wiki_query, source_type
215215
lst_file_name.append({'fileName':obj_source_node.file_name,'fileSize':obj_source_node.file_size,'url':obj_source_node.url, 'language':obj_source_node.language, 'status':'Failed'})
216216
return lst_file_name,success_count,failed_count
217217

218-
def extract_graph_from_file_local_file(graph, model, merged_file_path, fileName, allowedNodes, allowedRelationship,uri):
218+
def extract_graph_from_file_local_file(uri, userName, password, database, model, merged_file_path, fileName, allowedNodes, allowedRelationship):
219219

220220
logging.info(f'Process file name :{fileName}')
221221
gcs_file_cache = os.environ.get('GCS_FILE_CACHE')
@@ -227,9 +227,9 @@ def extract_graph_from_file_local_file(graph, model, merged_file_path, fileName,
227227
if pages==None or len(pages)==0:
228228
raise Exception(f'File content is not available for file : {file_name}')
229229

230-
return processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship, True, merged_file_path, uri)
230+
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship, True, merged_file_path)
231231

232-
def extract_graph_from_file_s3(graph, model, source_url, aws_access_key_id, aws_secret_access_key):
232+
def extract_graph_from_file_s3(uri, userName, password, database, model, source_url, aws_access_key_id, aws_secret_access_key, allowedNodes, allowedRelationship):
233233

234234
if(aws_access_key_id==None or aws_secret_access_key==None):
235235
raise Exception('Please provide AWS access and secret keys')
@@ -240,44 +240,44 @@ def extract_graph_from_file_s3(graph, model, source_url, aws_access_key_id, aws_
240240
if pages==None or len(pages)==0:
241241
raise Exception(f'File content is not available for file : {file_name}')
242242

243-
return processing_source(graph, model, file_name, pages)
243+
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)
244244

245-
def extract_graph_from_web_page(graph, model, source_url, allowedNodes, allowedRelationship):
245+
def extract_graph_from_web_page(uri, userName, password, database, model, source_url, allowedNodes, allowedRelationship):
246246

247247
file_name, pages = get_documents_from_web_page(source_url)
248248

249249
if pages==None or len(pages)==0:
250250
raise Exception(f'Content is not available for given URL : {file_name}')
251251

252-
return processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship)
252+
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)
253253

254-
def extract_graph_from_file_youtube(graph, model, source_url, allowedNodes, allowedRelationship):
254+
def extract_graph_from_file_youtube(uri, userName, password, database, model, source_url, allowedNodes, allowedRelationship):
255255

256256
source_type, youtube_url = check_url_source(source_url)
257257
file_name, pages = get_documents_from_youtube(source_url)
258258

259259
if pages==None or len(pages)==0:
260260
raise Exception('Youtube transcript is not available for file : {file_name}')
261261

262-
return processing_source(graph, model, file_name, pages)
262+
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)
263263

264-
def extract_graph_from_file_Wikipedia(graph, model, wiki_query, max_sources, language, allowedNodes, allowedRelationship):
264+
def extract_graph_from_file_Wikipedia(uri, userName, password, database, model, wiki_query, max_sources, language, allowedNodes, allowedRelationship):
265265

266266
file_name, pages = get_documents_from_Wikipedia(wiki_query, language)
267267
if pages==None or len(pages)==0:
268268
raise Exception('Wikipedia page is not available for file : {file_name}')
269269

270-
return processing_source(graph, model, file_name, pages)
270+
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)
271271

272-
def extract_graph_from_file_gcs(graph, model, gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, allowedNodes, allowedRelationship):
272+
def extract_graph_from_file_gcs(uri, userName, password, database, model, gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename, access_token, allowedNodes, allowedRelationship):
273273

274274
file_name, pages = get_documents_from_gcs(gcs_project_id, gcs_bucket_name, gcs_bucket_folder, gcs_blob_filename)
275275
if pages==None or len(pages)==0:
276276
raise Exception(f'File content is not available for file : {file_name}')
277277

278-
return processing_source(graph, model, file_name, pages)
278+
return processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship)
279279

280-
def processing_source(graph, model, file_name, pages, allowedNodes, allowedRelationship, is_uploaded_from_local=None, merged_file_path=None, uri=None):
280+
def processing_source(uri, userName, password, database, model, file_name, pages, allowedNodes, allowedRelationship, is_uploaded_from_local=None, merged_file_path=None):
281281
"""
282282
Extracts a Neo4jGraph from a PDF file based on the model.
283283
@@ -294,6 +294,7 @@ def processing_source(graph, model, file_name, pages, allowedNodes, allowedRelat
294294
status and model as attributes.
295295
"""
296296
start_time = datetime.now()
297+
graph = create_graph_database_connection(uri, userName, password, database)
297298
graphDb_data_Access = graphDBdataAccess(graph)
298299

299300
result = graphDb_data_Access.get_current_status_document_node(file_name)
@@ -344,7 +345,7 @@ def processing_source(graph, model, file_name, pages, allowedNodes, allowedRelat
344345
logging.info('Exit from running loop of processing file')
345346
exit
346347
else:
347-
node_count,rel_count = processing_chunks(selected_chunks,graph,file_name,model,allowedNodes,allowedRelationship,node_count, rel_count)
348+
node_count,rel_count = processing_chunks(selected_chunks,graph,uri, userName, password, database,file_name,model,allowedNodes,allowedRelationship,node_count, rel_count)
348349
end_time = datetime.now()
349350
processed_time = end_time - start_time
350351

@@ -397,8 +398,14 @@ def processing_source(graph, model, file_name, pages, allowedNodes, allowedRelat
397398
else:
398399
logging.info('File does not process because it\'s already in Processing status')
399400

400-
def processing_chunks(chunkId_chunkDoc_list,graph,file_name,model,allowedNodes,allowedRelationship, node_count, rel_count):
401+
def processing_chunks(chunkId_chunkDoc_list,graph,uri, userName, password, database,file_name,model,allowedNodes,allowedRelationship, node_count, rel_count):
401402
#create vector index and update chunk node with embedding
403+
if graph is not None:
404+
if graph._driver._closed:
405+
graph = create_graph_database_connection(uri, userName, password, database)
406+
else:
407+
graph = create_graph_database_connection(uri, userName, password, database)
408+
402409
update_embedding_create_vector_index( graph, chunkId_chunkDoc_list, file_name)
403410
logging.info("Get graph document list from models")
404411
graph_documents = generate_graphDocuments(model, graph, chunkId_chunkDoc_list, allowedNodes, allowedRelationship)

0 commit comments

Comments
 (0)