99from dataclasses import asdict
1010from typing import Dict , List , Union
1111
12- from api .api import AssistantMessage , CompletionRequest , OpenAiApiGenerator , UserMessage
12+ from api .api import CompletionRequest , OpenAiApiGenerator
13+ from api .models import get_model_info_list , retrieve_model_info
1314
1415from build .builder import BuilderArgs , TokenizerArgs
1516from flask import Flask , request , Response
1617from generate import GeneratorArgs
1718
1819
19- """
20- Creates a flask app that can be used to serve the model as a chat API.
21- """
22- app = Flask (__name__ )
23- # Messages and gen are kept global so they can be accessed by the flask app endpoints.
24- messages : list = []
25- gen : OpenAiApiGenerator = None
20+ def create_app (args ):
21+ """
22+ Creates a flask app that can be used to serve the model as a chat API.
23+ """
24+ app = Flask (__name__ )
2625
26+ gen : OpenAiApiGenerator = initialize_generator (args )
2727
28- def _del_none (d : Union [Dict , List ]) -> Union [Dict , List ]:
29- """Recursively delete None values from a dictionary."""
30- if type (d ) is dict :
31- return {k : _del_none (v ) for k , v in d .items () if v }
32- elif type (d ) is list :
33- return [_del_none (v ) for v in d if v ]
34- return d
28+ def _del_none (d : Union [Dict , List ]) -> Union [Dict , List ]:
29+ """Recursively delete None values from a dictionary."""
30+ if type (d ) is dict :
31+ return {k : _del_none (v ) for k , v in d .items () if v }
32+ elif type (d ) is list :
33+ return [_del_none (v ) for v in d if v ]
34+ return d
3535
36+ @app .route ("/chat" , methods = ["POST" ])
37+ def chat_endpoint ():
38+ """
39+ Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
40+ This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
3641
37- @app .route ("/chat" , methods = ["POST" ])
38- def chat_endpoint ():
39- """
40- Endpoint for the Chat API. This endpoint is used to generate a response to a user prompt.
41- This endpoint emulates the behavior of the OpenAI Chat API. (https://platform.openai.com/docs/api-reference/chat)
42+ ** Warning ** : Not all arguments of the CompletionRequest are consumed.
4243
43- ** Warning ** : Not all arguments of the CompletionRequest are consumed .
44+ See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details .
4445
45- See https://github.com/pytorch/torchchat/issues/973 and the OpenAiApiGenerator class for more details.
46+ If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
47+ a single CompletionResponse object will be returned.
48+ """
4649
47- If stream is set to true, the response will be streamed back as a series of CompletionResponseChunk objects. Otherwise,
48- a single CompletionResponse object will be returned.
49- """
50+ print (" === Completion Request ===" )
5051
51- print (" === Completion Request ===" )
52+ # Parse the request in to a CompletionRequest object
53+ data = request .get_json ()
54+ req = CompletionRequest (** data )
5255
53- # Parse the request in to a CompletionRequest object
54- data = request .get_json ()
55- req = CompletionRequest (** data )
56+ if data .get ("stream" ) == "true" :
5657
57- # Add the user message to our internal message history.
58- messages . append ( UserMessage ( ** req . messages [ - 1 ]))
58+ def chunk_processor ( chunked_completion_generator ):
59+ """Inline function for postprocessing CompletionResponseChunk objects.
5960
60- if data .get ("stream" ) == "true" :
61+ Here, we just jsonify the chunk and yield it as a string.
62+ """
63+ for chunk in chunked_completion_generator :
64+ if (next_tok := chunk .choices [0 ].delta .content ) is None :
65+ next_tok = ""
66+ print (next_tok , end = "" )
67+ yield json .dumps (_del_none (asdict (chunk )))
6168
62- def chunk_processor (chunked_completion_generator ):
63- """Inline function for postprocessing CompletionResponseChunk objects.
69+ return Response (
70+ chunk_processor (gen .chunked_completion (req )),
71+ mimetype = "text/event-stream" ,
72+ )
73+ else :
74+ response = gen .sync_completion (req )
6475
65- Here, we just jsonify the chunk and yield it as a string.
66- """
67- messages .append (AssistantMessage (content = "" ))
68- for chunk in chunked_completion_generator :
69- if (next_tok := chunk .choices [0 ].delta .content ) is None :
70- next_tok = ""
71- messages [- 1 ].content += next_tok
72- print (next_tok , end = "" )
73- yield json .dumps (_del_none (asdict (chunk )))
76+ return json .dumps (_del_none (asdict (response )))
7477
75- return Response (
76- chunk_processor (gen .chunked_completion (req )), mimetype = "text/event-stream"
77- )
78- else :
79- response = gen .sync_completion (req )
78+ @app .route ("/models" , methods = ["GET" ])
79+ def models_endpoint ():
80+ return json .dumps (asdict (get_model_info_list (args )))
8081
81- messages .append (response .choices [0 ].message )
82- print (messages [- 1 ].content )
82+ @app .route ("/models/<model_id>" , methods = ["GET" ])
83+ def models_retrieve_endpoint (model_id ):
84+ if response := retrieve_model_info (args , model_id ):
85+ return json .dumps (asdict (response ))
86+ else :
87+ return "Model not found" , 404
8388
84- return json . dumps ( _del_none ( asdict ( response )))
89+ return app
8590
8691
8792def initialize_generator (args ) -> OpenAiApiGenerator :
@@ -103,6 +108,5 @@ def initialize_generator(args) -> OpenAiApiGenerator:
103108
104109
105110def main (args ):
106- global gen
107- gen = initialize_generator (args )
111+ app = create_app (args )
108112 app .run ()
0 commit comments