mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-13 04:22:35 +00:00
One more attempt with Claude's help to close connections
This commit is contained in:
parent
fa100c77fd
commit
a424815804
16 changed files with 132 additions and 38 deletions
|
|
@ -108,3 +108,8 @@ class QuotaMiddleware:
|
||||||
)
|
)
|
||||||
body = json.dumps({"error": {"message": message}}).encode()
|
body = json.dumps({"error": {"message": message}}).encode()
|
||||||
await send({"type": "http.response.body", "body": body})
|
await send({"type": "http.response.body", "body": body})
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""Close the KV store connection."""
|
||||||
|
if self.kv is not None:
|
||||||
|
await self.kv.close()
|
||||||
|
|
|
||||||
|
|
@ -315,6 +315,7 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
await self.persistence_store.close()
|
await self.persistence_store.close()
|
||||||
|
await self.responses_store.shutdown()
|
||||||
|
|
||||||
# OpenAI responses
|
# OpenAI responses
|
||||||
async def get_openai_response(
|
async def get_openai_response(
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,8 @@ class LocalfsFilesImpl(Files):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
if self.sql_store:
|
||||||
|
await self.sql_store.close()
|
||||||
|
|
||||||
def _generate_file_id(self) -> str:
|
def _generate_file_id(self) -> str:
|
||||||
"""Generate a unique file ID for OpenAI API."""
|
"""Generate a unique file ID for OpenAI API."""
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,8 @@ class S3FilesImpl(Files):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
if self._sql_store:
|
||||||
|
await self._sql_store.close()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> boto3.client:
|
def client(self) -> boto3.client:
|
||||||
|
|
|
||||||
|
|
@ -74,19 +74,21 @@ class InferenceStore:
|
||||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if not self._worker_tasks:
|
if self._worker_tasks:
|
||||||
return
|
if self._queue is not None:
|
||||||
if self._queue is not None:
|
await self._queue.join()
|
||||||
await self._queue.join()
|
for t in self._worker_tasks:
|
||||||
for t in self._worker_tasks:
|
if not t.done():
|
||||||
if not t.done():
|
t.cancel()
|
||||||
t.cancel()
|
for t in self._worker_tasks:
|
||||||
for t in self._worker_tasks:
|
try:
|
||||||
try:
|
await t
|
||||||
await t
|
except asyncio.CancelledError:
|
||||||
except asyncio.CancelledError:
|
pass
|
||||||
pass
|
self._worker_tasks.clear()
|
||||||
self._worker_tasks.clear()
|
|
||||||
|
if self.sql_store:
|
||||||
|
await self.sql_store.close()
|
||||||
|
|
||||||
async def flush(self) -> None:
|
async def flush(self) -> None:
|
||||||
"""Wait for all queued writes to complete. Useful for testing."""
|
"""Wait for all queued writes to complete. Useful for testing."""
|
||||||
|
|
|
||||||
|
|
@ -96,19 +96,21 @@ class ResponsesStore:
|
||||||
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
logger.info("Write queue disabled for SQLite to avoid concurrency issues")
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if not self._worker_tasks:
|
if self._worker_tasks:
|
||||||
return
|
if self._queue is not None:
|
||||||
if self._queue is not None:
|
await self._queue.join()
|
||||||
await self._queue.join()
|
for t in self._worker_tasks:
|
||||||
for t in self._worker_tasks:
|
if not t.done():
|
||||||
if not t.done():
|
t.cancel()
|
||||||
t.cancel()
|
for t in self._worker_tasks:
|
||||||
for t in self._worker_tasks:
|
try:
|
||||||
try:
|
await t
|
||||||
await t
|
except asyncio.CancelledError:
|
||||||
except asyncio.CancelledError:
|
pass
|
||||||
pass
|
self._worker_tasks.clear()
|
||||||
self._worker_tasks.clear()
|
|
||||||
|
if self.sql_store:
|
||||||
|
await self.sql_store.close()
|
||||||
|
|
||||||
async def flush(self) -> None:
|
async def flush(self) -> None:
|
||||||
"""Wait for all queued writes to complete. Useful for testing."""
|
"""Wait for all queued writes to complete. Useful for testing."""
|
||||||
|
|
|
||||||
|
|
@ -126,3 +126,9 @@ class SqlStore(Protocol):
|
||||||
:param nullable: Whether the column should be nullable (default: True)
|
:param nullable: Whether the column should be nullable (default: True)
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""
|
||||||
|
Close any persistent database connections.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -197,6 +197,10 @@ class AuthorizedSqlStore:
|
||||||
"""Delete rows with automatic access control filtering."""
|
"""Delete rows with automatic access control filtering."""
|
||||||
await self.sql_store.delete(table, where)
|
await self.sql_store.delete(table, where)
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the underlying SQL store connection."""
|
||||||
|
await self.sql_store.close()
|
||||||
|
|
||||||
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
|
def _build_access_control_where_clause(self, policy: list[AccessRule]) -> str:
|
||||||
"""Build SQL WHERE clause for access control filtering.
|
"""Build SQL WHERE clause for access control filtering.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -311,3 +311,11 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
||||||
# The table creation will handle adding the column
|
# The table creation will handle adding the column
|
||||||
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
logger.error(f"Error adding column {column_name} to table {table}: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""Close the database engine and all connections."""
|
||||||
|
if hasattr(self, "async_session"):
|
||||||
|
# Get the engine from the session maker
|
||||||
|
engine = self.async_session.kw.get("bind")
|
||||||
|
if engine:
|
||||||
|
await engine.dispose()
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,7 @@ async def files_provider(tmp_path):
|
||||||
provider = LocalfsFilesImpl(config, default_policy())
|
provider = LocalfsFilesImpl(config, default_policy())
|
||||||
await provider.initialize()
|
await provider.initialize()
|
||||||
yield provider
|
yield provider
|
||||||
|
await provider.shutdown()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -28,3 +28,5 @@ async def temp_prompt_store(tmp_path_factory):
|
||||||
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
|
store.kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=db_path))
|
||||||
|
|
||||||
yield store
|
yield store
|
||||||
|
|
||||||
|
await store.kvstore.close()
|
||||||
|
|
|
||||||
|
|
@ -52,17 +52,21 @@ def auth_app(tmp_path, request):
|
||||||
db_path = tmp_path / f"quota_{request.node.name}.db"
|
db_path = tmp_path / f"quota_{request.node.name}.db"
|
||||||
quota = build_quota_config(db_path)
|
quota = build_quota_config(db_path)
|
||||||
|
|
||||||
app = InjectClientIDMiddleware(
|
quota_middleware = QuotaMiddleware(
|
||||||
QuotaMiddleware(
|
inner_app,
|
||||||
inner_app,
|
kv_config=quota.kvstore,
|
||||||
kv_config=quota.kvstore,
|
anonymous_max_requests=quota.anonymous_max_requests,
|
||||||
anonymous_max_requests=quota.anonymous_max_requests,
|
authenticated_max_requests=quota.authenticated_max_requests,
|
||||||
authenticated_max_requests=quota.authenticated_max_requests,
|
window_seconds=86400,
|
||||||
window_seconds=86400,
|
|
||||||
),
|
|
||||||
client_id=f"client_{request.node.name}",
|
|
||||||
)
|
)
|
||||||
return app
|
app = InjectClientIDMiddleware(quota_middleware, client_id=f"client_{request.node.name}")
|
||||||
|
|
||||||
|
yield app
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
asyncio.run(quota_middleware.close())
|
||||||
|
|
||||||
|
|
||||||
def test_authenticated_quota_allows_up_to_limit(auth_app):
|
def test_authenticated_quota_allows_up_to_limit(auth_app):
|
||||||
|
|
@ -81,6 +85,8 @@ def test_authenticated_quota_blocks_after_limit(auth_app):
|
||||||
|
|
||||||
|
|
||||||
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
||||||
|
import asyncio
|
||||||
|
|
||||||
inner_app = FastAPI()
|
inner_app = FastAPI()
|
||||||
|
|
||||||
@inner_app.get("/test")
|
@inner_app.get("/test")
|
||||||
|
|
@ -101,8 +107,12 @@ def test_anonymous_quota_allows_up_to_limit(tmp_path, request):
|
||||||
client = TestClient(app)
|
client = TestClient(app)
|
||||||
assert client.get("/test").status_code == 200
|
assert client.get("/test").status_code == 200
|
||||||
|
|
||||||
|
asyncio.run(app.close())
|
||||||
|
|
||||||
|
|
||||||
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
||||||
|
import asyncio
|
||||||
|
|
||||||
inner_app = FastAPI()
|
inner_app = FastAPI()
|
||||||
|
|
||||||
@inner_app.get("/test")
|
@inner_app.get("/test")
|
||||||
|
|
@ -125,3 +135,5 @@ def test_anonymous_quota_blocks_after_limit(tmp_path, request):
|
||||||
resp = client.get("/test")
|
resp = client.get("/test")
|
||||||
assert resp.status_code == 429
|
assert resp.status_code == 429
|
||||||
assert resp.json()["error"]["message"] == "Quota exceeded"
|
assert resp.json()["error"]["message"] == "Quota exceeded"
|
||||||
|
|
||||||
|
asyncio.run(app.close())
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,8 @@ async def test_inference_store_pagination_basic():
|
||||||
assert result3.data[0].id == "zebra-task"
|
assert result3.data[0].id == "zebra-task"
|
||||||
assert result3.has_more is False
|
assert result3.has_more is False
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_inference_store_pagination_ascending():
|
async def test_inference_store_pagination_ascending():
|
||||||
"""Test pagination with ascending order."""
|
"""Test pagination with ascending order."""
|
||||||
|
|
@ -126,6 +128,8 @@ async def test_inference_store_pagination_ascending():
|
||||||
assert result2.data[0].id == "charlie-task"
|
assert result2.data[0].id == "charlie-task"
|
||||||
assert result2.has_more is True
|
assert result2.has_more is True
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_inference_store_pagination_with_model_filter():
|
async def test_inference_store_pagination_with_model_filter():
|
||||||
"""Test pagination combined with model filtering."""
|
"""Test pagination combined with model filtering."""
|
||||||
|
|
@ -166,6 +170,8 @@ async def test_inference_store_pagination_with_model_filter():
|
||||||
assert result2.data[0].model == "model-a"
|
assert result2.data[0].model == "model-a"
|
||||||
assert result2.has_more is False
|
assert result2.has_more is False
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_inference_store_pagination_invalid_after():
|
async def test_inference_store_pagination_invalid_after():
|
||||||
"""Test error handling for invalid 'after' parameter."""
|
"""Test error handling for invalid 'after' parameter."""
|
||||||
|
|
@ -178,6 +184,8 @@ async def test_inference_store_pagination_invalid_after():
|
||||||
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
with pytest.raises(ValueError, match="Record with id='non-existent' not found in table 'chat_completions'"):
|
||||||
await store.list_chat_completions(after="non-existent", limit=2)
|
await store.list_chat_completions(after="non-existent", limit=2)
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_inference_store_pagination_no_limit():
|
async def test_inference_store_pagination_no_limit():
|
||||||
"""Test pagination behavior when no limit is specified."""
|
"""Test pagination behavior when no limit is specified."""
|
||||||
|
|
@ -208,3 +216,5 @@ async def test_inference_store_pagination_no_limit():
|
||||||
assert result.data[0].id == "beta-second" # Most recent first
|
assert result.data[0].id == "beta-second" # Most recent first
|
||||||
assert result.data[1].id == "omega-first"
|
assert result.data[1].id == "omega-first"
|
||||||
assert result.has_more is False
|
assert result.has_more is False
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,8 @@ async def test_responses_store_pagination_basic():
|
||||||
assert result3.data[0].id == "zebra-resp"
|
assert result3.data[0].id == "zebra-resp"
|
||||||
assert result3.has_more is False
|
assert result3.has_more is False
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_responses_store_pagination_ascending():
|
async def test_responses_store_pagination_ascending():
|
||||||
"""Test pagination with ascending order."""
|
"""Test pagination with ascending order."""
|
||||||
|
|
@ -136,6 +138,8 @@ async def test_responses_store_pagination_ascending():
|
||||||
assert result2.data[0].id == "charlie-resp"
|
assert result2.data[0].id == "charlie-resp"
|
||||||
assert result2.has_more is True
|
assert result2.has_more is True
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_responses_store_pagination_with_model_filter():
|
async def test_responses_store_pagination_with_model_filter():
|
||||||
"""Test pagination combined with model filtering."""
|
"""Test pagination combined with model filtering."""
|
||||||
|
|
@ -177,6 +181,8 @@ async def test_responses_store_pagination_with_model_filter():
|
||||||
assert result2.data[0].model == "model-a"
|
assert result2.data[0].model == "model-a"
|
||||||
assert result2.has_more is False
|
assert result2.has_more is False
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_responses_store_pagination_invalid_after():
|
async def test_responses_store_pagination_invalid_after():
|
||||||
"""Test error handling for invalid 'after' parameter."""
|
"""Test error handling for invalid 'after' parameter."""
|
||||||
|
|
@ -189,6 +195,8 @@ async def test_responses_store_pagination_invalid_after():
|
||||||
with pytest.raises(ValueError, match="Record with id.*'non-existent' not found in table 'openai_responses'"):
|
with pytest.raises(ValueError, match="Record with id.*'non-existent' not found in table 'openai_responses'"):
|
||||||
await store.list_responses(after="non-existent", limit=2)
|
await store.list_responses(after="non-existent", limit=2)
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_responses_store_pagination_no_limit():
|
async def test_responses_store_pagination_no_limit():
|
||||||
"""Test pagination behavior when no limit is specified."""
|
"""Test pagination behavior when no limit is specified."""
|
||||||
|
|
@ -221,6 +229,8 @@ async def test_responses_store_pagination_no_limit():
|
||||||
assert result.data[1].id == "omega-resp"
|
assert result.data[1].id == "omega-resp"
|
||||||
assert result.has_more is False
|
assert result.has_more is False
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_responses_store_get_response_object():
|
async def test_responses_store_get_response_object():
|
||||||
"""Test retrieving a single response object."""
|
"""Test retrieving a single response object."""
|
||||||
|
|
@ -249,6 +259,8 @@ async def test_responses_store_get_response_object():
|
||||||
with pytest.raises(ValueError, match="Response with id non-existent not found"):
|
with pytest.raises(ValueError, match="Response with id non-existent not found"):
|
||||||
await store.get_response_object("non-existent")
|
await store.get_response_object("non-existent")
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_responses_store_input_items_pagination():
|
async def test_responses_store_input_items_pagination():
|
||||||
"""Test pagination functionality for input items."""
|
"""Test pagination functionality for input items."""
|
||||||
|
|
@ -330,6 +342,8 @@ async def test_responses_store_input_items_pagination():
|
||||||
with pytest.raises(ValueError, match="Cannot specify both 'before' and 'after' parameters"):
|
with pytest.raises(ValueError, match="Cannot specify both 'before' and 'after' parameters"):
|
||||||
await store.list_response_input_items("test-resp", before="some-id", after="other-id")
|
await store.list_response_input_items("test-resp", before="some-id", after="other-id")
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_responses_store_input_items_before_pagination():
|
async def test_responses_store_input_items_before_pagination():
|
||||||
"""Test before pagination functionality for input items."""
|
"""Test before pagination functionality for input items."""
|
||||||
|
|
@ -390,3 +404,5 @@ async def test_responses_store_input_items_before_pagination():
|
||||||
ValueError, match="Input item with id 'non-existent' not found for response 'test-resp-before'"
|
ValueError, match="Input item with id 'non-existent' not found for response 'test-resp-before'"
|
||||||
):
|
):
|
||||||
await store.list_response_input_items("test-resp-before", before="non-existent")
|
await store.list_response_input_items("test-resp-before", before="non-existent")
|
||||||
|
|
||||||
|
await store.sql_store.close()
|
||||||
|
|
|
||||||
|
|
@ -64,6 +64,9 @@ async def test_sqlite_sqlstore():
|
||||||
assert result.data == [{"id": 12, "name": "test12"}]
|
assert result.data == [{"id": 12, "name": "test12"}]
|
||||||
assert result.has_more is False
|
assert result.has_more is False
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
await sqlstore.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_sqlstore_pagination_basic():
|
async def test_sqlstore_pagination_basic():
|
||||||
"""Test basic pagination functionality at the SQL store level."""
|
"""Test basic pagination functionality at the SQL store level."""
|
||||||
|
|
@ -128,6 +131,8 @@ async def test_sqlstore_pagination_basic():
|
||||||
assert result3.data[0]["id"] == "zebra"
|
assert result3.data[0]["id"] == "zebra"
|
||||||
assert result3.has_more is False
|
assert result3.has_more is False
|
||||||
|
|
||||||
|
await store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_sqlstore_pagination_with_filter():
|
async def test_sqlstore_pagination_with_filter():
|
||||||
"""Test pagination with WHERE conditions."""
|
"""Test pagination with WHERE conditions."""
|
||||||
|
|
@ -180,6 +185,8 @@ async def test_sqlstore_pagination_with_filter():
|
||||||
assert result2.data[0]["id"] == "xyz"
|
assert result2.data[0]["id"] == "xyz"
|
||||||
assert result2.has_more is False
|
assert result2.has_more is False
|
||||||
|
|
||||||
|
await store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_sqlstore_pagination_ascending_order():
|
async def test_sqlstore_pagination_ascending_order():
|
||||||
"""Test pagination with ascending order."""
|
"""Test pagination with ascending order."""
|
||||||
|
|
@ -228,6 +235,8 @@ async def test_sqlstore_pagination_ascending_order():
|
||||||
assert result2.data[0]["id"] == "alpha"
|
assert result2.data[0]["id"] == "alpha"
|
||||||
assert result2.has_more is True
|
assert result2.has_more is True
|
||||||
|
|
||||||
|
await store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_sqlstore_pagination_multi_column_ordering_error():
|
async def test_sqlstore_pagination_multi_column_ordering_error():
|
||||||
"""Test that multi-column ordering raises an error when using cursor pagination."""
|
"""Test that multi-column ordering raises an error when using cursor pagination."""
|
||||||
|
|
@ -265,6 +274,8 @@ async def test_sqlstore_pagination_multi_column_ordering_error():
|
||||||
assert len(result.data) == 1
|
assert len(result.data) == 1
|
||||||
assert result.data[0]["id"] == "task1"
|
assert result.data[0]["id"] == "task1"
|
||||||
|
|
||||||
|
await store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_sqlstore_pagination_cursor_requires_order_by():
|
async def test_sqlstore_pagination_cursor_requires_order_by():
|
||||||
"""Test that cursor pagination requires order_by parameter."""
|
"""Test that cursor pagination requires order_by parameter."""
|
||||||
|
|
@ -282,6 +293,8 @@ async def test_sqlstore_pagination_cursor_requires_order_by():
|
||||||
cursor=("id", "task1"),
|
cursor=("id", "task1"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_sqlstore_pagination_error_handling():
|
async def test_sqlstore_pagination_error_handling():
|
||||||
"""Test error handling for invalid columns and cursor IDs."""
|
"""Test error handling for invalid columns and cursor IDs."""
|
||||||
|
|
@ -414,6 +427,8 @@ async def test_where_operator_edge_cases():
|
||||||
with pytest.raises(ValueError, match="Unsupported operator"):
|
with pytest.raises(ValueError, match="Unsupported operator"):
|
||||||
await store.fetch_all("events", where={"ts": {"!=": base}})
|
await store.fetch_all("events", where={"ts": {"!=": base}})
|
||||||
|
|
||||||
|
await store.close()
|
||||||
|
|
||||||
|
|
||||||
async def test_sqlstore_pagination_custom_key_column():
|
async def test_sqlstore_pagination_custom_key_column():
|
||||||
"""Test pagination with custom primary key column (not 'id')."""
|
"""Test pagination with custom primary key column (not 'id')."""
|
||||||
|
|
@ -463,3 +478,5 @@ async def test_sqlstore_pagination_custom_key_column():
|
||||||
assert len(result2.data) == 1
|
assert len(result2.data) == 1
|
||||||
assert result2.data[0]["uuid"] == "uuid-alpha"
|
assert result2.data[0]["uuid"] == "uuid-alpha"
|
||||||
assert result2.has_more is False
|
assert result2.has_more is False
|
||||||
|
|
||||||
|
await store.close()
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,8 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
|
||||||
assert row is not None
|
assert row is not None
|
||||||
assert row["title"] == "User Document"
|
assert row["title"] == "User Document"
|
||||||
|
|
||||||
|
await base_sqlstore.close()
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
|
|
@ -163,6 +165,8 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
|
||||||
f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}"
|
f"Difference: SQL only: {sql_ids - policy_ids}, Policy only: {policy_ids - sql_ids}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await base_sqlstore.close()
|
||||||
|
|
||||||
|
|
||||||
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
|
||||||
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
async def test_authorized_store_user_attribute_capture(mock_get_authenticated_user):
|
||||||
|
|
@ -211,3 +215,5 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
|
||||||
# Third item should have null attributes (no authenticated user)
|
# Third item should have null attributes (no authenticated user)
|
||||||
assert result.data[2]["id"] == "item3"
|
assert result.data[2]["id"] == "item3"
|
||||||
assert result.data[2]["access_attributes"] is None
|
assert result.data[2]["access_attributes"] is None
|
||||||
|
|
||||||
|
await base_sqlstore.close()
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue