diff --git a/server/src/handlers/http.rs b/server/src/handlers/http.rs index ad1fe855e..804317821 100644 --- a/server/src/handlers/http.rs +++ b/server/src/handlers/http.rs @@ -34,6 +34,7 @@ use self::middleware::{DisAllowRootUser, RouteExt}; mod about; mod health_check; mod ingest; +mod llm; mod logstream; mod middleware; mod query; @@ -229,6 +230,21 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) { .wrap(DisAllowRootUser), ), ); + + let llm_query_api = web::scope("/llm") + .service( + web::resource("").route( + web::post() + .to(llm::make_llm_request) + .authorize(Action::Query), + ), + ) + .service( + // to check if the API key for an LLM has been set up as env var + web::resource("isactive") + .route(web::post().to(llm::is_llm_active).authorize(Action::Query)), + ); + // Deny request if username is same as the env variable P_USERNAME. cfg.service( // Base path "{url}/api/v1" @@ -266,7 +282,8 @@ pub fn configure_routes(cfg: &mut web::ServiceConfig) { logstream_api, ), ) - .service(user_api), + .service(user_api) + .service(llm_query_api), ) // GET "/" ==> Serve the static frontend directory .service(ResourceFiles::new("/", generated).resolve_not_found_to_root()); diff --git a/server/src/handlers/http/llm.rs b/server/src/handlers/http/llm.rs new file mode 100644 index 000000000..ef8feccc0 --- /dev/null +++ b/server/src/handlers/http/llm.rs @@ -0,0 +1,176 @@ +/* + * Parseable Server (C) 2022 - 2023 Parseable, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * + */ + +use actix_web::{http::header::ContentType, web, HttpResponse, Result}; +use http::{header, StatusCode}; +use itertools::Itertools; +use reqwest; +use serde_json::{json, Value}; + +use crate::{ + metadata::{error::stream_info::MetadataError, STREAM_INFO}, + option::CONFIG, +}; + +const OPEN_AI_URL: &str = "https://api.openai.com/v1/chat/completions"; + +// Deserialize types for OpenAI Response +#[derive(serde::Deserialize, Debug)] +struct ResponseData { + choices: Vec, +} + +#[derive(serde::Deserialize, Debug)] +struct Choice { + message: Message, +} + +#[derive(serde::Deserialize, Debug)] +struct Message { + content: String, +} + +// Request body +#[derive(serde::Deserialize, Debug)] +pub struct AiPrompt { + prompt: String, + stream: String, +} + +// Temperory type +#[derive(Debug, serde::Serialize)] +struct Field { + name: String, + data_type: String, +} + +impl From<&arrow_schema::Field> for Field { + fn from(field: &arrow_schema::Field) -> Self { + Self { + name: field.name().clone(), + data_type: field.data_type().to_string(), + } + } +} + +fn build_prompt(stream: &str, prompt: &str, schema_json: &str) -> String { + format!( + r#"I have a table called {}. +It has the columns:\n{} +Based on this, generate valid SQL for the query: "{}" +Generate only SQL as output. Also add comments in SQL syntax to explain your actions. +Don't output anything else. +If it is not possible to generate valid SQL, output an SQL comment saying so."#, + stream, schema_json, prompt + ) +} + +fn build_request_body(ai_prompt: String) -> impl serde::Serialize { + json!({ + "model": "gpt-3.5-turbo", + "messages": [{ "role": "user", "content": ai_prompt}], + "temperature": 0.6, + }) +} + +pub async fn make_llm_request(body: web::Json) -> Result { + let api_key = match &CONFIG.parseable.open_ai_key { + Some(api_key) if api_key.len() > 3 => api_key, + _ => return Err(LLMError::InvalidAPIKey), + }; + + let stream_name = &body.stream; + let schema = STREAM_INFO.schema(stream_name)?; + let filtered_schema = schema + .all_fields() + .into_iter() + .map(Field::from) + .collect_vec(); + + let schema_json = + serde_json::to_string(&filtered_schema).expect("always converted to valid json"); + + let prompt = build_prompt(stream_name, &body.prompt, &schema_json); + let body = build_request_body(prompt); + + let client = reqwest::Client::new(); + let response = client + .post(OPEN_AI_URL) + .header(header::CONTENT_TYPE, "application/json") + .bearer_auth(api_key) + .json(&body) + .send() + .await?; + + if response.status().is_success() { + let body: ResponseData = response + .json() + .await + .expect("OpenAI response is always the same"); + Ok(HttpResponse::Ok() + .content_type("application/json") + .json(&body.choices[0].message.content)) + } else { + let body: Value = response.json().await?; + let message = body + .as_object() + .and_then(|body| body.get("error")) + .and_then(|error| error.as_object()) + .and_then(|error| error.get("message")) + .map(|message| message.to_string()) + .unwrap_or_else(|| "Error from OpenAI".to_string()); + + Err(LLMError::APIError(message)) + } +} + +pub async fn is_llm_active(_body: web::Json) -> HttpResponse { + let is_active = matches!(&CONFIG.parseable.open_ai_key, Some(api_key) if api_key.len() > 3); + HttpResponse::Ok() + .content_type("application/json") + .json(json!({"is_active": is_active})) +} + +#[derive(Debug, thiserror::Error)] +pub enum LLMError { + #[error("Either OpenAI key was not provided or was invalid")] + InvalidAPIKey, + #[error("Failed to call OpenAI endpoint: {0}")] + FailedRequest(#[from] reqwest::Error), + #[error("{0}")] + APIError(String), + #[error("{0}")] + StreamDoesNotExist(#[from] MetadataError), +} + +impl actix_web::ResponseError for LLMError { + fn status_code(&self) -> http::StatusCode { + match self { + Self::InvalidAPIKey => StatusCode::INTERNAL_SERVER_ERROR, + Self::FailedRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::APIError(_) => StatusCode::INTERNAL_SERVER_ERROR, + Self::StreamDoesNotExist(_) => StatusCode::INTERNAL_SERVER_ERROR, + } + } + + fn error_response(&self) -> actix_web::HttpResponse { + actix_web::HttpResponse::build(self.status_code()) + .insert_header(ContentType::plaintext()) + .body(self.to_string()) + } +} diff --git a/server/src/option.rs b/server/src/option.rs index 14e389e3a..3da036d0c 100644 --- a/server/src/option.rs +++ b/server/src/option.rs @@ -184,6 +184,9 @@ pub struct Server { /// Server should send anonymous analytics or not pub send_analytics: bool, + /// Open AI access key + pub open_ai_key: Option, + /// Rows in Parquet Rowgroup pub row_group_size: usize, @@ -232,6 +235,7 @@ impl FromArgMatches for Server { .get_one::(Self::SEND_ANALYTICS) .cloned() .expect("default for send analytics"); + self.open_ai_key = m.get_one::(Self::OPEN_AI_KEY).cloned(); // converts Gib to bytes before assigning self.query_memory_pool_size = m .get_one::(Self::QUERY_MEM_POOL_SIZE) @@ -271,6 +275,7 @@ impl Server { pub const PASSWORD: &str = "password"; pub const CHECK_UPDATE: &str = "check-update"; pub const SEND_ANALYTICS: &str = "send-analytics"; + pub const OPEN_AI_KEY: &str = "open-ai-key"; pub const QUERY_MEM_POOL_SIZE: &str = "query-mempool-size"; pub const ROW_GROUP_SIZE: &str = "row-group-size"; pub const PARQUET_COMPRESSION_ALGO: &str = "compression-algo"; @@ -351,6 +356,24 @@ impl Server { .required(true) .help("Password for the basic authentication on the server"), ) + .arg( + Arg::new(Self::SEND_ANALYTICS) + .long(Self::SEND_ANALYTICS) + .env("P_SEND_ANONYMOUS_USAGE_DATA") + .value_name("BOOL") + .required(false) + .default_value("true") + .value_parser(value_parser!(bool)) + .help("Disable/Enable sending anonymous user data"), + ) + .arg( + Arg::new(Self::OPEN_AI_KEY) + .long(Self::OPEN_AI_KEY) + .env("OPENAI_API_KEY") + .value_name("STRING") + .required(false) + .help("Set OpenAI key to enable llm feature"), + ) .arg( Arg::new(Self::CHECK_UPDATE) .long(Self::CHECK_UPDATE) @@ -380,16 +403,6 @@ impl Server { .value_parser(value_parser!(usize)) .help("Number of rows in a row groups"), ) - .arg( - Arg::new(Self::SEND_ANALYTICS) - .long(Self::SEND_ANALYTICS) - .env("P_SEND_ANONYMOUS_USAGE_DATA") - .value_name("BOOL") - .required(false) - .default_value("true") - .value_parser(value_parser!(bool)) - .help("Disable/Enable sending anonymous user data"), - ) .arg( Arg::new(Self::PARQUET_COMPRESSION_ALGO) .long(Self::PARQUET_COMPRESSION_ALGO)