Skip to content

Commit 9f20813

Browse files
committed
Migrating tests to pytest
1 parent 9df1a40 commit 9f20813

14 files changed

+1122
-132
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ doc/drawio
1212
**/__pycache__
1313
test.env
1414
test_19c.env
15+
pytest.env

pyproject.toml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,12 @@ dependencies = [
4343
"pandas==2.2.3"
4444
]
4545

46+
[project.optional-dependencies]
47+
test = [
48+
"anyio",
49+
"pytest",
50+
]
51+
4652
[project.urls]
4753
Homepage = "https://github.com/oracle/python-select-ai"
4854
Repository = "https://github.com/oracle/python-select-ai"
@@ -64,3 +70,9 @@ required-version = 24
6470
line-length = 79
6571
target-version = "py39"
6672
per-file-ignores = { "__init__.py" = ["F401"] }
73+
74+
[tool.pytest.ini_options]
75+
minversion = "8.3.0"
76+
testpaths = [
77+
"tests"
78+
]

src/select_ai/__init__.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@
2929
disconnect,
3030
is_connected,
3131
)
32+
from .errors import *
33+
from .privilege import (
34+
async_grant_http_access,
35+
async_grant_privileges,
36+
async_revoke_http_access,
37+
grant_http_access,
38+
grant_privileges,
39+
revoke_http_access,
40+
)
3241
from .profile import Profile
3342
from .provider import (
3443
AnthropicProvider,
@@ -40,10 +49,6 @@
4049
OCIGenAIProvider,
4150
OpenAIProvider,
4251
Provider,
43-
async_disable_provider,
44-
async_enable_provider,
45-
disable_provider,
46-
enable_provider,
4752
)
4853
from .synthetic_data import (
4954
SyntheticDataAttributes,

src/select_ai/async_profile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ async def _init_profile(self):
7474
if self.raise_error_if_exists:
7575
raise ProfileExistsError(self.profile_name)
7676

77-
if self.description is None:
77+
if self.description is None and not self.replace:
7878
self.description = await self._get_profile_description(
7979
profile_name=self.profile_name
8080
)
@@ -307,7 +307,7 @@ async def list(
307307
rows = await cr.fetchall()
308308
for row in rows:
309309
profile_name = row[0]
310-
description = row[1]
310+
description = await row[1].read() if row[1] else None
311311
attributes = await cls._get_attributes(
312312
profile_name=profile_name
313313
)

src/select_ai/privilege.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
# -----------------------------------------------------------------------------
2+
# Copyright (c) 2025, Oracle and/or its affiliates.
3+
#
4+
# Licensed under the Universal Permissive License v 1.0 as shown at
5+
# http://oss.oracle.com/licenses/upl.
6+
# -----------------------------------------------------------------------------
7+
from typing import List, Union
8+
9+
from .db import async_cursor, cursor
10+
from .sql import (
11+
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
12+
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
13+
GRANT_PRIVILEGES_TO_USER,
14+
REVOKE_PRIVILEGES_FROM_USER,
15+
)
16+
17+
18+
async def async_grant_privileges(users: Union[str, List[str]]):
19+
"""
20+
This method grants execute privilege on the packages DBMS_CLOUD,
21+
DBMS_CLOUD_AI, DBMS_CLOUD_AI_AGENT and DBMS_CLOUD_PIPELINE.
22+
23+
"""
24+
if isinstance(users, str):
25+
users = [users]
26+
27+
async with async_cursor() as cr:
28+
for user in users:
29+
await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
30+
31+
32+
async def async_revoke_privileges(users: Union[str, List[str]]):
33+
"""
34+
This method revokes execute privilege on the packages DBMS_CLOUD,
35+
DBMS_CLOUD_AI, DBMS_CLOUD_AI_AGENT and DBMS_CLOUD_PIPELINE.
36+
37+
"""
38+
if isinstance(users, str):
39+
users = [users]
40+
41+
async with async_cursor() as cr:
42+
for user in users:
43+
await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
44+
45+
46+
async def async_grant_http_access(
47+
users: Union[str, List[str]],
48+
provider_endpoint: str,
49+
):
50+
"""
51+
Async method to add ACL for HTTP access.
52+
"""
53+
if isinstance(users, str):
54+
users = [users]
55+
56+
async with async_cursor() as cr:
57+
for user in users:
58+
await cr.execute(
59+
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
60+
user=user,
61+
host=provider_endpoint,
62+
)
63+
64+
65+
async def async_revoke_http_access(
66+
users: Union[str, List[str]],
67+
provider_endpoint: str,
68+
):
69+
"""
70+
Async method to remove ACL for HTTP access.
71+
"""
72+
if isinstance(users, str):
73+
users = [users]
74+
75+
async with async_cursor() as cr:
76+
for user in users:
77+
await cr.execute(
78+
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
79+
user=user,
80+
host=provider_endpoint,
81+
)
82+
83+
84+
def grant_privileges(users: Union[str, List[str]]):
85+
"""
86+
This method grants execute privilege on the packages DBMS_CLOUD,
87+
DBMS_CLOUD_AI, DBMS_CLOUD_AI_AGENT and DBMS_CLOUD_PIPELINE
88+
"""
89+
if isinstance(users, str):
90+
users = [users]
91+
with cursor() as cr:
92+
for user in users:
93+
cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
94+
95+
96+
def revoke_privileges(users: Union[str, List[str]]):
97+
"""
98+
This method revokes execute privilege on the packages DBMS_CLOUD,
99+
DBMS_CLOUD_AI, DBMS_CLOUD_AI_AGENT and DBMS_CLOUD_PIPELINE.
100+
"""
101+
if isinstance(users, str):
102+
users = [users]
103+
with cursor() as cr:
104+
for user in users:
105+
cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
106+
107+
108+
def grant_http_access(users: Union[str, List[str]], provider_endpoint: str):
109+
"""
110+
Adds ACL entry for HTTP access
111+
"""
112+
if isinstance(users, str):
113+
users = [users]
114+
with cursor() as cr:
115+
for user in users:
116+
cr.execute(
117+
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
118+
user=user,
119+
host=provider_endpoint,
120+
)
121+
122+
123+
def revoke_http_access(users: Union[str, List[str]], provider_endpoint: str):
124+
"""
125+
Removes ACL entry for HTTP access
126+
"""
127+
if isinstance(users, str):
128+
users = [users]
129+
with cursor() as cr:
130+
for user in users:
131+
cr.execute(
132+
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
133+
user=user,
134+
host=provider_endpoint,
135+
)

src/select_ai/profile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _init_profile(self) -> None:
6363
if self.raise_error_if_exists:
6464
raise ProfileExistsError(self.profile_name)
6565

66-
if self.description is None:
66+
if self.description is None and not self.replace:
6767
self.description = self._get_profile_description(
6868
profile_name=self.profile_name
6969
)
@@ -280,7 +280,7 @@ def list(
280280
)
281281
for row in cr.fetchall():
282282
profile_name = row[0]
283-
description = row[1]
283+
description = row[1].read() if row[1] else None
284284
attributes = cls._get_attributes(profile_name=profile_name)
285285
yield cls(
286286
profile_name=profile_name,

src/select_ai/provider.py

Lines changed: 0 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -193,101 +193,3 @@ class AnthropicProvider(Provider):
193193

194194
provider_name: str = ANTHROPIC
195195
provider_endpoint = "api.anthropic.com"
196-
197-
198-
@enforce_types
199-
async def async_enable_provider(
200-
users: Union[str, List[str]], provider_endpoint: str = None
201-
):
202-
"""
203-
Async API to enable AI profile for database users.
204-
205-
This method grants execute privilege on the packages DBMS_CLOUD,
206-
DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It also enables the database
207-
user to invoke the AI Provider (LLM) endpoint
208-
209-
"""
210-
if isinstance(users, str):
211-
users = [users]
212-
213-
async with async_cursor() as cr:
214-
for user in users:
215-
await cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
216-
if provider_endpoint:
217-
await cr.execute(
218-
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
219-
user=user,
220-
host=provider_endpoint,
221-
)
222-
223-
224-
@enforce_types
225-
async def async_disable_provider(
226-
users: Union[str, List[str]], provider_endpoint: str = None
227-
):
228-
"""
229-
Async API to disable AI profile for database users
230-
231-
Disables AI provider for the user. This method revokes execute privilege
232-
on the packages DBMS_CLOUD, DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It
233-
also disables the user to invoke the AI Provider (LLM) endpoint
234-
"""
235-
if isinstance(users, str):
236-
users = [users]
237-
238-
async with async_cursor() as cr:
239-
for user in users:
240-
await cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
241-
if provider_endpoint:
242-
await cr.execute(
243-
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
244-
user=user,
245-
host=provider_endpoint,
246-
)
247-
248-
249-
@enforce_types
250-
def enable_provider(
251-
users: Union[str, List[str]], provider_endpoint: str = None
252-
):
253-
"""
254-
Enables AI profile for the user. This method grants execute privilege
255-
on the packages DBMS_CLOUD, DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It
256-
also enables the user to invoke the AI Provider (LLM) endpoint
257-
"""
258-
if isinstance(users, str):
259-
users = [users]
260-
261-
with cursor() as cr:
262-
for user in users:
263-
cr.execute(GRANT_PRIVILEGES_TO_USER.format(user.strip()))
264-
if provider_endpoint:
265-
cr.execute(
266-
ENABLE_AI_PROFILE_DOMAIN_FOR_USER,
267-
user=user,
268-
host=provider_endpoint,
269-
)
270-
271-
272-
@enforce_types
273-
def disable_provider(
274-
users: Union[str, List[str]], provider_endpoint: str = None
275-
):
276-
"""
277-
Disables AI provider for the user. This method revokes execute privilege
278-
on the packages DBMS_CLOUD, DBMS_CLOUD_AI and DBMS_CLOUD_PIPELINE. It
279-
also disables the user to invoke the AI(LLM) endpoint
280-
281-
"""
282-
if isinstance(users, str):
283-
users = [users]
284-
285-
with cursor() as cr:
286-
for user in users:
287-
cr.execute(REVOKE_PRIVILEGES_FROM_USER.format(user.strip()))
288-
if provider_endpoint:
289-
cr.execute(
290-
DISABLE_AI_PROFILE_DOMAIN_FOR_USER,
291-
user=user,
292-
host=provider_endpoint,
293-
)

src/select_ai/vector_index.py

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -81,41 +81,22 @@ class VectorIndexAttributes(SelectAIDataClass):
8181
to store vector embeddings and chunked data
8282
"""
8383

84-
chunk_size: Optional[int] = 1024
85-
chunk_overlap: Optional[int] = 128
84+
chunk_size: Optional[int] = None
85+
chunk_overlap: Optional[int] = None
8686
location: Optional[str] = None
87-
match_limit: Optional[int] = 5
87+
match_limit: Optional[int] = None
8888
object_storage_credential_name: Optional[str] = None
8989
profile_name: Optional[str] = None
90-
refresh_rate: Optional[int] = 1440
91-
similarity_threshold: Optional[float] = 0
92-
vector_distance_metric: Optional[VectorDistanceMetric] = (
93-
VectorDistanceMetric.COSINE
94-
)
90+
refresh_rate: Optional[int] = None
91+
similarity_threshold: Optional[float] = None
92+
vector_distance_metric: Optional[VectorDistanceMetric] = None
9593
vector_db_endpoint: Optional[str] = None
9694
vector_db_credential_name: Optional[str] = None
9795
vector_db_provider: Optional[VectorDBProvider] = None
9896
vector_dimension: Optional[int] = None
9997
vector_table_name: Optional[str] = None
10098
pipeline_name: Optional[str] = None
10199

102-
def json(self, exclude_null=True, for_update=False):
103-
attributes = self.dict(exclude_null=exclude_null)
104-
attributes.pop("pipeline_name", None)
105-
# Currently, the following are unmodifiable
106-
unmodifiable = [
107-
"location",
108-
"chunk_size",
109-
"chunk_overlap",
110-
"vector_dimension",
111-
"vector_table_name",
112-
"vector_distance_metric",
113-
]
114-
if for_update:
115-
for key in unmodifiable:
116-
attributes.pop(key, None)
117-
return json.dumps(attributes)
118-
119100
@classmethod
120101
def create(cls, *, vector_db_provider: Optional[str] = None, **kwargs):
121102
for subclass in cls.__subclasses__():
@@ -364,7 +345,7 @@ def set_attributes(
364345

365346
parameters = {"index_name": self.index_name}
366347
if attributes:
367-
parameters["attributes"] = attributes.json(for_update=True)
348+
parameters["attributes"] = attributes.json()
368349
self.attributes = attributes
369350
else:
370351
setattr(self.attributes, attribute_name, attribute_value)

0 commit comments

Comments
 (0)