Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ keywords = [
license = " UPL-1.0"
license-files = ["LICENSE.txt"]
classifiers = [
"Development Status :: 4 - Beta",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Developers",
"Natural Language :: English",
"Operating System :: OS Independent",
Expand All @@ -34,7 +34,9 @@ classifiers = [
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Database"
"Topic :: Database",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development :: Libraries :: Python Modules"
]
dependencies = [
"oracledb",
Expand All @@ -45,6 +47,7 @@ dependencies = [
Homepage = "https://github.com/oracle/python-select-ai"
Repository = "https://github.com/oracle/python-select-ai"
Issues = "https://github.com/oracle/python-select-ai/issues"
Documentation = "https://docs.oracle.com/en/cloud/paas/autonomous-database/serverless/pysai/"

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
123 changes: 123 additions & 0 deletions src/select_ai/_validations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# -----------------------------------------------------------------------------
# Copyright (c) 2025, Oracle and/or its affiliates.
#
# Licensed under the Universal Permissive License v 1.0 as shown at
# http://oss.oracle.com/licenses/upl.
# -----------------------------------------------------------------------------

import inspect
from collections.abc import Mapping, Sequence, Set
from functools import wraps
from typing import Any, get_args, get_origin, get_type_hints

NoneType = type(None)


def _match(value, annot) -> bool:
"""Recursively validate value against a typing annotation."""
if annot is Any:
return True

origin = get_origin(annot)
args = get_args(annot)

# Handle Annotated[T, ...] → treat as T
if origin is getattr(__import__("typing"), "Annotated", None):
annot = args[0]
origin = get_origin(annot)
args = get_args(annot)

# Optional[T] is Union[T, NoneType]
if origin is getattr(__import__("typing"), "Union", None):
return any(_match(value, a) for a in args)

# Literal[…]
if origin is getattr(__import__("typing"), "Literal", None):
return any(value == lit for lit in args)

# Tuple cases
if origin is tuple:
if not isinstance(value, tuple):
return False
if len(args) == 2 and args[1] is Ellipsis:
# tuple[T, ...]
return all(_match(v, args[0]) for v in value)
if len(args) != len(value):
return False
return all(_match(v, a) for v, a in zip(value, args))

# Mappings (dict-like)
if origin in (dict, Mapping):
if not isinstance(value, Mapping):
return False
k_annot, v_annot = args if args else (Any, Any)
return all(
_match(k, k_annot) and _match(v, v_annot) for k, v in value.items()
)

# Sequences (list, Sequence) – but not str/bytes
if origin in (list, Sequence):
if isinstance(value, (str, bytes)):
return False
if not isinstance(value, Sequence):
return False
elem_annot = args[0] if args else Any
return all(_match(v, elem_annot) for v in value)

# Sets
if origin in (set, frozenset, Set):
if not isinstance(value, (set, frozenset)):
return False
elem_annot = args[0] if args else Any
return all(_match(v, elem_annot) for v in value)

# Fall back to normal isinstance for non-typing classes
if isinstance(annot, type):
return isinstance(value, annot)

# If annot is a typing alias like 'list' without args
if origin is not None:
# Treat bare containers as accepting anything inside
return isinstance(value, origin)

# Unknown/unsupported typing form: accept conservatively
return True


def enforce_types(func):
# Resolve ForwardRefs using function globals (handles "User" as a string, etc.)
hints = get_type_hints(
func, globalns=func.__globals__, include_extras=True
)
sig = inspect.signature(func)

def _check(bound):
for name, val in bound.arguments.items():
if name in hints:
annot = hints[name]
if not _match(val, annot):
raise TypeError(
f"Argument '{name}' failed type check: expected {annot!r}, "
f"got {type(val).__name__} -> {val!r}"
)

if inspect.iscoroutinefunction(func):

@wraps(func)
async def aw(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
_check(bound)
return await func(*args, **kwargs)

return aw
else:

@wraps(func)
def w(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
_check(bound)
return func(*args, **kwargs)

return w
14 changes: 10 additions & 4 deletions src/select_ai/async_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,15 @@ async def generate(
keyword_parameters=parameters,
)
if data is not None:
return await data.read()
return None
result = await data.read()
else:
result = None
if action == Action.RUNSQL and result:
return pandas.DataFrame(json.loads(result))
elif action == Action.RUNSQL:
return pandas.DataFrame()
else:
return result

async def chat(self, prompt, params: Mapping = None) -> str:
"""Asynchronously chat with the LLM
Expand Down Expand Up @@ -411,8 +418,7 @@ async def run_sql(
:param params: Parameters to include in the LLM request
:return: pandas.DataFrame
"""
data = await self.generate(prompt, action=Action.RUNSQL, params=params)
return pandas.DataFrame(json.loads(data))
return await self.generate(prompt, action=Action.RUNSQL, params=params)

async def show_sql(self, prompt, params: Mapping = None):
"""Show the generated SQL
Expand Down
1 change: 1 addition & 0 deletions src/select_ai/base_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class ProfileAttributes(SelectAIDataClass):
vector_index_name: Optional[str] = None

def __post_init__(self):
super().__post_init__()
if self.provider and not isinstance(self.provider, Provider):
raise ValueError(
f"'provider' must be an object of " f"type select_ai.Provider"
Expand Down
22 changes: 14 additions & 8 deletions src/select_ai/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
from contextlib import contextmanager
from dataclasses import replace as dataclass_replace
from typing import Iterator, Mapping, Optional, Union
from typing import Generator, Iterator, Mapping, Optional, Union

import oracledb
import pandas
Expand Down Expand Up @@ -258,7 +258,9 @@ def _from_db(cls, profile_name: str) -> "Profile":
raise ProfileNotFoundError(profile_name=profile_name)

@classmethod
def list(cls, profile_name_pattern: str = ".*") -> Iterator["Profile"]:
def list(
cls, profile_name_pattern: str = ".*"
) -> Generator["Profile", None, None]:
"""List AI Profiles saved in the database.

:param str profile_name_pattern: Regular expressions can be used
Expand Down Expand Up @@ -314,8 +316,15 @@ def generate(
keyword_parameters=parameters,
)
if data is not None:
return data.read()
return None
result = data.read()
else:
result = None
if action == Action.RUNSQL and result:
return pandas.DataFrame(json.loads(result))
elif action == Action.RUNSQL:
return pandas.DataFrame()
else:
return result

def chat(self, prompt: str, params: Mapping = None) -> str:
"""Chat with the LLM
Expand Down Expand Up @@ -375,10 +384,7 @@ def run_sql(self, prompt: str, params: Mapping = None) -> pandas.DataFrame:
:param params: Parameters to include in the LLM request
:return: pandas.DataFrame
"""
data = json.loads(
self.generate(prompt, action=Action.RUNSQL, params=params)
)
return pandas.DataFrame(data)
return self.generate(prompt, action=Action.RUNSQL, params=params)

def show_sql(self, prompt: str, params: Mapping = None) -> str:
"""Show the generated SQL
Expand Down
13 changes: 9 additions & 4 deletions src/select_ai/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import List, Optional, Union

from select_ai._abc import SelectAIDataClass
from select_ai._validations import enforce_types

from .db import async_cursor, cursor
from .sql import (
Expand Down Expand Up @@ -194,6 +195,7 @@ class AnthropicProvider(Provider):
provider_endpoint = "api.anthropic.com"


@enforce_types
async def async_enable_provider(
users: Union[str, List[str]], provider_endpoint: str = None
):
Expand All @@ -210,7 +212,7 @@ async def async_enable_provider(

async with async_cursor() as cr:
for user in users:
await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user))
await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
if provider_endpoint:
await cr.execute(
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
Expand All @@ -219,6 +221,7 @@ async def async_enable_provider(
)


@enforce_types
async def async_disable_provider(
users: Union[str, List[str]], provider_endpoint: str = None
):
Expand All @@ -234,7 +237,7 @@ async def async_disable_provider(

async with async_cursor() as cr:
for user in users:
await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user))
await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
if provider_endpoint:
await cr.execute(
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
Expand All @@ -243,6 +246,7 @@ async def async_disable_provider(
)


@enforce_types
def enable_provider(
users: Union[str, List[str]], provider_endpoint: str = None
):
Expand All @@ -256,7 +260,7 @@ def enable_provider(

with cursor() as cr:
for user in users:
cr.execute(GRANT_PRIVILEGES_TO_USER.format(user))
cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
if provider_endpoint:
cr.execute(
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
Expand All @@ -265,6 +269,7 @@ def enable_provider(
)


@enforce_types
def disable_provider(
users: Union[str, List[str]], provider_endpoint: str = None
):
Expand All @@ -279,7 +284,7 @@ def disable_provider(

with cursor() as cr:
for user in users:
cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user))
cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
if provider_endpoint:
cr.execute(
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
Expand Down
10 changes: 10 additions & 0 deletions src/select_ai/vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,16 @@ def __init__(
attributes: Optional[VectorIndexAttributes] = None,
):
"""Initialize a Vector Index"""
if attributes and not isinstance(attributes, VectorIndexAttributes):
raise TypeError(
"'attributes' must be an object of type "
"select_ai.VectorIndexAttributes"
)
if profile and not isinstance(profile, BaseProfile):
raise TypeError(
"'profile' must be an object of type "
"select_ai.Profile or select_ai.AsyncProfile"
)
self.profile = profile
self.index_name = index_name
self.attributes = attributes
Expand Down
2 changes: 1 addition & 1 deletion src/select_ai/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# http://oss.oracle.com/licenses/upl.
# -----------------------------------------------------------------------------

__version__ = "1.0.0b1"
__version__ = "1.0.0"