|
27 | 27 | "\n", |
28 | 28 | "Many use cases such as building a chatbot require text (text2text) generation models like **[BloomZ 7B1](https://huggingface.co/bigscience/bloomz-7b1)**, **[Flan T5 XXL](https://huggingface.co/google/flan-t5-xxl)**, and **[Flan T5 UL2](https://huggingface.co/google/flan-ul2)** to respond to user questions with insightful answers. The **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** models have picked up a lot of general knowledge in training, but we often need to ingest and use a large library of more specific information.\n", |
29 | 29 | "\n", |
30 | | - "In this notebook we will demonstrate how to use **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **GPT-J-6B** embedding model. \n", |
| 30 | + "In this notebook we will demonstrate how to use **BloomZ 7B1**, **Flan T5 XXL**, and **Flan T5 UL2** to answer questions using a library of documents as a reference, by using document embeddings and retrieval. The embeddings are generated from **MiniLM-L6-v2** embedding model. \n", |
31 | 31 | "\n", |
32 | 32 | "**This notebook serves a template such that you can easily replace the example dataset by your own to build a custom question and asnwering application.**" |
33 | 33 | ] |
|
45 | 45 | "cell_type": "code", |
46 | 46 | "execution_count": null, |
47 | 47 | "metadata": { |
48 | | - "collapsed": false, |
49 | 48 | "jupyter": { |
50 | 49 | "outputs_hidden": false |
51 | 50 | }, |
|
57 | 56 | "outputs": [], |
58 | 57 | "source": [ |
59 | 58 | "!pip install --upgrade sagemaker --quiet\n", |
60 | | - "!pip install ipywidgets==7.0.0 --quiet\n", |
61 | | - "!pip install langchain==0.0.148 --quiet\n", |
62 | | - "!pip install faiss-cpu --quiet" |
| 59 | + "!pip install faiss-cpu --quiet\n", |
| 60 | + "!pip install langchain --quiet" |
63 | 61 | ] |
64 | 62 | }, |
65 | 63 | { |
|
70 | 68 | }, |
71 | 69 | "outputs": [], |
72 | 70 | "source": [ |
73 | | - "import time\n", |
74 | | - "import sagemaker, boto3, json\n", |
75 | | - "from sagemaker.session import Session\n", |
76 | | - "from sagemaker.model import Model\n", |
77 | | - "from sagemaker import image_uris, model_uris, script_uris, hyperparameters\n", |
78 | | - "from sagemaker.predictor import Predictor\n", |
| 71 | + "from sagemaker import Session\n", |
79 | 72 | "from sagemaker.utils import name_from_base\n", |
80 | | - "from typing import Any, Dict, List, Optional\n", |
81 | | - "from langchain.embeddings import SagemakerEndpointEmbeddings\n", |
82 | | - "from langchain.llms.sagemaker_endpoint import ContentHandlerBase\n", |
| 73 | + "from sagemaker.jumpstart.model import JumpStartModel\n", |
83 | 74 | "\n", |
84 | | - "sagemaker_session = Session()\n", |
85 | | - "aws_role = sagemaker_session.get_caller_identity_arn()\n", |
86 | | - "aws_region = boto3.Session().region_name\n", |
87 | | - "sess = sagemaker.Session()\n", |
88 | | - "model_version = \"1.*\"" |
89 | | - ] |
90 | | - }, |
91 | | - { |
92 | | - "cell_type": "code", |
93 | | - "execution_count": null, |
94 | | - "metadata": { |
95 | | - "tags": [] |
96 | | - }, |
97 | | - "outputs": [], |
98 | | - "source": [ |
99 | | - "def query_endpoint_with_json_payload(encoded_json, endpoint_name, content_type=\"application/json\"):\n", |
100 | | - " client = boto3.client(\"runtime.sagemaker\")\n", |
101 | | - " response = client.invoke_endpoint(\n", |
102 | | - " EndpointName=endpoint_name, ContentType=content_type, Body=encoded_json\n", |
103 | | - " )\n", |
104 | | - " return response\n", |
105 | | - "\n", |
106 | | - "\n", |
107 | | - "def parse_response_model_flan_t5(query_response):\n", |
108 | | - " model_predictions = json.loads(query_response[\"Body\"].read())\n", |
109 | | - " generated_text = model_predictions[\"generated_texts\"]\n", |
110 | | - " return generated_text\n", |
111 | | - "\n", |
112 | | - "\n", |
113 | | - "def parse_response_multiple_texts_bloomz(query_response):\n", |
114 | | - " generated_text = []\n", |
115 | | - " model_predictions = json.loads(query_response[\"Body\"].read())\n", |
116 | | - " for x in model_predictions[0]:\n", |
117 | | - " generated_text.append(x[\"generated_text\"])\n", |
118 | | - " return generated_text" |
| 75 | + "sagemaker_session = Session()" |
119 | 76 | ] |
120 | 77 | }, |
121 | 78 | { |
122 | 79 | "cell_type": "markdown", |
123 | 80 | "metadata": {}, |
124 | 81 | "source": [ |
125 | | - "Deploy SageMaker endpoint(s) for large language models and GPT-J 6B embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance." |
| 82 | + "Deploy SageMaker endpoint(s) for large language models and MiniLM-L6-v2 embedding model. Please uncomment the entries as below if you want to deploy multiple LLM models to compare their performance." |
126 | 83 | ] |
127 | 84 | }, |
128 | 85 | { |
|
135 | 92 | "source": [ |
136 | 93 | "_MODEL_CONFIG_ = {\n", |
137 | 94 | " \"huggingface-text2text-flan-t5-xxl\": {\n", |
| 95 | + " \"model_version\": \"2.*\",\n", |
138 | 96 | " \"instance type\": \"ml.g5.12xlarge\",\n", |
139 | | - " \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n", |
140 | | - " \"parse_function\": parse_response_model_flan_t5,\n", |
141 | | - " \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n", |
142 | 97 | " },\n", |
143 | | - " \"huggingface-textembedding-gpt-j-6b\": {\n", |
| 98 | + " \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n", |
| 99 | + " \"model_version\": \"1.*\",\n", |
144 | 100 | " \"instance type\": \"ml.g5.24xlarge\",\n", |
145 | | - " \"env\": {\"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\", \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"},\n", |
146 | 101 | " },\n", |
147 | | - " # \"huggingface-textgeneration1-bloomz-7b1-fp16\": {\n", |
148 | | - " # \"instance type\": \"ml.g5.12xlarge\",\n", |
149 | | - " # \"env\": {},\n", |
150 | | - " # \"parse_function\": parse_response_multiple_texts_bloomz,\n", |
151 | | - " # \"prompt\": \"\"\"question: \\\"{question}\"\\\\n\\nContext: \\\"{context}\"\\\\n\\nAnswer:\"\"\",\n", |
| 102 | + " # \"huggingface-textembedding-all-MiniLM-L6-v2\": {\n", |
| 103 | + " # \"model_version\": \"3.*\",\n", |
| 104 | + " # \"instance type\": \"ml.g5.12xlarge\"\n", |
152 | 105 | " # },\n", |
153 | 106 | " # \"huggingface-text2text-flan-ul2-bf16\": {\n", |
154 | | - " # \"instance type\": \"ml.g5.24xlarge\",\n", |
155 | | - " # \"env\": {\n", |
156 | | - " # \"SAGEMAKER_MODEL_SERVER_WORKERS\": \"1\",\n", |
157 | | - " # \"TS_DEFAULT_WORKERS_PER_MODEL\": \"1\"\n", |
158 | | - " # },\n", |
159 | | - " # \"parse_function\": parse_response_model_flan_t5,\n", |
160 | | - " # \"prompt\": \"\"\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\"\",\n", |
161 | | - " # },\n", |
| 107 | + " # \"model_version\": \"2.*\",\n", |
| 108 | + " # \"instance type\": \"ml.g5.24xlarge\"\n", |
| 109 | + " # }\n", |
162 | 110 | "}" |
163 | 111 | ] |
164 | 112 | }, |
|
168 | 116 | "metadata": {}, |
169 | 117 | "outputs": [], |
170 | 118 | "source": [ |
171 | | - "newline, bold, unbold = \"\\n\", \"\\033[1m\", \"\\033[0m\"\n", |
172 | | - "\n", |
173 | 119 | "for model_id in _MODEL_CONFIG_:\n", |
174 | 120 | " endpoint_name = name_from_base(f\"jumpstart-example-raglc-{model_id}\")\n", |
175 | 121 | " inference_instance_type = _MODEL_CONFIG_[model_id][\"instance type\"]\n", |
| 122 | + " model_version = _MODEL_CONFIG_[model_id][\"model_version\"]\n", |
176 | 123 | "\n", |
177 | | - " # Retrieve the inference container uri. This is the base HuggingFace container image for the default model above.\n", |
178 | | - " deploy_image_uri = image_uris.retrieve(\n", |
179 | | - " region=None,\n", |
180 | | - " framework=None, # automatically inferred from model_id\n", |
181 | | - " image_scope=\"inference\",\n", |
182 | | - " model_id=model_id,\n", |
183 | | - " model_version=model_version,\n", |
184 | | - " instance_type=inference_instance_type,\n", |
185 | | - " )\n", |
186 | | - " # Retrieve the model uri.\n", |
187 | | - " model_uri = model_uris.retrieve(\n", |
188 | | - " model_id=model_id, model_version=model_version, model_scope=\"inference\"\n", |
189 | | - " )\n", |
190 | | - " model_inference = Model(\n", |
191 | | - " image_uri=deploy_image_uri,\n", |
192 | | - " model_data=model_uri,\n", |
193 | | - " role=aws_role,\n", |
194 | | - " predictor_cls=Predictor,\n", |
195 | | - " name=endpoint_name,\n", |
196 | | - " env=_MODEL_CONFIG_[model_id][\"env\"],\n", |
197 | | - " )\n", |
198 | | - " model_predictor_inference = model_inference.deploy(\n", |
199 | | - " initial_instance_count=1,\n", |
200 | | - " instance_type=inference_instance_type,\n", |
201 | | - " predictor_cls=Predictor,\n", |
202 | | - " endpoint_name=endpoint_name,\n", |
203 | | - " )\n", |
204 | | - " print(f\"{bold}Model {model_id} has been deployed successfully.{unbold}{newline}\")\n", |
205 | | - " _MODEL_CONFIG_[model_id][\"endpoint_name\"] = endpoint_name" |
| 124 | + " print(f\"Deploying {model_id}...\")\n", |
| 125 | + "\n", |
| 126 | + " model = JumpStartModel(model_id=model_id, model_version=model_version)\n", |
| 127 | + "\n", |
| 128 | + " try:\n", |
| 129 | + " predictor = model.deploy(\n", |
| 130 | + " initial_instance_count=1,\n", |
| 131 | + " instance_type=inference_instance_type,\n", |
| 132 | + " endpoint_name=name_from_base(f\"jumpstart-example-raglc-{model_id}\"),\n", |
| 133 | + " )\n", |
| 134 | + " print(f\"Deployed endpoint: {predictor.endpoint_name}\")\n", |
| 135 | + " _MODEL_CONFIG_[model_id][\"predictor\"] = predictor\n", |
| 136 | + " except Exception as e:\n", |
| 137 | + " print(f\"Error deploying {model_id}: {str(e)}\")\n", |
| 138 | + "\n", |
| 139 | + "print(\"Deployment process completed.\")" |
206 | 140 | ] |
207 | 141 | }, |
208 | 142 | { |
|
229 | 163 | "metadata": {}, |
230 | 164 | "outputs": [], |
231 | 165 | "source": [ |
232 | | - "payload = {\n", |
233 | | - " \"text_inputs\": question,\n", |
234 | | - " \"max_length\": 100,\n", |
235 | | - " \"num_return_sequences\": 1,\n", |
236 | | - " \"top_k\": 50,\n", |
237 | | - " \"top_p\": 0.95,\n", |
238 | | - " \"do_sample\": True,\n", |
239 | | - "}\n", |
240 | | - "\n", |
241 | 166 | "list_of_LLMs = list(_MODEL_CONFIG_.keys())\n", |
242 | | - "list_of_LLMs.remove(\"huggingface-textembedding-gpt-j-6b\") # remove the embedding model\n", |
243 | | - "\n", |
| 167 | + "list_of_LLMs = [model for model in list_of_LLMs if \"textembedding\" not in model]\n", |
244 | 168 | "\n", |
245 | 169 | "for model_id in list_of_LLMs:\n", |
246 | | - " endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n", |
247 | | - " query_response = query_endpoint_with_json_payload(\n", |
248 | | - " json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n", |
249 | | - " )\n", |
250 | | - " generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n", |
251 | | - " print(f\"For model: {model_id}, the generated output is: {generated_texts[0]}\\n\")" |
| 170 | + " predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n", |
| 171 | + " response = predictor.predict({\"inputs\": question})\n", |
| 172 | + " print(f\"For model: {model_id}, the generated output is:\\n\")\n", |
| 173 | + " print(f\"{response[0]['generated_text']}\\n\")" |
252 | 174 | ] |
253 | 175 | }, |
254 | 176 | { |
|
283 | 205 | "metadata": {}, |
284 | 206 | "outputs": [], |
285 | 207 | "source": [ |
286 | | - "parameters = {\n", |
287 | | - " \"max_length\": 200,\n", |
288 | | - " \"num_return_sequences\": 1,\n", |
289 | | - " \"top_k\": 250,\n", |
290 | | - " \"top_p\": 0.95,\n", |
291 | | - " \"do_sample\": False,\n", |
292 | | - " \"temperature\": 1,\n", |
293 | | - "}\n", |
| 208 | + "prompt = f\"Answer based on context:\\n\\n{context}\\n\\n{question}\"\n", |
294 | 209 | "\n", |
295 | 210 | "for model_id in list_of_LLMs:\n", |
296 | | - " endpoint_name = _MODEL_CONFIG_[model_id][\"endpoint_name\"]\n", |
297 | | - "\n", |
298 | | - " prompt = _MODEL_CONFIG_[model_id][\"prompt\"]\n", |
299 | | - "\n", |
300 | | - " text_input = prompt.replace(\"{context}\", context)\n", |
301 | | - " text_input = text_input.replace(\"{question}\", question)\n", |
302 | | - " payload = {\"text_inputs\": text_input, **parameters}\n", |
303 | | - "\n", |
304 | | - " query_response = query_endpoint_with_json_payload(\n", |
305 | | - " json.dumps(payload).encode(\"utf-8\"), endpoint_name=endpoint_name\n", |
306 | | - " )\n", |
307 | | - " generated_texts = _MODEL_CONFIG_[model_id][\"parse_function\"](query_response)\n", |
308 | | - " print(\n", |
309 | | - " f\"{bold}For model: {model_id}, the generated output is: {generated_texts[0]}{unbold}{newline}\"\n", |
310 | | - " )" |
| 211 | + " predictor = _MODEL_CONFIG_[model_id][\"predictor\"]\n", |
| 212 | + " response = predictor.predict({\"inputs\": prompt})\n", |
| 213 | + " print(f\"For model: {model_id}, the generated output is:\\n\")\n", |
| 214 | + " print(f\"{response[0]['generated_text']}\\n\")" |
311 | 215 | ] |
312 | 216 | }, |
313 | 217 | { |
|
330 | 234 | "\n", |
331 | 235 | "To achieve that, we will do following.\n", |
332 | 236 | "\n", |
333 | | - "1. **Generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**\n", |
| 237 | + "1. **Generate embedings for each of document in the knowledge library with SageMaker MiniLM-L6-v2 embedding model.**\n", |
334 | 238 | "2. **Identify top K most relevant documents based on user query.**\n", |
335 | 239 | " - 2.1 **For a query of your interest, generate the embedding of the query using the same embedding model.**\n", |
336 | 240 | " - 2.2 **Search the indexes of top K most relevant documents in the embedding space using in-memory Faiss search.**\n", |
|
365 | 269 | "outputs": [], |
366 | 270 | "source": [ |
367 | 271 | "from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler\n", |
| 272 | + "from langchain.embeddings import SagemakerEndpointEmbeddings\n", |
| 273 | + "from typing import List\n", |
| 274 | + "import boto3\n", |
| 275 | + "\n", |
| 276 | + "aws_region = boto3.Session().region_name\n", |
368 | 277 | "\n", |
369 | 278 | "\n", |
370 | 279 | "class SagemakerEndpointEmbeddingsJumpStart(SagemakerEndpointEmbeddings):\n", |
|
405 | 314 | "\n", |
406 | 315 | "\n", |
407 | 316 | "content_handler = ContentHandler()\n", |
| 317 | + "endpoint_name = _MODEL_CONFIG_[\"huggingface-textembedding-all-MiniLM-L6-v2\"][\n", |
| 318 | + " \"predictor\"\n", |
| 319 | + "].endpoint_name\n", |
408 | 320 | "\n", |
409 | 321 | "embeddings = SagemakerEndpointEmbeddingsJumpStart(\n", |
410 | | - " endpoint_name=_MODEL_CONFIG_[\"huggingface-textembedding-gpt-j-6b\"][\"endpoint_name\"],\n", |
| 322 | + " endpoint_name=endpoint_name,\n", |
411 | 323 | " region_name=aws_region,\n", |
412 | 324 | " content_handler=content_handler,\n", |
413 | 325 | ")" |
|
428 | 340 | "source": [ |
429 | 341 | "from langchain.llms.sagemaker_endpoint import LLMContentHandler, SagemakerEndpoint\n", |
430 | 342 | "\n", |
431 | | - "parameters = {\n", |
432 | | - " \"max_length\": 200,\n", |
433 | | - " \"num_return_sequences\": 1,\n", |
434 | | - " \"top_k\": 250,\n", |
435 | | - " \"top_p\": 0.95,\n", |
436 | | - " \"do_sample\": False,\n", |
437 | | - " \"temperature\": 1,\n", |
438 | | - "}\n", |
439 | | - "\n", |
440 | 343 | "\n", |
441 | 344 | "class ContentHandler(LLMContentHandler):\n", |
442 | 345 | " content_type = \"application/json\"\n", |
443 | 346 | " accepts = \"application/json\"\n", |
444 | 347 | "\n", |
445 | 348 | " def transform_input(self, prompt: str, model_kwargs={}) -> bytes:\n", |
446 | | - " input_str = json.dumps({\"text_inputs\": prompt, **model_kwargs})\n", |
| 349 | + " input_str = json.dumps({\"inputs\": prompt, **model_kwargs})\n", |
447 | 350 | " return input_str.encode(\"utf-8\")\n", |
448 | 351 | "\n", |
449 | 352 | " def transform_output(self, output: bytes) -> str:\n", |
450 | 353 | " response_json = json.loads(output.read().decode(\"utf-8\"))\n", |
451 | | - " return response_json[\"generated_texts\"][0]\n", |
| 354 | + " return response_json[0][\"generated_text\"]\n", |
452 | 355 | "\n", |
453 | 356 | "\n", |
454 | 357 | "content_handler = ContentHandler()\n", |
| 358 | + "endpoint_name = _MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"predictor\"].endpoint_name\n", |
| 359 | + "\n", |
| 360 | + "parameters = {\n", |
| 361 | + " \"max_length\": 200,\n", |
| 362 | + " \"num_return_sequences\": 1,\n", |
| 363 | + " \"top_k\": 250,\n", |
| 364 | + " \"top_p\": 0.95,\n", |
| 365 | + " \"do_sample\": False,\n", |
| 366 | + " \"temperature\": 1,\n", |
| 367 | + "}\n", |
455 | 368 | "\n", |
456 | 369 | "sm_llm = SagemakerEndpoint(\n", |
457 | | - " endpoint_name=_MODEL_CONFIG_[\"huggingface-text2text-flan-t5-xxl\"][\"endpoint_name\"],\n", |
| 370 | + " endpoint_name=endpoint_name,\n", |
458 | 371 | " region_name=aws_region,\n", |
459 | 372 | " model_kwargs=parameters,\n", |
460 | 373 | " content_handler=content_handler,\n", |
|
568 | 481 | "from langchain.text_splitter import CharacterTextSplitter\n", |
569 | 482 | "from langchain import PromptTemplate\n", |
570 | 483 | "from langchain.chains.question_answering import load_qa_chain\n", |
571 | | - "from langchain.document_loaders.csv_loader import CSVLoader" |
| 484 | + "from langchain.document_loaders.csv_loader import CSVLoader\n", |
| 485 | + "import json" |
572 | 486 | ] |
573 | 487 | }, |
574 | 488 | { |
|
670 | 584 | "cell_type": "markdown", |
671 | 585 | "metadata": {}, |
672 | 586 | "source": [ |
673 | | - "Firstly, we **generate embedings for each of document in the knowledge library with SageMaker GPT-J-6B embedding model.**" |
| 587 | + "Firstly, we **generate embedings for each of document in the knowledge library with SageMaker MiniLM-L6-v2 embedding model.**" |
674 | 588 | ] |
675 | 589 | }, |
676 | 590 | { |
|
1384 | 1298 | ], |
1385 | 1299 | "instance_type": "ml.t3.medium", |
1386 | 1300 | "kernelspec": { |
1387 | | - "display_name": "Python 3 (Data Science 2.0)", |
| 1301 | + "display_name": "Python 3 (ipykernel)", |
1388 | 1302 | "language": "python", |
1389 | | - "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-data-science-38" |
| 1303 | + "name": "python3" |
1390 | 1304 | }, |
1391 | 1305 | "language_info": { |
1392 | 1306 | "codemirror_mode": { |
|
1398 | 1312 | "name": "python", |
1399 | 1313 | "nbconvert_exporter": "python", |
1400 | 1314 | "pygments_lexer": "ipython3", |
1401 | | - "version": "3.8.13" |
| 1315 | + "version": "3.11.9" |
1402 | 1316 | } |
1403 | 1317 | }, |
1404 | 1318 | "nbformat": 4, |
|
0 commit comments