-
Notifications
You must be signed in to change notification settings - Fork 36
[DNM] Proposal for Adding local model loading and automatically dependency installation in webui #91
base: main
Are you sure you want to change the base?
[DNM] Proposal for Adding local model loading and automatically dependency installation in webui #91
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,52 @@ | |
| import sys | ||
|
|
||
| sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | ||
|
|
||
| import importlib, pip | ||
|
|
||
| def fix_package_name(package): | ||
| a = package.split('>')[0] | ||
| a = a.split('<')[0] | ||
| b = a.split('=')[0] | ||
|
|
||
| package_name_map = { | ||
| } | ||
|
|
||
| if b in package_name_map: | ||
| b = package_name_map[b] | ||
| return b | ||
|
|
||
| def check_availability_and_install(package_or_list, verbose=1): | ||
| def actual_func(package): | ||
| pip_name = fix_package_name(package) | ||
| try: | ||
| if pip_name == 'gradio': | ||
| import gradio | ||
| a = gradio.Progress() | ||
| else: | ||
| return importlib.import_module(pip_name) | ||
| except ImportError: | ||
| print(f"automatically install for {package}") | ||
| pip.main(['install', '-q', package]) | ||
| #importlib.import_module(pip_name) | ||
|
|
||
| if isinstance(package_or_list, list): | ||
| if verbose == 1 and len(package_or_list) > 0: | ||
| print(f"check_availability_and_install {package_or_list}") | ||
| for pkg in package_or_list: | ||
| actual_func(pkg) | ||
| elif isinstance(package_or_list, str): | ||
| if verbose == 1 and package_or_list != "": | ||
| print(f"check_availability_and_install {package_or_list}") | ||
| actual_func(package_or_list) | ||
| else: | ||
| raise ValueError(f"{package_or_list} with type of {type(package_or_list)} is not supported.") | ||
|
|
||
| check_availability_and_install(['gradio==3.36.1', 'gradio_client==0.7.3', 'langchain==0.1.4', 'langchain-community==0.0.16', 'lz4', 'sentence-transformers==2.2.2', 'pyrecdp']) | ||
|
|
||
| if not os.environ['RECDP_CACHE_HOME']: | ||
| os.environ['RECDP_CACHE_HOME'] = os.getcwd() | ||
|
|
||
| from inference.inference_config import all_models, ModelDescription, Prompt | ||
| from inference.inference_config import InferenceConfig as FinetunedConfig | ||
| from inference.chat_process import ChatModelGptJ, ChatModelLLama # noqa: F401 | ||
|
|
@@ -48,6 +94,13 @@ | |
| RAGTextFix, | ||
| ) | ||
| from pyrecdp.primitives.document.reader import _default_file_readers | ||
| from pyrecdp.core.cache_utils import RECDP_MODELS_CACHE | ||
|
|
||
| import logging | ||
|
|
||
| lib_list = ["httpcore", "httpx", "paramiko", "urllib3", "markdown_it", "matplotlib"] | ||
| for lib in lib_list: | ||
| logging.getLogger(lib).setLevel(logging.ERROR) | ||
|
Comment on lines
+101
to
+103
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this is to solve the DEBUG log problem, I think it can be solved after #101 is merged. |
||
|
|
||
|
|
||
| class CustomStopper(Stopper): | ||
|
|
@@ -143,7 +196,12 @@ def __init__( | |
| self.finetune_actor = None | ||
| self.finetune_status = False | ||
| self.default_rag_path = default_rag_path | ||
| self.embedding_model_name = "sentence-transformers/all-mpnet-base-v2" | ||
| local_embedding_model_path = os.path.join(RECDP_MODELS_CACHE, "sentence-transformers/all-mpnet-base-v2") | ||
| print(local_embedding_model_path) | ||
| if os.path.exists(local_embedding_model_path): | ||
| self.embedding_model_name = local_embedding_model_path | ||
| else: | ||
| self.embedding_model_name = "sentence-transformers/all-mpnet-base-v2" | ||
| self.embeddings = HuggingFaceEmbeddings(model_name=self.embedding_model_name) | ||
|
|
||
| self._init_ui() | ||
|
|
@@ -794,23 +852,24 @@ def set_rag_default_path(self, selector, rag_path): | |
|
|
||
| def _init_ui(self): | ||
| mark_alive = None | ||
| private_key = paramiko.Ed25519Key.from_private_key_file("/root/.ssh/id_ed25519") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why use a custom path to set private key?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For some reason, the id_rsa decoder is reporting error from my env, I found similiar report on google.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, I see. I'm not sure about the format of your known_hosts, could you verify this method in your env? |
||
| for index in range(len(self.ray_nodes)): | ||
| if "node:__internal_head__" in ray.nodes()[index]["Resources"]: | ||
| mark_alive = index | ||
| node_ip = self.ray_nodes[index]["NodeName"] | ||
| self.ssh_connect[index] = paramiko.SSHClient() | ||
| self.ssh_connect[index].load_system_host_keys() | ||
| self.ssh_connect[index].set_missing_host_key_policy(paramiko.RejectPolicy()) | ||
| #self.ssh_connect[index].load_system_host_keys() | ||
| self.ssh_connect[index].set_missing_host_key_policy(paramiko.AutoAddPolicy()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't use AutoAddPolicy because of code scan issue.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @KepingYan , I see, thanks for the clarification, I'll make the change |
||
| self.ssh_connect[index].connect( | ||
| hostname=node_ip, port=self.node_port, username=self.user_name | ||
| hostname=node_ip, username=self.user_name, pkey=private_key | ||
| ) | ||
| self.ssh_connect[-1] = paramiko.SSHClient() | ||
| self.ssh_connect[-1].load_system_host_keys() | ||
| self.ssh_connect[-1].set_missing_host_key_policy(paramiko.RejectPolicy()) | ||
| #self.ssh_connect[-1].load_system_host_keys() | ||
| self.ssh_connect[-1].set_missing_host_key_policy(paramiko.AutoAddPolicy()) | ||
| self.ssh_connect[-1].connect( | ||
| hostname=self.ray_nodes[mark_alive]["NodeName"], | ||
| port=self.node_port, | ||
| username=self.user_name, | ||
| pkey=private_key | ||
| ) | ||
|
|
||
| title = "Manage LLM Lifecycle" | ||
|
|
||

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a script in #101 to install the UI environment, maybe we don't need to check it in code anymore. And we have to use
pip install 'git+https://github.com/intel/e2eAIOK.git#egg=pyrecdp&subdirectory=RecDP'to install pyrecdp,pip install pyrecdpdoes not contain the latest code and will change the ray version.