From 238660106966ed9b678693ea0161c9855e304706 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Wed, 3 Sep 2025 21:32:13 -0700 Subject: [PATCH 1/3] v1.0.0 release - enforce types for certain utilities e.g. enable_provider, disable_provider - runsql return empty df for no rows - validate arguments for vector index creation - convert AI profile fetched attribtues to Python objects - profile.generate() should return DataFrame for action runsql --- pyproject.toml | 6 +- src/select_ai/_validations.py | 123 +++++++++++++++++++++++++++++++++ src/select_ai/async_profile.py | 14 ++-- src/select_ai/base_profile.py | 1 + src/select_ai/profile.py | 22 +++--- src/select_ai/provider.py | 13 ++-- src/select_ai/vector_index.py | 10 +++ src/select_ai/version.py | 2 +- 8 files changed, 172 insertions(+), 19 deletions(-) create mode 100644 src/select_ai/_validations.py diff --git a/pyproject.toml b/pyproject.toml index 2a11a13..c6d28c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ keywords = [ license = " UPL-1.0" license-files = ["LICENSE.txt"] classifiers = [ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Natural Language :: English", "Operating System :: OS Independent", @@ -34,7 +34,9 @@ classifiers = [ "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", - "Topic :: Database" + "Topic :: Database", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules" ] dependencies = [ "oracledb", diff --git a/src/select_ai/_validations.py b/src/select_ai/_validations.py new file mode 100644 index 0000000..70bf1ec --- /dev/null +++ b/src/select_ai/_validations.py @@ -0,0 +1,123 @@ +# ----------------------------------------------------------------------------- +# Copyright (c) 2025, Oracle and/or its affiliates. +# +# Licensed under the Universal Permissive License v 1.0 as shown at +# http://oss.oracle.com/licenses/upl. +# ----------------------------------------------------------------------------- + +import inspect +from collections.abc import Mapping, Sequence, Set +from functools import wraps +from typing import Any, get_args, get_origin, get_type_hints + +NoneType = type(None) + + +def _match(value, annot) -> bool: + """Recursively validate value against a typing annotation.""" + if annot is Any: + return True + + origin = get_origin(annot) + args = get_args(annot) + + # Handle Annotated[T, ...] → treat as T + if origin is getattr(__import__("typing"), "Annotated", None): + annot = args[0] + origin = get_origin(annot) + args = get_args(annot) + + # Optional[T] is Union[T, NoneType] + if origin is getattr(__import__("typing"), "Union", None): + return any(_match(value, a) for a in args) + + # Literal[…] + if origin is getattr(__import__("typing"), "Literal", None): + return any(value == lit for lit in args) + + # Tuple cases + if origin is tuple: + if not isinstance(value, tuple): + return False + if len(args) == 2 and args[1] is Ellipsis: + # tuple[T, ...] + return all(_match(v, args[0]) for v in value) + if len(args) != len(value): + return False + return all(_match(v, a) for v, a in zip(value, args)) + + # Mappings (dict-like) + if origin in (dict, Mapping): + if not isinstance(value, Mapping): + return False + k_annot, v_annot = args if args else (Any, Any) + return all( + _match(k, k_annot) and _match(v, v_annot) for k, v in value.items() + ) + + # Sequences (list, Sequence) – but not str/bytes + if origin in (list, Sequence): + if isinstance(value, (str, bytes)): + return False + if not isinstance(value, Sequence): + return False + elem_annot = args[0] if args else Any + return all(_match(v, elem_annot) for v in value) + + # Sets + if origin in (set, frozenset, Set): + if not isinstance(value, (set, frozenset)): + return False + elem_annot = args[0] if args else Any + return all(_match(v, elem_annot) for v in value) + + # Fall back to normal isinstance for non-typing classes + if isinstance(annot, type): + return isinstance(value, annot) + + # If annot is a typing alias like 'list' without args + if origin is not None: + # Treat bare containers as accepting anything inside + return isinstance(value, origin) + + # Unknown/unsupported typing form: accept conservatively + return True + + +def enforce_types(func): + # Resolve ForwardRefs using function globals (handles "User" as a string, etc.) + hints = get_type_hints( + func, globalns=func.__globals__, include_extras=True + ) + sig = inspect.signature(func) + + def _check(bound): + for name, val in bound.arguments.items(): + if name in hints: + annot = hints[name] + if not _match(val, annot): + raise TypeError( + f"Argument '{name}' failed type check: expected {annot!r}, " + f"got {type(val).__name__} -> {val!r}" + ) + + if inspect.iscoroutinefunction(func): + + @wraps(func) + async def aw(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + _check(bound) + return await func(*args, **kwargs) + + return aw + else: + + @wraps(func) + def w(*args, **kwargs): + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + _check(bound) + return func(*args, **kwargs) + + return w diff --git a/src/select_ai/async_profile.py b/src/select_ai/async_profile.py index da4f502..e84f403 100644 --- a/src/select_ai/async_profile.py +++ b/src/select_ai/async_profile.py @@ -344,8 +344,15 @@ async def generate( keyword_parameters=parameters, ) if data is not None: - return await data.read() - return None + result = await data.read() + else: + result = None + if action == Action.RUNSQL and result: + return pandas.DataFrame(json.loads(result)) + elif action == Action.RUNSQL: + return pandas.DataFrame() + else: + return result async def chat(self, prompt, params: Mapping = None) -> str: """Asynchronously chat with the LLM @@ -411,8 +418,7 @@ async def run_sql( :param params: Parameters to include in the LLM request :return: pandas.DataFrame """ - data = await self.generate(prompt, action=Action.RUNSQL, params=params) - return pandas.DataFrame(json.loads(data)) + return await self.generate(prompt, action=Action.RUNSQL, params=params) async def show_sql(self, prompt, params: Mapping = None): """Show the generated SQL diff --git a/src/select_ai/base_profile.py b/src/select_ai/base_profile.py index 5336ca2..1162d0a 100644 --- a/src/select_ai/base_profile.py +++ b/src/select_ai/base_profile.py @@ -73,6 +73,7 @@ class ProfileAttributes(SelectAIDataClass): vector_index_name: Optional[str] = None def __post_init__(self): + super().__post_init__() if self.provider and not isinstance(self.provider, Provider): raise ValueError( f"'provider' must be an object of " f"type select_ai.Provider" diff --git a/src/select_ai/profile.py b/src/select_ai/profile.py index e375a01..fa068bd 100644 --- a/src/select_ai/profile.py +++ b/src/select_ai/profile.py @@ -8,7 +8,7 @@ import json from contextlib import contextmanager from dataclasses import replace as dataclass_replace -from typing import Iterator, Mapping, Optional, Union +from typing import Generator, Iterator, Mapping, Optional, Union import oracledb import pandas @@ -258,7 +258,9 @@ def _from_db(cls, profile_name: str) -> "Profile": raise ProfileNotFoundError(profile_name=profile_name) @classmethod - def list(cls, profile_name_pattern: str = ".*") -> Iterator["Profile"]: + def list( + cls, profile_name_pattern: str = ".*" + ) -> Generator["Profile", None, None]: """List AI Profiles saved in the database. :param str profile_name_pattern: Regular expressions can be used @@ -314,8 +316,15 @@ def generate( keyword_parameters=parameters, ) if data is not None: - return data.read() - return None + result = data.read() + else: + result = None + if action == Action.RUNSQL and result: + return pandas.DataFrame(json.loads(result)) + elif action == Action.RUNSQL: + return pandas.DataFrame() + else: + return result def chat(self, prompt: str, params: Mapping = None) -> str: """Chat with the LLM @@ -375,10 +384,7 @@ def run_sql(self, prompt: str, params: Mapping = None) -> pandas.DataFrame: :param params: Parameters to include in the LLM request :return: pandas.DataFrame """ - data = json.loads( - self.generate(prompt, action=Action.RUNSQL, params=params) - ) - return pandas.DataFrame(data) + return self.generate(prompt, action=Action.RUNSQL, params=params) def show_sql(self, prompt: str, params: Mapping = None) -> str: """Show the generated SQL diff --git a/src/select_ai/provider.py b/src/select_ai/provider.py index 9dec23c..cb87c13 100644 --- a/src/select_ai/provider.py +++ b/src/select_ai/provider.py @@ -9,6 +9,7 @@ from typing import List, Optional, Union from select_ai._abc import SelectAIDataClass +from select_ai._validations import enforce_types from .db import async_cursor, cursor from .sql import ( @@ -194,6 +195,7 @@ class AnthropicProvider(Provider): provider_endpoint = "api.anthropic.com" +@enforce_types async def async_enable_provider( users: Union[str, List[str]], provider_endpoint: str = None ): @@ -210,7 +212,7 @@ async def async_enable_provider( async with async_cursor() as cr: for user in users: - await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user)) + await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip())) if provider_endpoint: await cr.execute( ENABLE_AI_PROFILE_DOMAIN_FOR_USER, @@ -219,6 +221,7 @@ async def async_enable_provider( ) +@enforce_types async def async_disable_provider( users: Union[str, List[str]], provider_endpoint: str = None ): @@ -234,7 +237,7 @@ async def async_disable_provider( async with async_cursor() as cr: for user in users: - await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user)) + await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip())) if provider_endpoint: await cr.execute( DISABLE_AI_PROFILE_DOMAIN_FOR_USER, @@ -243,6 +246,7 @@ async def async_disable_provider( ) +@enforce_types def enable_provider( users: Union[str, List[str]], provider_endpoint: str = None ): @@ -256,7 +260,7 @@ def enable_provider( with cursor() as cr: for user in users: - cr.execute(GRANT_PRIVILEGES_TO_USER.format(user)) + cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip())) if provider_endpoint: cr.execute( ENABLE_AI_PROFILE_DOMAIN_FOR_USER, @@ -265,6 +269,7 @@ def enable_provider( ) +@enforce_types def disable_provider( users: Union[str, List[str]], provider_endpoint: str = None ): @@ -279,7 +284,7 @@ def disable_provider( with cursor() as cr: for user in users: - cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user)) + cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip())) if provider_endpoint: cr.execute( DISABLE_AI_PROFILE_DOMAIN_FOR_USER, diff --git a/src/select_ai/vector_index.py b/src/select_ai/vector_index.py index 9143d88..67f6470 100644 --- a/src/select_ai/vector_index.py +++ b/src/select_ai/vector_index.py @@ -119,6 +119,16 @@ def __init__( attributes: Optional[VectorIndexAttributes] = None, ): """Initialize a Vector Index""" + if attributes and not isinstance(attributes, VectorIndexAttributes): + raise TypeError( + "'attributes' must be an object of type " + "select_ai.VectorIndexAttributes" + ) + if profile and not isinstance(profile, BaseProfile): + raise TypeError( + "'profile' must be an object of type " + "select_ai.Profile or select_ai.AsyncProfile" + ) self.profile = profile self.index_name = index_name self.attributes = attributes diff --git a/src/select_ai/version.py b/src/select_ai/version.py index 1875fd5..633966b 100644 --- a/src/select_ai/version.py +++ b/src/select_ai/version.py @@ -5,4 +5,4 @@ # http://oss.oracle.com/licenses/upl. # ----------------------------------------------------------------------------- -__version__ = "1.0.0b1" +__version__ = "1.0.0" From ced57a8ca3604e24f91bfbdb14ddef246f434bfd Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Wed, 3 Sep 2025 22:54:38 -0700 Subject: [PATCH 2/3] Added documentation link --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index c6d28c3..6953244 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ dependencies = [ Homepage = "https://github.com/oracle/python-select-ai" Repository = "https://github.com/oracle/python-select-ai" Issues = "https://github.com/oracle/python-select-ai/issues" +Documentation = "https://docs.oracle.com/en/cloud/paas/autonomous-database/serverless/pysai/" [tool.setuptools.packages.find] where = ["src"] From c7d736a3cbf77ccee16a56a4b3905ce25826d1c5 Mon Sep 17 00:00:00 2001 From: Abhishek Singh Date: Wed, 3 Sep 2025 23:32:49 -0700 Subject: [PATCH 3/3] Added documentation in README --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index f98bb70..f8d1fa5 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,10 @@ Run python3 -m pip install select_ai ``` +## Documentation + +See [Select AI for Python documentation][documentation] + ## Samples Examples can be found in the [/samples][samples] directory @@ -81,6 +85,7 @@ Released under the Universal Permissive License v1.0 as shown at . [contributing]: https://github.com/oracle/python-select-ai/blob/main/CONTRIBUTING.md +[documentation]: https://docs.oracle.com/en/cloud/paas/autonomous-database/serverless/pysai/ [ghdiscussions]: https://github.com/oracle/python-select-ai/discussions [ghissues]: https://github.com/oracle/python-select-ai/issues [samples]: https://github.com/oracle/python-select-ai/tree/main/samples