diff --git a/docs/model_serving_framework/deploy_sparse_model_to_SageMaker.ipynb b/docs/model_serving_framework/deploy_sparse_model_to_SageMaker.ipynb index 7997346c93..71e9b4c086 100644 --- a/docs/model_serving_framework/deploy_sparse_model_to_SageMaker.ipynb +++ b/docs/model_serving_framework/deploy_sparse_model_to_SageMaker.ipynb @@ -52,7 +52,13 @@ "outputs": [], "source": [ "%%writefile handler/code/requirements.txt\n", - "sentence-transformers==5.0.0" + "transformers==4.56.1\n", + "huggingface_hub==0.35.0\n", + "hf_xet==1.1.10\n", + "tokenizers==0.22.0\n", + "regex==2025.9.1\n", + "safetensors==0.6.2\n", + "sentence-transformers==5.1.0" ] }, { @@ -134,30 +140,64 @@ " )\n", " print(f\"Using device: {self.device}\")\n", " self.model = SparseEncoder(model_id, device=self.device, trust_remote_code=trust_remote_code)\n", + " self._warmup()\n", " self.initialized = True\n", "\n", - " def preprocess(self, requests):\n", + " def _warmup(self):\n", + " input_data = [{\"body\": [\"hello world\"] * 10}]\n", + " self.handle(input_data, None)\n", + "\n", + " def _preprocess(self, requests):\n", " inputSentence = []\n", " batch_idx = []\n", + " formats = [] # per-text format: \"word\" or \"token_id\"\n", "\n", " for request in requests:\n", " request_body = request.get(\"body\")\n", " if isinstance(request_body, bytearray):\n", " request_body = request_body.decode(\"utf-8\")\n", " request_body = json.loads((request_body))\n", - " if isinstance(request_body, list):\n", + "\n", + " # dict-based new schema: {\"texts\": str | list[str], \"sparse_embedding_format\": str}\n", + " if isinstance(request_body, dict):\n", + " texts = request_body.get(\"texts\")\n", + " fmt = request_body.get(\"sparse_embedding_format\", \"word\")\n", + " fmt = \"token_id\" if isinstance(fmt, str) and fmt.lower() == \"token_id\" else \"word\"\n", + "\n", + " if isinstance(texts, list):\n", + " inputSentence += texts\n", + " batch_idx.append(len(texts))\n", + " formats += [fmt] * len(texts)\n", + " else:\n", + " inputSentence.append(texts)\n", + " batch_idx.append(1)\n", + " formats.append(fmt)\n", + "\n", + " # legacy schemas\n", + " elif isinstance(request_body, list):\n", " inputSentence += request_body\n", " batch_idx.append(len(request_body))\n", + " formats += [\"word\"] * len(request_body)\n", " else:\n", " inputSentence.append(request_body)\n", " batch_idx.append(1)\n", + " formats.append(\"word\")\n", + "\n", + " return inputSentence, batch_idx, formats\n", "\n", - " return inputSentence, batch_idx\n", + " def _convert_token_ids(self, sparse_embedding):\n", + " token_ids = self.model.tokenizer.convert_tokens_to_ids([x[0] for x in sparse_embedding])\n", + " return [(str(token_ids[i]), sparse_embedding[i][1]) for i in range(len(token_ids))]\n", "\n", " def handle(self, data, context):\n", - " inputSentence, batch_idx = self.preprocess(data)\n", + " inputSentence, batch_idx, formats = self._preprocess(data)\n", " model_output = self.model.encode_document(inputSentence, batch_size=max_bs)\n", - " sparse_embedding = list(map(dict,self.model.decode(model_output)))\n", + "\n", + " sparse_embedding_word = self.model.decode(model_output)\n", + " for i, fmt in enumerate(formats):\n", + " if fmt == \"token_id\":\n", + " sparse_embedding_word[i] = self._convert_token_ids(sparse_embedding_word[i])\n", + " sparse_embedding = list(map(dict, sparse_embedding_word))\n", "\n", " outputs = [sparse_embedding[s:e]\n", " for s, e in zip([0]+list(itertools.accumulate(batch_idx))[:-1],\n", @@ -424,8 +464,8 @@ "```json\n", "POST /_plugins/_ml/connectors/_create\n", "{\n", - " \"name\": \"test\",\n", - " \"description\": \"Test connector for Sagemaker model\",\n", + " \"name\": \"Sagemaker Connector: embedding\",\n", + " \"description\": \"The connector to sagemaker embedding model\",\n", " \"version\": 1,\n", " \"protocol\": \"aws_sigv4\",\n", " \"credential\": {\n", @@ -436,6 +476,7 @@ " \"region\": \"{region}\",\n", " \"service_name\": \"sagemaker\",\n", " \"input_docs_processed_step_size\": 2,\n", + " \"sparse_embedding_format\": \"word\"\n", " },\n", " \"actions\": [\n", " {\n", @@ -445,7 +486,12 @@ " \"content-type\": \"application/json\"\n", " },\n", " \"url\": \"https://runtime.sagemaker.{region}.amazonaws.com/endpoints/{predictor.endpoint_name}/invocations\",\n", - " \"request_body\": \"${parameters.input}\"\n", + " \"request_body\": \"\"\"\n", + " {\n", + " \"texts\": ${parameters.input},\n", + " \"sparse_embedding_format\": \"${parameters.sparse_embedding_format}\"\n", + " }\n", + " \"\"\"\n", " }\n", " ],\n", " \"client_config\":{\n",